Adam to SGD with burn-in

This commit is contained in:
Glenn Jocher 2018-09-20 18:03:19 +02:00
parent 1cfde4aba8
commit a722601ef6
5 changed files with 48 additions and 35 deletions

View File

@ -18,7 +18,7 @@ parser.add_argument('-txt_out', type=bool, default=False)
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file') parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
parser.add_argument('-conf_thres', type=float, default=0.98, help='object confidence threshold') parser.add_argument('-conf_thres', type=float, default=0.80, help='object confidence threshold')
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression') parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
parser.add_argument('-batch_size', type=int, default=1, help='size of the batches') parser.add_argument('-batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension') parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension')
@ -33,7 +33,6 @@ def detect(opt):
# Load model # Load model
model = Darknet(opt.cfg, opt.img_size) model = Darknet(opt.cfg, opt.img_size)
#weights_path = 'checkpoints/yolov3.weights'
weights_path = 'checkpoints/yolov3.pt' weights_path = 'checkpoints/yolov3.pt'
if weights_path.endswith('.weights'): # saved in darknet format if weights_path.endswith('.weights'): # saved in darknet format
load_weights(model, weights_path) load_weights(model, weights_path)

View File

@ -100,7 +100,7 @@ class YOLOLayer(nn.Module):
self.anchor_w = self.scaled_anchors[:, 0:1].view((1, nA, 1, 1)) self.anchor_w = self.scaled_anchors[:, 0:1].view((1, nA, 1, 1))
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1)) self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
def forward(self, p, targets=None, requestPrecision=False, epoch=None): def forward(self, p, targets=None, requestPrecision=False):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size bs = p.shape[0] # batch size
@ -117,10 +117,18 @@ class YOLOLayer(nn.Module):
# Get outputs # Get outputs
x = torch.sigmoid(p[..., 0]) # Center x x = torch.sigmoid(p[..., 0]) # Center x
y = torch.sigmoid(p[..., 1]) # Center y y = torch.sigmoid(p[..., 1]) # Center y
w = p[..., 2] # Width
h = p[..., 3] # Height # Width and height (yolo method)
width = torch.exp(w.data) * self.anchor_w # w = p[..., 2] # Width
height = torch.exp(h.data) * self.anchor_h # h = p[..., 3] # Height
# width = torch.exp(w.data) * self.anchor_w
# height = torch.exp(h.data) * self.anchor_h
# Width and height (power method)
w = torch.sigmoid(p[..., 2]) # Width
h = torch.sigmoid(p[..., 3]) # Height
width = ((w.data * 2) ** 2) * self.anchor_w
height = ((h.data * 2) ** 2) * self.anchor_h
# Add offset and scale with anchors (in grid space, i.e. 0-13) # Add offset and scale with anchors (in grid space, i.e. 0-13)
pred_boxes = FT(bs, self.nA, nG, nG, 4) pred_boxes = FT(bs, self.nA, nG, nG, 4)
@ -151,6 +159,7 @@ class YOLOLayer(nn.Module):
# Mask outputs to ignore non-existing objects (but keep confidence predictions) # Mask outputs to ignore non-existing objects (but keep confidence predictions)
nM = mask.sum().float() nM = mask.sum().float()
batch_size = len(targets)
nT = sum([len(x) for x in targets]) nT = sum([len(x) for x in targets])
if nM > 0: if nM > 0:
lx = 5 * MSELoss(x[mask], tx[mask]) lx = 5 * MSELoss(x[mask], tx[mask])
@ -166,7 +175,7 @@ class YOLOLayer(nn.Module):
lconf += 0.5 * nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float()) lconf += 0.5 * nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())
loss = lx + ly + lw + lh + lconf + lcls loss = (lx + ly + lw + lh + lconf + lcls) / batch_size
# Sum False Positives from unnasigned anchors # Sum False Positives from unnasigned anchors
i = torch.sigmoid(pred_conf[~mask]) > 0.99 i = torch.sigmoid(pred_conf[~mask]) > 0.99
@ -202,7 +211,7 @@ class Darknet(nn.Module):
self.img_size = img_size self.img_size = img_size
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC'] self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC']
def forward(self, x, targets=None, requestPrecision=False, epoch=None): def forward(self, x, targets=None, requestPrecision=False):
is_training = targets is not None is_training = targets is not None
output = [] output = []
self.losses = defaultdict(float) self.losses = defaultdict(float)
@ -220,7 +229,7 @@ class Darknet(nn.Module):
elif module_def['type'] == 'yolo': elif module_def['type'] == 'yolo':
# Train phase: get loss # Train phase: get loss
if is_training: if is_training:
x, *losses = module[0](x, targets, requestPrecision, epoch) x, *losses = module[0](x, targets, requestPrecision)
for name, loss in zip(self.loss_names, losses): for name, loss in zip(self.loss_names, losses):
self.losses[name] += loss self.losses[name] += loss
# Test phase: Get detections # Test phase: Get detections

View File

