updates
This commit is contained in:
parent
291c3ec9c7
commit
94669fb704
5
train.py
5
train.py
|
@ -40,7 +40,6 @@ def train(
|
||||||
latest = weights + 'latest.pt'
|
latest = weights + 'latest.pt'
|
||||||
best = weights + 'best.pt'
|
best = weights + 'best.pt'
|
||||||
device = torch_utils.select_device()
|
device = torch_utils.select_device()
|
||||||
img_size_test = img_size # image size for testing
|
|
||||||
multi_scale = not opt.single_scale
|
multi_scale = not opt.single_scale
|
||||||
|
|
||||||
if multi_scale:
|
if multi_scale:
|
||||||
|
@ -140,7 +139,7 @@ def train(
|
||||||
if mixed_precision:
|
if mixed_precision:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
||||||
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
|
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
|
||||||
mixed_precision = False
|
mixed_precision = False
|
||||||
|
|
||||||
|
@ -232,7 +231,7 @@ def train(
|
||||||
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
|
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
|
||||||
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
|
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=img_size_test, model=model,
|
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model,
|
||||||
conf_thres=0.1)
|
conf_thres=0.1)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
|
|
|
@ -430,7 +430,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
pred = pred[(-pred[:, 4]).argsort()]
|
pred = pred[(-pred[:, 4]).argsort()]
|
||||||
|
|
||||||
det_max = []
|
det_max = []
|
||||||
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
nms_style = 'SOFT' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||||
for c in pred[:, -1].unique():
|
for c in pred[:, -1].unique():
|
||||||
dc = pred[pred[:, -1] == c] # select class c
|
dc = pred[pred[:, -1] == c] # select class c
|
||||||
n = len(dc)
|
n = len(dc)
|
||||||
|
@ -486,6 +486,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||||
dc = dc[1:]
|
dc = dc[1:]
|
||||||
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
|
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
|
||||||
|
dc = dc[dc[:, 4] > nms_thres] # new line per https://github.com/ultralytics/yolov3/issues/362
|
||||||
|
|
||||||
if len(det_max):
|
if len(det_max):
|
||||||
det_max = torch.cat(det_max) # concatenate
|
det_max = torch.cat(det_max) # concatenate
|
||||||
|
|
Loading…
Reference in New Issue