This commit is contained in:
Glenn Jocher 2019-08-05 17:45:32 +02:00
parent 1613d1c396
commit 9a9224cfe6
1 changed files with 2 additions and 3 deletions

View File

@ -333,11 +333,10 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
def build_targets(model, targets):
# targets = [image, class, x, y, w, h]
iou_thres = model.hyp['iou_t'] # hyperparameter
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
nt = len(targets)
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i in model.yolo_layers:
# get number of grid points and anchor vec for this yolo layer
if multi_gpu:
@ -364,7 +363,7 @@ def build_targets(model, targets):
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
reject = True
if reject:
j = iou > iou_thres
j = iou > model.hyp['iou_t'] # iou threshold hyperparameter
t, a, gwh = t[j], a[j], gwh[j]
# Indices