ONNX export compatability updates
This commit is contained in:
parent
69963ff1f5
commit
c50df0d1db
|
@ -66,7 +66,9 @@ def detect(
|
|||
|
||||
# Get detections
|
||||
with torch.no_grad():
|
||||
pred = model(torch.from_numpy(img).unsqueeze(0).to(device))
|
||||
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
||||
# pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True,); return # ONNX export
|
||||
pred = model(img)
|
||||
pred = pred[pred[:, :, 4] > conf_thres]
|
||||
|
||||
if len(pred) > 0:
|
||||
|
|
50
models.py
50
models.py
|
@ -133,6 +133,8 @@ class YOLOLayer(nn.Module):
|
|||
# Get outputs
|
||||
x = torch.sigmoid(p[..., 0]) # Center x
|
||||
y = torch.sigmoid(p[..., 1]) # Center y
|
||||
p_conf = p[..., 4] # Conf
|
||||
p_cls = p[..., 5:] # Class
|
||||
|
||||
# Width and height (yolo method)
|
||||
w = p[..., 2] # Width
|
||||
|
@ -146,28 +148,25 @@ class YOLOLayer(nn.Module):
|
|||
# width = ((w.data * 2) ** 2) * self.anchor_w
|
||||
# height = ((h.data * 2) ** 2) * self.anchor_h
|
||||
|
||||
# Add offset and scale with anchors (in grid space, i.e. 0-13)
|
||||
pred_boxes = FT(bs, self.nA, nG, nG, 4)
|
||||
pred_conf = p[..., 4] # Conf
|
||||
pred_cls = p[..., 5:] # Class
|
||||
|
||||
# Training
|
||||
if targets is not None:
|
||||
MSELoss = nn.MSELoss()
|
||||
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
||||
CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||
|
||||
p_boxes = None
|
||||
if batch_report:
|
||||
# Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13)
|
||||
gx = self.grid_x[:, :, :nG, :nG]
|
||||
gy = self.grid_y[:, :, :nG, :nG]
|
||||
pred_boxes[..., 0] = x.data + gx - width / 2
|
||||
pred_boxes[..., 1] = y.data + gy - height / 2
|
||||
pred_boxes[..., 2] = x.data + gx + width / 2
|
||||
pred_boxes[..., 3] = y.data + gy + height / 2
|
||||
p_boxes = torch.stack((x.data + gx - width / 2,
|
||||
y.data + gy - height / 2,
|
||||
x.data + gx + width / 2,
|
||||
y.data + gy + height / 2), 4) # x1y1x2y2
|
||||
|
||||
tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \
|
||||
build_targets(pred_boxes, pred_conf, pred_cls, targets, self.scaled_anchors, self.nA, self.nC, nG,
|
||||
batch_report)
|
||||
build_targets(p_boxes, p_conf, p_cls, targets, self.scaled_anchors, self.nA, self.nC, nG, batch_report)
|
||||
|
||||
tcls = tcls[mask]
|
||||
if x.is_cuda:
|
||||
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
|
||||
|
@ -194,15 +193,15 @@ class YOLOLayer(nn.Module):
|
|||
# import matplotlib.pyplot as plt
|
||||
# plt.hist(self.x)
|
||||
|
||||
# lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
|
||||
# lconf = k * BCEWithLogitsLoss(p_conf[mask], mask[mask].float())
|
||||
|
||||
lcls = (k / 4) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
||||
# lcls = (k * 10) * BCEWithLogitsLoss(pred_cls[mask], tcls.float())
|
||||
lcls = (k / 4) * CrossEntropyLoss(p_cls[mask], torch.argmax(tcls, 1))
|
||||
# lcls = (k * 10) * BCEWithLogitsLoss(p_cls[mask], tcls.float())
|
||||
else:
|
||||
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
|
||||
|
||||
# lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
|
||||
lconf = (k * 64) * BCEWithLogitsLoss(pred_conf, mask.float())
|
||||
# lconf += k * BCEWithLogitsLoss(p_conf[~mask], mask[~mask].float())
|
||||
lconf = (k * 64) * BCEWithLogitsLoss(p_conf, mask.float())
|
||||
|
||||
# Sum loss components
|
||||
balance_losses_flag = False
|
||||
|
@ -218,24 +217,23 @@ class YOLOLayer(nn.Module):
|
|||
# Sum False Positives from unassigned anchors
|
||||
FPe = torch.zeros(self.nC)
|
||||
if batch_report:
|
||||
i = torch.sigmoid(pred_conf[~mask]) > 0.5
|
||||
i = torch.sigmoid(p_conf[~mask]) > 0.5
|
||||
if i.sum() > 0:
|
||||
FP_classes = torch.argmax(pred_cls[~mask][i], 1)
|
||||
FP_classes = torch.argmax(p_cls[~mask][i], 1)
|
||||
FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs
|
||||
|
||||
return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \
|
||||
nT, TP, FP, FPe, FN, TC
|
||||
|
||||
else:
|
||||
pred_boxes[..., 0] = x.data + self.grid_x
|
||||
pred_boxes[..., 1] = y.data + self.grid_y
|
||||
pred_boxes[..., 2] = width
|
||||
pred_boxes[..., 3] = height
|
||||
|
||||
# If not in training phase return predictions
|
||||
output = torch.cat((pred_boxes.view(bs, -1, 4) * stride,
|
||||
torch.sigmoid(pred_conf.view(bs, -1, 1)), pred_cls.view(bs, -1, self.nC)), -1)
|
||||
return output.data
|
||||
p_boxes = torch.stack((x + self.grid_x, y + self.grid_y, width, height), 4) # xywh
|
||||
|
||||
# output.shape = [1, 3, 13, 13, 85]
|
||||
output = torch.cat((p_boxes * stride, torch.sigmoid(p_conf).unsqueeze(4), p_cls), 4)
|
||||
|
||||
# returns shape = [1, 507, 85]
|
||||
return output.data.view(bs, -1, 5 + self.nC)
|
||||
|
||||
|
||||
class Darknet(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue