From 94669fb704be871c7c3a4cfb00152cb8dd717f08 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 8 Jul 2019 15:24:20 +0200 Subject: [PATCH] updates --- train.py | 5 ++--- utils/utils.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index bbbfec05..f980d69b 100644 --- a/train.py +++ b/train.py @@ -40,7 +40,6 @@ def train( latest = weights + 'latest.pt' best = weights + 'best.pt' device = torch_utils.select_device() - img_size_test = img_size # image size for testing multi_scale = not opt.single_scale if multi_scale: @@ -140,7 +139,7 @@ def train( if mixed_precision: try: from apex import amp - model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259 mixed_precision = False @@ -232,7 +231,7 @@ def train( # Calculate mAP (always test final epoch, skip first 5 if opt.nosave) if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1: with torch.no_grad(): - results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=img_size_test, model=model, + results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model, conf_thres=0.1) # Write epoch results diff --git a/utils/utils.py b/utils/utils.py index 91b81f16..53d82f8f 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -430,7 +430,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): pred = pred[(-pred[:, 4]).argsort()] det_max = [] - nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) + nms_style = 'SOFT' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in pred[:, -1].unique(): dc = pred[pred[:, -1] == c] # select class c n = len(dc) @@ -486,6 +486,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes dc = dc[1:] dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences + dc = dc[dc[:, 4] > nms_thres] # new line per https://github.com/ultralytics/yolov3/issues/362 if len(det_max): det_max = torch.cat(det_max) # concatenate