From 5843c41dfcd5400a4b084ba9c26680372ecc2747 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 3 Dec 2018 14:05:50 +0100 Subject: [PATCH] add multi_scale support --- models.py | 3 +-- train.py | 11 +++++++---- utils/datasets.py | 6 +++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/models.py b/models.py index d2c263f2..1d105709 100755 --- a/models.py +++ b/models.py @@ -184,15 +184,14 @@ class YOLOLayer(nn.Module): # plt.hist(self.x) # lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float()) - lconf = (k * 64) * BCEWithLogitsLoss(pred_conf, mask.float()) lcls = (k / 4) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) # lcls = (k * 10) * BCEWithLogitsLoss(pred_cls[mask], tcls.float()) else: lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) - # Add confidence loss for background anchors (noobj) # lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) + lconf = (k * 64) * BCEWithLogitsLoss(pred_conf, mask.float()) # Sum loss components balance_losses_flag = False diff --git a/train.py b/train.py index 3bd778e3..84932406 100644 --- a/train.py +++ b/train.py @@ -8,15 +8,18 @@ from utils.utils import * parser = argparse.ArgumentParser() parser.add_argument('-epochs', type=int, default=100, help='number of epochs') -parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch') +parser.add_argument('-batch_size', type=int, default=2, help='size of each image batch') parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') -parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension') +parser.add_argument('-multi_scale', default=True, help='random image sizes per batch 320 - 608') +parser.add_argument('-img_size', type=int, default=32 * 13, help='pixels') parser.add_argument('-resume', default=False, help='resume training flag') parser.add_argument('-batch_report', default=False, help='report TP, FP, FN, P and R per batch (slower)') parser.add_argument('-freeze_darknet53', default=False, help='freeze darknet53.conv.74 layers for first epoch') parser.add_argument('-var', type=float, default=0, help='optional test variable') opt = parser.parse_args() +if opt.multi_scale: # pass maximum multi_scale size + opt.img_size = 608 print(opt) # Import test.py to get mAP after each epoch @@ -50,7 +53,8 @@ def main(opt): model = Darknet(opt.cfg, opt.img_size) # Get dataloader - dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=opt.img_size, augment=True) + dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=opt.img_size, + multi_scale=opt.multi_scale, augment=True) lr0 = 0.001 if opt.resume: @@ -217,4 +221,3 @@ def main(opt): if __name__ == '__main__': torch.cuda.empty_cache() main(opt) - torch.cuda.empty_cache() diff --git a/utils/datasets.py b/utils/datasets.py index a356c106..89f3c27e 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -60,7 +60,7 @@ class load_images(): # for inference class load_images_and_labels(): # for training - def __init__(self, path, batch_size=1, img_size=608, augment=False): + def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False): self.path = path # self.img_files = sorted(glob.glob('%s/*.*' % path)) with open(path, 'r') as file: @@ -79,6 +79,7 @@ class load_images_and_labels(): # for training self.nB = math.ceil(self.nF / batch_size) # number of batches self.batch_size = batch_size self.height = img_size + self.multi_scale = multi_scale self.augment = augment assert self.nB > 0, 'No images found in path %s' % path @@ -100,8 +101,7 @@ class load_images_and_labels(): # for training ia = self.count * self.batch_size ib = min((self.count + 1) * self.batch_size, self.nF) - multi_scale = False - if multi_scale and self.augment: + if self.multi_scale: # Multi-Scale YOLO Training height = random.choice(range(10, 20)) * 32 # 320 - 608 pixels else: