This commit is contained in:
Glenn Jocher 2019-04-02 18:04:04 +02:00
parent d526ce0d11
commit 1457f66419
5 changed files with 48 additions and 45 deletions

View File

@ -1,6 +0,0 @@
classes=80
train=../coco/coco_1cls.txt
valid=../coco/coco_1cls.txt
names=data/coco.names
backup=backup/
eval=coco

View File

@ -1,6 +0,0 @@
classes=80
train=../coco/coco_1img.txt
valid=../coco/coco_1img.txt
names=data/coco.names
backup=backup/
eval=coco

View File

@ -20,7 +20,9 @@ def train(
accumulate=1,
multi_scale=False,
freeze_backbone=False,
num_workers=4
num_workers=4,
transfer=False # Transfer learning (train only YOLO layers)
):
weights = 'weights' + os.sep
latest = weights + 'latest.pt'
@ -46,14 +48,26 @@ def train(
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
best_loss = float('inf')
yl = get_yolo_layers(model) # yolo layers
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
if resume: # Load previously saved PyTorch model
checkpoint = torch.load(latest, map_location=device) # load checkpoint
model.load_state_dict(checkpoint['model'])
start_epoch = checkpoint['epoch'] + 1
if checkpoint['optimizer'] is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
best_loss = checkpoint['best_loss']
del checkpoint
if transfer: # Transfer learning
chkpt = torch.load(weights + 'yolov3-tiny.pt', map_location=device)
model.load_state_dict(
{k: v for k, v in chkpt['model'].items() if (int(k.split('.')[1]) + 1) not in yl}, strict=False)
for (name, p) in model.named_parameters():
p.requires_grad = True if p.shape[0] == nf else False
else: # resume from latest.pt
chkpt = torch.load(latest, map_location=device) # load checkpoint
model.load_state_dict(chkpt['model'])
start_epoch = chkpt['epoch'] + 1
if chkpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer'])
best_loss = chkpt['best_loss']
del chkpt
else: # Initialize model with backbone (optional)
if cfg.endswith('yolov3.cfg'):
@ -61,10 +75,6 @@ def train(
elif cfg.endswith('yolov3-tiny.cfg'):
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
# Transfer learning (train only YOLO layers)
# for (name, p) in model.named_parameters():
# p.requires_grad = True if p.shape[0] == 255 else False
# Set scheduler (reduce lr at epoch 250)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1, last_epoch=start_epoch - 1)
@ -174,22 +184,22 @@ def train(
save = True
if save:
# Save latest checkpoint
checkpoint = {'epoch': epoch,
chkpt = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
torch.save(chkpt, latest)
# Save best checkpoint
if best_loss == mloss['total']:
torch.save(checkpoint, best)
torch.save(chkpt, best)
# Save backup weights every 10 epochs (optional)
# Save backup every 10 epochs (optional)
if epoch > 0 and epoch % 10 == 0:
torch.save(checkpoint, weights + 'backup%g.pt' % epoch)
torch.save(chkpt, weights + 'backup%g.pt' % epoch)
del checkpoint
del chkpt
# Calculate mAP
with torch.no_grad():
@ -210,6 +220,7 @@ if __name__ == '__main__':
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
@ -224,7 +235,8 @@ if __name__ == '__main__':
opt.cfg,
opt.data_cfg,
img_size=opt.img_size,
resume=opt.resume,
resume=True or opt.resume or opt.transfer,
transfer=True or opt.transfer,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulate=opt.accumulate,

View File

@ -45,11 +45,14 @@ wget https://storage.googleapis.com/ultralytics/yolov3/best_v1_0.pt -O weights/b
# Debug/Development
sudo rm -rf yolov3
# git clone https://github.com/ultralytics/yolov3 # master
git clone -b map_update --depth 1 https://github.com/ultralytics/yolov3 yolov3 # branch
git clone https://github.com/ultralytics/yolov3 # master
# git clone -b hyperparameter_search --depth 1 https://github.com/ultralytics/yolov3 hyperparameter_search # branch
cp -r weights yolov3
cp -r cocoapi/PythonAPI/pycocotools yolov3
cd yolov3
#git pull https://github.com/ultralytics/yolov3 map_update # branch
python3 test.py --img-size 320
git pull https://github.com/ultralytics/yolov3 #hyperparameter_search # branch
python3 train.py --data-cfg data/coco_1cls.data
python3 train.py --data-cfg data/coco_1img.data

View File

@ -295,12 +295,11 @@ def build_targets(model, targets):
txy, twh, tcls, indices = [], [], [], []
for i, layer in enumerate(get_yolo_layers(model)):
nG = model.module_list[layer][0].nG # grid size
anchor_vec = model.module_list[layer][0].anchor_vec
layer = model.module_list[layer][0]
# iou of targets-anchors
gwh = targets[:, 4:6] * nG
iou = [wh_iou(x, gwh) for x in anchor_vec]
gwh = targets[:, 4:6] * layer.nG
iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
# reject below threshold ious (OPTIONAL, increases P, lowers R)
@ -313,7 +312,7 @@ def build_targets(model, targets):
# Indices
b, c = t[:, :2].long().t() # target image, class
gxy = t[:, 2:4] * nG
gxy = t[:, 2:4] * layer.nG
gi, gj = gxy.long().t() # grid_i, grid_j
indices.append((b, a, gj, gi))
@ -321,11 +320,12 @@ def build_targets(model, targets):
txy.append(gxy - gxy.floor())
# Width and height
twh.append(torch.log(gwh / anchor_vec[a])) # yolo method
twh.append(torch.log(gwh / layer.anchor_vec[a])) # yolo method
# twh.append(torch.sqrt(gwh / anchor_vec[a]) / 2) # power method
# Class
tcls.append(c)
assert c.max() <= layer.nC, 'Target classes exceed model classes'
return txy, twh, tcls, indices