This commit is contained in:
Glenn Jocher 2018-11-22 14:54:52 +01:00
parent 959c67b4ed
commit b9d87be318
3 changed files with 13 additions and 13 deletions

View File

@ -103,7 +103,7 @@ class YOLOLayer(nn.Module):
self.loss_means = torch.zeros(6)
def forward(self, p, targets=None, requestPrecision=False):
def forward(self, p, targets=None, batch_report=False):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size
@ -145,7 +145,7 @@ class YOLOLayer(nn.Module):
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
CrossEntropyLoss = nn.CrossEntropyLoss()
if requestPrecision:
if batch_report:
gx = self.grid_x[:, :, :nG, :nG]
gy = self.grid_y[:, :, :nG, :nG]
pred_boxes[..., 0] = x.data + gx - width / 2
@ -155,7 +155,7 @@ class YOLOLayer(nn.Module):
tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \
build_targets(pred_boxes, pred_conf, pred_cls, targets, self.scaled_anchors, self.nA, self.nC, nG,
requestPrecision)
batch_report)
tcls = tcls[mask]
if x.is_cuda:
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
@ -195,7 +195,7 @@ class YOLOLayer(nn.Module):
# Sum False Positives from unassigned anchors
FPe = torch.zeros(self.nC)
if requestPrecision:
if batch_report:
i = torch.sigmoid(pred_conf[~mask]) > 0.5
if i.sum() > 0:
FP_classes = torch.argmax(pred_cls[~mask][i], 1)
@ -227,7 +227,7 @@ class Darknet(nn.Module):
self.img_size = img_size
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC']
def forward(self, x, targets=None, requestPrecision=False):
def forward(self, x, targets=None, batch_report=False):
is_training = targets is not None
output = []
self.losses = defaultdict(float)
@ -245,7 +245,7 @@ class Darknet(nn.Module):
elif module_def['type'] == 'yolo':
# Train phase: get loss
if is_training:
x, *losses = module[0](x, targets, requestPrecision)
x, *losses = module[0](x, targets, batch_report)
for name, loss in zip(self.loss_names, losses):
self.losses[name] += loss
# Test phase: Get detections
@ -258,7 +258,7 @@ class Darknet(nn.Module):
self.losses['TP'] = 0
self.losses['FP'] = 0
self.losses['FN'] = 0
if is_training and requestPrecision:
if is_training and batch_report:
self.losses['TC'] /= 3 # target category
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN

View File

@ -12,6 +12,7 @@ parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension')
parser.add_argument('-resume', default=False, help='resume training flag')
parser.add_argument('-batch_report', default=False, help='report TP, FP, FN, P and R per batch (slower)')
opt = parser.parse_args()
print(opt)
@ -125,8 +126,7 @@ def main(opt):
g['lr'] = lr
# Compute loss, compute gradient, update parameters
precision_per_batch = False
loss = model(imgs.to(device), targets, requestPrecision=precision_per_batch)
loss = model(imgs.to(device), targets, batch_report=opt.batch_report)
loss.backward()
# accumulated_batches = 1 # accumulate gradient for 4 batches before stepping optimizer
@ -139,7 +139,7 @@ def main(opt):
for key, val in model.losses.items():
rloss[key] = (rloss[key] * ui + val) / (ui + 1)
if precision_per_batch:
if opt.batch_report:
TP, FP, FN = metrics
metrics += model.losses['metrics']

View File

@ -192,7 +192,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
return inter_area / (b1_area + b2_area - inter_area + 1e-16)
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, requestPrecision):
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, batch_report):
"""
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
"""
@ -214,7 +214,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
if nTb == 0:
continue
t = target[b]
if requestPrecision:
if batch_report:
FN[b, :nTb] = 1
# Convert to position relative to box
@ -273,7 +273,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
tcls[b, a, gj, gi, tc] = 1
tconf[b, a, gj, gi] = 1
if requestPrecision:
if batch_report:
# predicted classes and confidence
tb = torch.cat((gx - gw / 2, gy - gh / 2, gx + gw / 2, gy + gh / 2)).view(4, -1).t() # target boxes
pcls = torch.argmax(pred_cls[b, a, gj, gi], 1).cpu()