updates
This commit is contained in:
parent
a8596c6af4
commit
1cde55f7c9
8
train.py
8
train.py
|
@ -1,10 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import test # import test.py to get mAP after each epoch
|
import test # import test.py to get mAP after each epoch
|
||||||
from models import *
|
from models import *
|
||||||
|
@ -97,7 +95,7 @@ def train(cfg,
|
||||||
|
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_fitness = 0.0
|
best_fitness = 0.
|
||||||
if opt.resume or opt.transfer: # Load previously saved model
|
if opt.resume or opt.transfer: # Load previously saved model
|
||||||
if opt.transfer: # Transfer learning
|
if opt.transfer: # Transfer learning
|
||||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
|
@ -164,7 +162,7 @@ def train(cfg,
|
||||||
|
|
||||||
# Initialize distributed training
|
# Initialize distributed training
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
dist.init_process_group(backend='nccl', # 'distributed backend'
|
torch.distributed.init_process_group(backend='nccl', # 'distributed backend'
|
||||||
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
||||||
world_size=1, # number of nodes for distributed training
|
world_size=1, # number of nodes for distributed training
|
||||||
rank=0) # distributed training node rank
|
rank=0) # distributed training node rank
|
||||||
|
@ -173,7 +171,7 @@ def train(cfg,
|
||||||
# sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
# sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = torch.utils.data.DataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=opt.num_workers,
|
num_workers=opt.num_workers,
|
||||||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
||||||
|
|
Loading…
Reference in New Issue