parent
648ed20717
commit
386835d7ca
29
train.py
29
train.py
|
@ -7,6 +7,7 @@ import test # Import test.py to get mAP after each epoch
|
|||
from models import *
|
||||
from utils.datasets import *
|
||||
from utils.utils import *
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def train(
|
||||
|
@ -39,11 +40,7 @@ def train(
|
|||
|
||||
# Optimizer
|
||||
lr0 = 0.001 # initial learning rate
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
|
||||
|
||||
# Dataloader
|
||||
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005)
|
||||
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
start_epoch = 0
|
||||
|
@ -62,10 +59,19 @@ def train(
|
|||
cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74')
|
||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||
|
||||
|
||||
#initialize for distributed training
|
||||
if torch.cuda.device_count() > 1:
|
||||
print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146')
|
||||
model = nn.DataParallel(model)
|
||||
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
|
||||
# Dataloader
|
||||
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
|
||||
if torch.cuda.device_count() > 1:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
else:
|
||||
train_sampler=None
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
|
||||
|
||||
# Transfer learning (train only YOLO layers)
|
||||
# for i, (name, p) in enumerate(model.named_parameters()):
|
||||
|
@ -172,7 +178,7 @@ def train(
|
|||
# Save latest checkpoint
|
||||
checkpoint = {'epoch': epoch,
|
||||
'best_loss': best_loss,
|
||||
'model': model.module.state_dict() if type(model) is nn.DataParallel else model.state_dict(),
|
||||
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
torch.save(checkpoint, latest)
|
||||
|
||||
|
@ -185,6 +191,8 @@ def train(
|
|||
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
|
||||
|
||||
# Calculate mAP
|
||||
if type(model) is nn.parallel.DistributedDataParallel:
|
||||
model = model.module
|
||||
with torch.no_grad():
|
||||
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
|
||||
|
||||
|
@ -204,6 +212,9 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
|
||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training')
|
||||
parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
|
||||
opt = parser.parse_args()
|
||||
print(opt, end='\n\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue