nGT to nT
This commit is contained in:
parent
29fbcb059f
commit
1cfde4aba8
10
models.py
10
models.py
|
@ -2,8 +2,8 @@ from collections import defaultdict
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from utils.utils import *
|
||||
from utils.parse_config import *
|
||||
from utils.utils import *
|
||||
|
||||
|
||||
def create_modules(module_defs):
|
||||
|
@ -151,7 +151,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
# Mask outputs to ignore non-existing objects (but keep confidence predictions)
|
||||
nM = mask.sum().float()
|
||||
nGT = sum([len(x) for x in targets])
|
||||
nT = sum([len(x) for x in targets])
|
||||
if nM > 0:
|
||||
lx = 5 * MSELoss(x[mask], tx[mask])
|
||||
ly = 5 * MSELoss(y[mask], ty[mask])
|
||||
|
@ -177,7 +177,7 @@ class YOLOLayer(nn.Module):
|
|||
FPe[c] += 1
|
||||
|
||||
return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \
|
||||
nGT, TP, FP, FPe, FN, TC
|
||||
nT, TP, FP, FPe, FN, TC
|
||||
|
||||
else:
|
||||
pred_boxes[..., 0] = x.data + self.grid_x
|
||||
|
@ -200,7 +200,7 @@ class Darknet(nn.Module):
|
|||
self.module_defs[0]['height'] = img_size
|
||||
self.hyperparams, self.module_list = create_modules(self.module_defs)
|
||||
self.img_size = img_size
|
||||
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nGT', 'TP', 'FP', 'FPe', 'FN', 'TC']
|
||||
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC']
|
||||
|
||||
def forward(self, x, targets=None, requestPrecision=False, epoch=None):
|
||||
is_training = targets is not None
|
||||
|
@ -230,7 +230,7 @@ class Darknet(nn.Module):
|
|||
layer_outputs.append(x)
|
||||
|
||||
if is_training:
|
||||
self.losses['nGT'] /= 3
|
||||
self.losses['nT'] /= 3
|
||||
self.losses['TC'] /= 3
|
||||
metrics = torch.zeros(4, len(self.losses['FPe'])) # TP, FP, FN, target_count
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
|
|||
|
||||
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, requestPrecision):
|
||||
"""
|
||||
returns nGT, nCorrect, tx, ty, tw, th, tconf, tcls
|
||||
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
|
||||
"""
|
||||
nB = len(target) # target.shape[0]
|
||||
nT = [len(x) for x in target] # torch.argmin(target[:, :, 4], 1) # targets per image
|
||||
|
|
Loading…
Reference in New Issue