diff --git a/train.py b/train.py index da1a07fe..6e461431 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import test # Import test.py to get mAP after each epoch from models import * from utils.datasets import * from utils.utils import * +import torch.distributed as dist def train( @@ -39,11 +40,7 @@ def train( # Optimizer lr0 = 0.001 # initial learning rate - optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9) - - # Dataloader - dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005) cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 @@ -62,10 +59,19 @@ def train( cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74') elif cfg.endswith('yolov3-tiny.cfg'): cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') - + + #initialize for distributed training if torch.cuda.device_count() > 1: - print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146') - model = nn.DataParallel(model) + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank) + model = torch.nn.parallel.DistributedDataParallel(model) + + # Dataloader + dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) + if torch.cuda.device_count() > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + train_sampler=None + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler) # Transfer learning (train only YOLO layers) # for i, (name, p) in enumerate(model.named_parameters()): @@ -172,7 +178,7 @@ def train( # Save latest checkpoint checkpoint = {'epoch': epoch, 'best_loss': best_loss, - 'model': model.module.state_dict() if type(model) is nn.DataParallel else model.state_dict(), + 'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(), 'optimizer': optimizer.state_dict()} torch.save(checkpoint, latest) @@ -185,6 +191,8 @@ def train( os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch)) # Calculate mAP + if type(model) is nn.parallel.DistributedDataParallel: + model = model.module with torch.no_grad(): P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model) @@ -204,6 +212,9 @@ if __name__ == '__main__': parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers') + parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training') + parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training') + parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training') opt = parser.parse_args() print(opt, end='\n\n')