This commit is contained in:
Glenn Jocher 2019-04-18 02:13:04 +02:00
parent 5f21139623
commit 9a440cfa15
1 changed files with 2 additions and 1 deletions

View File

@ -280,6 +280,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
def build_targets(model, targets): def build_targets(model, targets):
# targets = [image, class, x, y, w, h] # targets = [image, class, x, y, w, h]
iou_thres = model.hyp['iou_t'] # hyperparameter
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
model = model.module model = model.module
@ -298,7 +299,7 @@ def build_targets(model, targets):
# reject below threshold ious (OPTIONAL, increases P, lowers R) # reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True reject = True
if reject: if reject:
j = iou > model.hyp['iou_t'] # hyperparameter j = iou > iou_thres
t, a, gwh = targets[j], a[j], gwh[j] t, a, gwh = targets[j], a[j], gwh[j]
# Indices # Indices