ONNX compatibility updates

This commit is contained in:
Glenn Jocher 2018-12-25 13:24:21 +01:00
parent febc55d96a
commit 647e1c6f52
1 changed files with 27 additions and 27 deletions

View File

@ -127,13 +127,12 @@ class YOLOLayer(nn.Module):
self.loss_means = torch.ones(6) self.loss_means = torch.ones(6)
self.tx, self.ty, self.tw, self.th = [], [], [], [] self.tx, self.ty, self.tw, self.th = [], [], [], []
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
def forward(self, p, targets=None, batch_report=False, var=None): def forward(self, p, targets=None, batch_report=False, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size bs = p.shape[0] # batch size
nG = p.shape[2] # number of grid points nG = p.shape[2] # number of grid points
stride = self.img_dim / nG
if p.is_cuda and not self.grid_x.is_cuda: if p.is_cuda and not self.grid_x.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
@ -143,30 +142,30 @@ class YOLOLayer(nn.Module):
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
# Get outputs
x = torch.sigmoid(p[..., 0]) # Center x
y = torch.sigmoid(p[..., 1]) # Center y
p_conf = p[..., 4] # Conf
p_cls = p[..., 5:] # Class
# Width and height (yolo method)
w = p[..., 2] # Width
h = p[..., 3] # Height
width = torch.exp(w.data) * self.anchor_w
height = torch.exp(h.data) * self.anchor_h
# Width and height (power method)
# w = torch.sigmoid(p[..., 2]) # Width
# h = torch.sigmoid(p[..., 3]) # Height
# width = ((w.data * 2) ** 2) * self.anchor_w
# height = ((h.data * 2) ** 2) * self.anchor_h
# Training # Training
if targets is not None: if targets is not None:
MSELoss = nn.MSELoss() MSELoss = nn.MSELoss()
BCEWithLogitsLoss = nn.BCEWithLogitsLoss() BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
CrossEntropyLoss = nn.CrossEntropyLoss() CrossEntropyLoss = nn.CrossEntropyLoss()
# Get outputs
x = torch.sigmoid(p[..., 0]) # Center x
y = torch.sigmoid(p[..., 1]) # Center y
p_conf = p[..., 4] # Conf
p_cls = p[..., 5:] # Class
# Width and height (yolo method)
w = p[..., 2] # Width
h = p[..., 3] # Height
width = torch.exp(w.data) * self.anchor_w
height = torch.exp(h.data) * self.anchor_h
# Width and height (power method)
# w = torch.sigmoid(p[..., 2]) # Width
# h = torch.sigmoid(p[..., 3]) # Height
# width = ((w.data * 2) ** 2) * self.anchor_w
# height = ((h.data * 2) ** 2) * self.anchor_h
p_boxes = None p_boxes = None
if batch_report: if batch_report:
# Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13) # Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13)
@ -239,14 +238,15 @@ class YOLOLayer(nn.Module):
nT, TP, FP, FPe, FN, TC nT, TP, FP, FPe, FN, TC
else: else:
# If not in training phase return predictions stride = self.img_dim / nG
p_boxes = torch.stack((x + self.grid_x, y + self.grid_y, width, height), 4) # xywh p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x
p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y
p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w # width
p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h # height
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
p[..., :4] *= stride
# output.shape = [1, 3, 13, 13, 85] return p.view(bs, self.nA * nG * nG, 5 + self.nC)
output = torch.cat((p_boxes * stride, torch.sigmoid(p_conf).unsqueeze(4), p_cls), 4)
# returns shape = [1, 507, 85]
return output.data.view(bs, -1, 5 + self.nC)
class Darknet(nn.Module): class Darknet(nn.Module):