diff --git a/models.py b/models.py index f254b9c8..e851d29c 100755 --- a/models.py +++ b/models.py @@ -194,12 +194,12 @@ class YOLOLayer(nn.Module): loss = lx + ly + lw + lh + lconf + lcls # Sum False Positives from unassigned anchors - i = torch.sigmoid(pred_conf[~mask]) > 0.5 - if i.sum() > 0: - FP_classes = torch.argmax(pred_cls[~mask][i], 1) - FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs - else: - FPe = torch.zeros(self.nC) + FPe = torch.zeros(self.nC) + if requestPrecision: + i = torch.sigmoid(pred_conf[~mask]) > 0.5 + if i.sum() > 0: + FP_classes = torch.argmax(pred_cls[~mask][i], 1) + FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \ nT, TP, FP, FPe, FN, TC @@ -254,7 +254,7 @@ class Darknet(nn.Module): output.append(x) layer_outputs.append(x) - if is_training: + if is_training and requestPrecision: self.losses['nT'] /= 3 self.losses['TC'] /= 3 # target category metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN diff --git a/utils/utils.py b/utils/utils.py index b4933d96..51ea5913 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -214,7 +214,8 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG if nTb == 0: continue t = target[b] - FN[b, :nTb] = 1 + if requestPrecision: + FN[b, :nTb] = 1 # Convert to position relative to box TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG