diff --git a/test.py b/test.py index ab122973..5d313efa 100644 --- a/test.py +++ b/test.py @@ -19,7 +19,8 @@ def test(cfg, single_cls=False, augment=False, model=None, - dataloader=None): + dataloader=None, + multi_label=True): # Initialize/load model and set device if model is None: device = torch_utils.select_device(opt.device, batch_size=batch_size) @@ -95,7 +96,7 @@ def test(cfg, # Run NMS t = torch_utils.time_synchronized() - output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) # nms + output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, multi_label=multi_label) t1 += torch_utils.time_synchronized() - t # Statistics per image diff --git a/train.py b/train.py index 3d4f355e..8ea7ed28 100644 --- a/train.py +++ b/train.py @@ -314,7 +314,8 @@ def train(hyp): model=ema.ema, save_json=final_epoch and is_coco, single_cls=opt.single_cls, - dataloader=testloader) + dataloader=testloader, + multi_label=ni > n_burn) # Write with open(results_file, 'a') as f: