updates
This commit is contained in:
parent
a8596c6af4
commit
1cde55f7c9
24
train.py
24
train.py
|
@ -1,10 +1,8 @@
|
|||
import argparse
|
||||
import time
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.optim as optim
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import test # import test.py to get mAP after each epoch
|
||||
from models import *
|
||||
|
@ -97,7 +95,7 @@ def train(cfg,
|
|||
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
start_epoch = 0
|
||||
best_fitness = 0.0
|
||||
best_fitness = 0.
|
||||
if opt.resume or opt.transfer: # Load previously saved model
|
||||
if opt.transfer: # Transfer learning
|
||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||
|
@ -164,21 +162,21 @@ def train(cfg,
|
|||
|
||||
# Initialize distributed training
|
||||
if torch.cuda.device_count() > 1:
|
||||
dist.init_process_group(backend='nccl', # 'distributed backend'
|
||||
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
||||
world_size=1, # number of nodes for distributed training
|
||||
rank=0) # distributed training node rank
|
||||
torch.distributed.init_process_group(backend='nccl', # 'distributed backend'
|
||||
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
||||
world_size=1, # number of nodes for distributed training
|
||||
rank=0) # distributed training node rank
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
# sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
|
||||
# Dataloader
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=opt.num_workers,
|
||||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=opt.num_workers,
|
||||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
|
||||
# Mixed precision training https://github.com/NVIDIA/apex
|
||||
mixed_precision = True
|
||||
|
|
Loading…
Reference in New Issue