@ -7,7 +7,7 @@ parser = argparse.ArgumentParser()
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch') parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file') parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
parser.add_argument('-weights_path', type=str, default='checkpoints/yolov3.weights', help='path to weights file') parser.add_argument('-weights_path', type=str, default='checkpoints/yolov3.pt', help='path to weights file')
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file') parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected') parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold') parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold')
@ -106,7 +106,6 @@ for batch_i, (imgs, targets) in enumerate(dataloader):
correct.append(0) correct.append(0)
# Compute Average Precision (AP) per class # Compute Average Precision (AP) per class
# target_cls = annotations[:, 0] if annotations.size(0) > 1 else annotations[0]
AP = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls) AP = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls)
# Compute mean AP for this image # Compute mean AP for this image

View File

@ -65,9 +65,8 @@ def main(opt):
# p.requires_grad = False # p.requires_grad = False
# Set optimizer # Set optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=.001, momentum=.9, weight_decay=5e-4, nesterov=True)
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()))
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1 start_epoch = checkpoint['epoch'] + 1
@ -79,12 +78,12 @@ def main(opt):
print('Using ', torch.cuda.device_count(), ' GPUs') print('Using ', torch.cuda.device_count(), ' GPUs')
model = nn.DataParallel(model) model = nn.DataParallel(model)
model.to(device).train() model.to(device).train()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=5e-4) # Set optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9, weight_decay=5e-4, nesterov=True)
# Set scheduler # Set scheduler
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 24, eta_min=0.00001, last_epoch=-1)
# y = 0.001 * exp(-0.00921 * x) # 1e-4 @ 250, 1e-5 @ 500
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99082, last_epoch=start_epoch - 1) # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99082, last_epoch=start_epoch - 1)
modelinfo(model) modelinfo(model)
@ -94,35 +93,40 @@ def main(opt):
for epoch in range(opt.epochs): for epoch in range(opt.epochs):
epoch += start_epoch epoch += start_epoch
# Multi-Scale Training # Multi-Scale YOLO Training
# img_size = random.choice(range(10, 20)) * 32 # img_size = random.choice(range(10, 20)) * 32 # 320 - 608 pixels
# dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=img_size, augment=True) # dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=img_size, augment=True)
# print('Running this epoch with image size %g' % img_size) # print('Running this epoch with image size %g' % img_size)
# Update scheduler # Update scheduler (automatic)
# if epoch % 25 == 0:
# scheduler.last_epoch = -1 # for cosine annealing, restart every 25 epochs
# scheduler.step() # scheduler.step()
# if epoch <= 100:
# Update scheduler (manual)
# for g in optimizer.param_groups: # for g in optimizer.param_groups:
# g['lr'] = 0.0005 * (0.992 ** epoch) # 1/10 th every 250 epochs # g['lr'] = 1e-3 * (g ** epoch) # 1/10th every [30, 50, 100, 250] epochs using g = [.926, .955, .977, .992]
# g['lr'] = 0.001 * (0.9773 ** epoch) # 1/10 th every 100 epochs
# g['lr'] = 0.0005 * (0.955 ** epoch) # 1/10 th every 50 epochs
# g['lr'] = 0.0005 * (0.926 ** epoch) # 1/10 th every 30 epochs
ui = -1 ui = -1
rloss = defaultdict(float) # running loss rloss = defaultdict(float) # running loss
metrics = torch.zeros(4, num_classes) metrics = torch.zeros(4, num_classes)
for i, (imgs, targets) in enumerate(dataloader): for i, (imgs, targets) in enumerate(dataloader):
if sum([len(x) for x in targets]) < 1: # if no targets continue if sum([len(x) for x in targets]) < 1: # if no targets continue
continue continue
loss = model(imgs.to(device), targets, requestPrecision=True, epoch=epoch) # SGD burn-in
if (epoch == 0) & (i <= 1000):
power = 4
lr = 1e-3 * (i / 1000) ** power
for g in optimizer.param_groups:
g['lr'] = lr
# print('SGD Burn-In LR = %9.5g' % lr, end='')
# Compute loss, compute gradient, update parameters
loss = model(imgs.to(device), targets, requestPrecision=True)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Compute running epoch-means of tracked metrics
ui += 1 ui += 1
metrics += model.losses['metrics'] metrics += model.losses['metrics']
for key, val in model.losses.items(): for key, val in model.losses.items():

View File

@ -262,12 +262,14 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
# Coordinates # Coordinates
tx[b, a, gj, gi] = gx - gi.float() tx[b, a, gj, gi] = gx - gi.float()
ty[b, a, gj, gi] = gy - gj.float() ty[b, a, gj, gi] = gy - gj.float()
# Width and height (sqrt method)
# tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2 # Width and height (power method)
# th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2 tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2
th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2
# Width and height (yolov3 method) # Width and height (yolov3 method)
tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0] + 1e-16) # tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0] + 1e-16)
th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1] + 1e-16) # th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1] + 1e-16)
# One-hot encoding of label # One-hot encoding of label
tcls[b, a, gj, gi, tc] = 1 tcls[b, a, gj, gi, tc] = 1