multi_label burnin addition
This commit is contained in:
parent
2cc2b2cf0d
commit
002884ae5e
5
test.py
5
test.py
|
@ -19,7 +19,8 @@ def test(cfg,
|
||||||
single_cls=False,
|
single_cls=False,
|
||||||
augment=False,
|
augment=False,
|
||||||
model=None,
|
model=None,
|
||||||
dataloader=None):
|
dataloader=None,
|
||||||
|
multi_label=True):
|
||||||
# Initialize/load model and set device
|
# Initialize/load model and set device
|
||||||
if model is None:
|
if model is None:
|
||||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||||
|
@ -95,7 +96,7 @@ def test(cfg,
|
||||||
|
|
||||||
# Run NMS
|
# Run NMS
|
||||||
t = torch_utils.time_synchronized()
|
t = torch_utils.time_synchronized()
|
||||||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) # nms
|
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, multi_label=multi_label)
|
||||||
t1 += torch_utils.time_synchronized() - t
|
t1 += torch_utils.time_synchronized() - t
|
||||||
|
|
||||||
# Statistics per image
|
# Statistics per image
|
||||||
|
|
3
train.py
3
train.py
|
@ -314,7 +314,8 @@ def train(hyp):
|
||||||
model=ema.ema,
|
model=ema.ema,
|
||||||
save_json=final_epoch and is_coco,
|
save_json=final_epoch and is_coco,
|
||||||
single_cls=opt.single_cls,
|
single_cls=opt.single_cls,
|
||||||
dataloader=testloader)
|
dataloader=testloader,
|
||||||
|
multi_label=ni > n_burn)
|
||||||
|
|
||||||
# Write
|
# Write
|
||||||
with open(results_file, 'a') as f:
|
with open(results_file, 'a') as f:
|
||||||
|
|
Loading…
Reference in New Issue