updates
This commit is contained in:
parent
db515a4535
commit
f18f288990
14
models.py
14
models.py
|
@ -194,12 +194,12 @@ class YOLOLayer(nn.Module):
|
||||||
loss = lx + ly + lw + lh + lconf + lcls
|
loss = lx + ly + lw + lh + lconf + lcls
|
||||||
|
|
||||||
# Sum False Positives from unassigned anchors
|
# Sum False Positives from unassigned anchors
|
||||||
i = torch.sigmoid(pred_conf[~mask]) > 0.5
|
FPe = torch.zeros(self.nC)
|
||||||
if i.sum() > 0:
|
if requestPrecision:
|
||||||
FP_classes = torch.argmax(pred_cls[~mask][i], 1)
|
i = torch.sigmoid(pred_conf[~mask]) > 0.5
|
||||||
FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs
|
if i.sum() > 0:
|
||||||
else:
|
FP_classes = torch.argmax(pred_cls[~mask][i], 1)
|
||||||
FPe = torch.zeros(self.nC)
|
FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs
|
||||||
|
|
||||||
return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \
|
return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \
|
||||||
nT, TP, FP, FPe, FN, TC
|
nT, TP, FP, FPe, FN, TC
|
||||||
|
@ -254,7 +254,7 @@ class Darknet(nn.Module):
|
||||||
output.append(x)
|
output.append(x)
|
||||||
layer_outputs.append(x)
|
layer_outputs.append(x)
|
||||||
|
|
||||||
if is_training:
|
if is_training and requestPrecision:
|
||||||
self.losses['nT'] /= 3
|
self.losses['nT'] /= 3
|
||||||
self.losses['TC'] /= 3 # target category
|
self.losses['TC'] /= 3 # target category
|
||||||
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
|
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
|
||||||
|
|
|
@ -214,7 +214,8 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
||||||
if nTb == 0:
|
if nTb == 0:
|
||||||
continue
|
continue
|
||||||
t = target[b]
|
t = target[b]
|
||||||
FN[b, :nTb] = 1
|
if requestPrecision:
|
||||||
|
FN[b, :nTb] = 1
|
||||||
|
|
||||||
# Convert to position relative to box
|
# Convert to position relative to box
|
||||||
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||||
|
|
Loading…
Reference in New Issue