Merge pull request #160 from perry0418/master

fix the multi gpu training bug: zero map
This commit is contained in:
Glenn Jocher 2019-03-25 11:57:40 +01:00 committed by GitHub
commit 16ad6f9739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 10 deletions

View File

@ -7,6 +7,7 @@ import test # Import test.py to get mAP after each epoch
from models import *
from utils.datasets import *
from utils.utils import *
import torch.distributed as dist
def train(
@ -39,11 +40,7 @@ def train(
# Optimizer
lr0 = 0.001 # initial learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005)
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
@ -62,10 +59,19 @@ def train(
cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74')
elif cfg.endswith('yolov3-tiny.cfg'):
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
#initialize for distributed training
if torch.cuda.device_count() > 1:
print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146')
model = nn.DataParallel(model)
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
if torch.cuda.device_count() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler=None
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
# Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()):
@ -172,7 +178,7 @@ def train(
# Save latest checkpoint
checkpoint = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(model) is nn.DataParallel else model.state_dict(),
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
@ -185,6 +191,8 @@ def train(
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
# Calculate mAP
if type(model) is nn.parallel.DistributedDataParallel:
model = model.module
with torch.no_grad():
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
@ -204,6 +212,10 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training')
parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend')
opt = parser.parse_args()
print(opt, end='\n\n')

View File

@ -285,7 +285,7 @@ def compute_loss(p, targets): # predictions, targets
def build_targets(model, targets, pred):
# targets = [image, class, x, y, w, h]
if isinstance(model, nn.DataParallel):
if isinstance(model, nn.parallel.DistributedDataParallel):
model = model.module
yolo_layers = get_yolo_layers(model)