This commit is contained in:
Glenn Jocher 2019-04-01 18:42:54 +02:00
parent b56952d707
commit 4f98fbde78
1 changed files with 4 additions and 4 deletions

View File

@ -267,9 +267,9 @@ def compute_loss(p, targets): # predictions, targets
pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf
lxy += k * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += k * MSE(pi[..., 2:4], twh[i]) # wh loss
lcls += (k / 4) * CE(pi[..., 5:], tcls[i]) # class_conf loss
lxy += (k * 16) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * 8) * MSE(pi[..., 2:4], twh[i]) # wh loss
lcls += (k * 1) * CE(pi[..., 5:], tcls[i]) # class_conf loss
# pos_weight = FT([gp[i] / min(gp) * 4.])
# BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
@ -303,7 +303,7 @@ def build_targets(model, targets):
# reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True
if reject:
j = iou > 0.01
j = iou > 0.10
t, a, gwh = targets[j], a[j], gwh[j]
else:
t = targets