Update train.py

solve the multi-gpu training problem.
This commit is contained in:
perry0418 2019-03-25 14:56:38 +08:00 committed by GitHub
parent 648ed20717
commit 386835d7ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 9 deletions

View File

@ -7,6 +7,7 @@ import test # Import test.py to get mAP after each epoch
from models import *
from utils.datasets import *
from utils.utils import *
import torch.distributed as dist
def train(
@ -39,11 +40,7 @@ def train(
# Optimizer
lr0 = 0.001 # initial learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005)
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
@ -63,9 +60,18 @@ def train(
elif cfg.endswith('yolov3-tiny.cfg'):
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
#initialize for distributed training
if torch.cuda.device_count() > 1:
print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146')
model = nn.DataParallel(model)
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
if torch.cuda.device_count() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler=None
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
# Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()):
@ -172,7 +178,7 @@ def train(
# Save latest checkpoint
checkpoint = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(model) is nn.DataParallel else model.state_dict(),
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
@ -185,6 +191,8 @@ def train(
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
# Calculate mAP
if type(model) is nn.parallel.DistributedDataParallel:
model = model.module
with torch.no_grad():
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
@ -204,6 +212,9 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training')
parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
opt = parser.parse_args()
print(opt, end='\n\n')