updates
This commit is contained in:
parent
959c67b4ed
commit
b9d87be318
14
models.py
14
models.py
|
@ -103,7 +103,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
self.loss_means = torch.zeros(6)
|
||||
|
||||
def forward(self, p, targets=None, requestPrecision=False):
|
||||
def forward(self, p, targets=None, batch_report=False):
|
||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||
|
||||
bs = p.shape[0] # batch size
|
||||
|
@ -145,7 +145,7 @@ class YOLOLayer(nn.Module):
|
|||
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
||||
CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||
|
||||
if requestPrecision:
|
||||
if batch_report:
|
||||
gx = self.grid_x[:, :, :nG, :nG]
|
||||
gy = self.grid_y[:, :, :nG, :nG]
|
||||
pred_boxes[..., 0] = x.data + gx - width / 2
|
||||
|
@ -155,7 +155,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \
|
||||
build_targets(pred_boxes, pred_conf, pred_cls, targets, self.scaled_anchors, self.nA, self.nC, nG,
|
||||
requestPrecision)
|
||||
batch_report)
|
||||
tcls = tcls[mask]
|
||||
if x.is_cuda:
|
||||
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
|
||||
|
@ -195,7 +195,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
# Sum False Positives from unassigned anchors
|
||||
FPe = torch.zeros(self.nC)
|
||||
if requestPrecision:
|
||||
if batch_report:
|
||||
i = torch.sigmoid(pred_conf[~mask]) > 0.5
|
||||
if i.sum() > 0:
|
||||
FP_classes = torch.argmax(pred_cls[~mask][i], 1)
|
||||
|
@ -227,7 +227,7 @@ class Darknet(nn.Module):
|
|||
self.img_size = img_size
|
||||
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC']
|
||||
|
||||
def forward(self, x, targets=None, requestPrecision=False):
|
||||
def forward(self, x, targets=None, batch_report=False):
|
||||
is_training = targets is not None
|
||||
output = []
|
||||
self.losses = defaultdict(float)
|
||||
|
@ -245,7 +245,7 @@ class Darknet(nn.Module):
|
|||
elif module_def['type'] == 'yolo':
|
||||
# Train phase: get loss
|
||||
if is_training:
|
||||
x, *losses = module[0](x, targets, requestPrecision)
|
||||
x, *losses = module[0](x, targets, batch_report)
|
||||
for name, loss in zip(self.loss_names, losses):
|
||||
self.losses[name] += loss
|
||||
# Test phase: Get detections
|
||||
|
@ -258,7 +258,7 @@ class Darknet(nn.Module):
|
|||
self.losses['TP'] = 0
|
||||
self.losses['FP'] = 0
|
||||
self.losses['FN'] = 0
|
||||
if is_training and requestPrecision:
|
||||
if is_training and batch_report:
|
||||
self.losses['TC'] /= 3 # target category
|
||||
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
|
||||
|
||||
|
|
6
train.py
6
train.py
|
@ -12,6 +12,7 @@ parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help
|
|||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||
parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension')
|
||||
parser.add_argument('-resume', default=False, help='resume training flag')
|
||||
parser.add_argument('-batch_report', default=False, help='report TP, FP, FN, P and R per batch (slower)')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
|
@ -125,8 +126,7 @@ def main(opt):
|
|||
g['lr'] = lr
|
||||
|
||||
# Compute loss, compute gradient, update parameters
|
||||
precision_per_batch = False
|
||||
loss = model(imgs.to(device), targets, requestPrecision=precision_per_batch)
|
||||
loss = model(imgs.to(device), targets, batch_report=opt.batch_report)
|
||||
loss.backward()
|
||||
|
||||
# accumulated_batches = 1 # accumulate gradient for 4 batches before stepping optimizer
|
||||
|
@ -139,7 +139,7 @@ def main(opt):
|
|||
for key, val in model.losses.items():
|
||||
rloss[key] = (rloss[key] * ui + val) / (ui + 1)
|
||||
|
||||
if precision_per_batch:
|
||||
if opt.batch_report:
|
||||
TP, FP, FN = metrics
|
||||
metrics += model.losses['metrics']
|
||||
|
||||
|
|
|
@ -192,7 +192,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
|
|||
return inter_area / (b1_area + b2_area - inter_area + 1e-16)
|
||||
|
||||
|
||||
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, requestPrecision):
|
||||
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, batch_report):
|
||||
"""
|
||||
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
|
||||
"""
|
||||
|
@ -214,7 +214,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
|||
if nTb == 0:
|
||||
continue
|
||||
t = target[b]
|
||||
if requestPrecision:
|
||||
if batch_report:
|
||||
FN[b, :nTb] = 1
|
||||
|
||||
# Convert to position relative to box
|
||||
|
@ -273,7 +273,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
|||
tcls[b, a, gj, gi, tc] = 1
|
||||
tconf[b, a, gj, gi] = 1
|
||||
|
||||
if requestPrecision:
|
||||
if batch_report:
|
||||
# predicted classes and confidence
|
||||
tb = torch.cat((gx - gw / 2, gy - gh / 2, gx + gw / 2, gy + gh / 2)).view(4, -1).t() # target boxes
|
||||
pcls = torch.argmax(pred_cls[b, a, gj, gi], 1).cpu()
|
||||
|
|
Loading…
Reference in New Issue