parent
648ed20717
commit
386835d7ca
27
train.py
27
train.py
|
@ -7,6 +7,7 @@ import test # Import test.py to get mAP after each epoch
|
||||||
from models import *
|
from models import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
@ -39,11 +40,7 @@ def train(
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
lr0 = 0.001 # initial learning rate
|
lr0 = 0.001 # initial learning rate
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
|
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005)
|
||||||
|
|
||||||
# Dataloader
|
|
||||||
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
|
||||||
|
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
@ -63,9 +60,18 @@ def train(
|
||||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||||
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||||
|
|
||||||
|
#initialize for distributed training
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146')
|
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank)
|
||||||
model = nn.DataParallel(model)
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||||
|
|
||||||
|
# Dataloader
|
||||||
|
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||||
|
else:
|
||||||
|
train_sampler=None
|
||||||
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
|
||||||
|
|
||||||
# Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
# for i, (name, p) in enumerate(model.named_parameters()):
|
# for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
|
@ -172,7 +178,7 @@ def train(
|
||||||
# Save latest checkpoint
|
# Save latest checkpoint
|
||||||
checkpoint = {'epoch': epoch,
|
checkpoint = {'epoch': epoch,
|
||||||
'best_loss': best_loss,
|
'best_loss': best_loss,
|
||||||
'model': model.module.state_dict() if type(model) is nn.DataParallel else model.state_dict(),
|
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, latest)
|
torch.save(checkpoint, latest)
|
||||||
|
|
||||||
|
@ -185,6 +191,8 @@ def train(
|
||||||
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
|
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
|
if type(model) is nn.parallel.DistributedDataParallel:
|
||||||
|
model = model.module
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
|
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
|
||||||
|
|
||||||
|
@ -204,6 +212,9 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
|
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
||||||
|
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
|
||||||
|
parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training')
|
||||||
|
parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt, end='\n\n')
|
print(opt, end='\n\n')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue