ONNX compatibility updates

This commit is contained in:
Glenn Jocher 2018-12-25 13:24:21 +01:00
parent febc55d96a
commit 647e1c6f52
1 changed files with 27 additions and 27 deletions

View File

@ -127,13 +127,12 @@ class YOLOLayer(nn.Module):
self.loss_means = torch.ones(6)
self.tx, self.ty, self.tw, self.th = [], [], [], []
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
def forward(self, p, targets=None, batch_report=False, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size
nG = p.shape[2] # number of grid points
stride = self.img_dim / nG
if p.is_cuda and not self.grid_x.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
@ -143,6 +142,12 @@ class YOLOLayer(nn.Module):
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
# Training
if targets is not None:
MSELoss = nn.MSELoss()
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
CrossEntropyLoss = nn.CrossEntropyLoss()
# Get outputs
x = torch.sigmoid(p[..., 0]) # Center x
y = torch.sigmoid(p[..., 1]) # Center y
@ -161,12 +166,6 @@ class YOLOLayer(nn.Module):
# width = ((w.data * 2) ** 2) * self.anchor_w
# height = ((h.data * 2) ** 2) * self.anchor_h
# Training
if targets is not None:
MSELoss = nn.MSELoss()
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
CrossEntropyLoss = nn.CrossEntropyLoss()
p_boxes = None
if batch_report:
# Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13)
@ -239,14 +238,15 @@ class YOLOLayer(nn.Module):
nT, TP, FP, FPe, FN, TC
else:
# If not in training phase return predictions
p_boxes = torch.stack((x + self.grid_x, y + self.grid_y, width, height), 4) # xywh
stride = self.img_dim / nG
p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x
p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y
p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w # width
p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h # height
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
p[..., :4] *= stride
# output.shape = [1, 3, 13, 13, 85]
output = torch.cat((p_boxes * stride, torch.sigmoid(p_conf).unsqueeze(4), p_cls), 4)
# returns shape = [1, 507, 85]
return output.data.view(bs, -1, 5 + self.nC)
return p.view(bs, self.nA * nG * nG, 5 + self.nC)
class Darknet(nn.Module):