This commit is contained in:
Glenn Jocher 2019-07-24 15:56:10 +02:00
parent a8596c6af4
commit 1cde55f7c9
1 changed files with 11 additions and 13 deletions

View File

@ -1,10 +1,8 @@
import argparse import argparse
import time import time
import torch.distributed as dist
import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import test # import test.py to get mAP after each epoch import test # import test.py to get mAP after each epoch
from models import * from models import *
@ -97,7 +95,7 @@ def train(cfg,
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
best_fitness = 0.0 best_fitness = 0.
if opt.resume or opt.transfer: # Load previously saved model if opt.resume or opt.transfer: # Load previously saved model
if opt.transfer: # Transfer learning if opt.transfer: # Transfer learning
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
@ -164,7 +162,7 @@ def train(cfg,
# Initialize distributed training # Initialize distributed training
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
dist.init_process_group(backend='nccl', # 'distributed backend' torch.distributed.init_process_group(backend='nccl', # 'distributed backend'
init_method='tcp://127.0.0.1:9999', # distributed training init method init_method='tcp://127.0.0.1:9999', # distributed training init method
world_size=1, # number of nodes for distributed training world_size=1, # number of nodes for distributed training
rank=0) # distributed training node rank rank=0) # distributed training node rank
@ -173,7 +171,7 @@ def train(cfg,
# sampler = torch.utils.data.distributed.DistributedSampler(dataset) # sampler = torch.utils.data.distributed.DistributedSampler(dataset)
# Dataloader # Dataloader
dataloader = DataLoader(dataset, dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=opt.num_workers, num_workers=opt.num_workers,
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used shuffle=not opt.rect, # Shuffle=True unless rectangular training is used