ONNX compatibility updates
This commit is contained in:
parent
febc55d96a
commit
647e1c6f52
30
models.py
30
models.py
|
@ -127,13 +127,12 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
self.loss_means = torch.ones(6)
|
self.loss_means = torch.ones(6)
|
||||||
self.tx, self.ty, self.tw, self.th = [], [], [], []
|
self.tx, self.ty, self.tw, self.th = [], [], [], []
|
||||||
|
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
|
||||||
|
|
||||||
def forward(self, p, targets=None, batch_report=False, var=None):
|
def forward(self, p, targets=None, batch_report=False, var=None):
|
||||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||||
|
|
||||||
bs = p.shape[0] # batch size
|
bs = p.shape[0] # batch size
|
||||||
nG = p.shape[2] # number of grid points
|
nG = p.shape[2] # number of grid points
|
||||||
stride = self.img_dim / nG
|
|
||||||
|
|
||||||
if p.is_cuda and not self.grid_x.is_cuda:
|
if p.is_cuda and not self.grid_x.is_cuda:
|
||||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||||
|
@ -143,6 +142,12 @@ class YOLOLayer(nn.Module):
|
||||||
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
||||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if targets is not None:
|
||||||
|
MSELoss = nn.MSELoss()
|
||||||
|
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
||||||
|
CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
# Get outputs
|
# Get outputs
|
||||||
x = torch.sigmoid(p[..., 0]) # Center x
|
x = torch.sigmoid(p[..., 0]) # Center x
|
||||||
y = torch.sigmoid(p[..., 1]) # Center y
|
y = torch.sigmoid(p[..., 1]) # Center y
|
||||||
|
@ -161,12 +166,6 @@ class YOLOLayer(nn.Module):
|
||||||
# width = ((w.data * 2) ** 2) * self.anchor_w
|
# width = ((w.data * 2) ** 2) * self.anchor_w
|
||||||
# height = ((h.data * 2) ** 2) * self.anchor_h
|
# height = ((h.data * 2) ** 2) * self.anchor_h
|
||||||
|
|
||||||
# Training
|
|
||||||
if targets is not None:
|
|
||||||
MSELoss = nn.MSELoss()
|
|
||||||
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
|
||||||
CrossEntropyLoss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
p_boxes = None
|
p_boxes = None
|
||||||
if batch_report:
|
if batch_report:
|
||||||
# Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13)
|
# Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13)
|
||||||
|
@ -239,14 +238,15 @@ class YOLOLayer(nn.Module):
|
||||||
nT, TP, FP, FPe, FN, TC
|
nT, TP, FP, FPe, FN, TC
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If not in training phase return predictions
|
stride = self.img_dim / nG
|
||||||
p_boxes = torch.stack((x + self.grid_x, y + self.grid_y, width, height), 4) # xywh
|
p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x
|
||||||
|
p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y
|
||||||
|
p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w # width
|
||||||
|
p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h # height
|
||||||
|
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
|
||||||
|
p[..., :4] *= stride
|
||||||
|
|
||||||
# output.shape = [1, 3, 13, 13, 85]
|
return p.view(bs, self.nA * nG * nG, 5 + self.nC)
|
||||||
output = torch.cat((p_boxes * stride, torch.sigmoid(p_conf).unsqueeze(4), p_cls), 4)
|
|
||||||
|
|
||||||
# returns shape = [1, 507, 85]
|
|
||||||
return output.data.view(bs, -1, 5 + self.nC)
|
|
||||||
|
|
||||||
|
|
||||||
class Darknet(nn.Module):
|
class Darknet(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue