This commit is contained in:
Glenn Jocher 2020-01-11 13:11:30 -08:00
parent 1638ab71cd
commit 5cda317902
1 changed files with 8 additions and 22 deletions

View File

@ -177,27 +177,14 @@ class YOLOLayer(nn.Module):
elif ONNX_EXPORT:
# Constants CAN NOT BE BROADCAST, ensure correct shape!
m = self.na * self.nx * self.ny
ngu = self.ng.repeat((1, m, 1))
grid_xy = self.grid_xy.repeat((1, self.na, 1, 1, 1)).view(1, m, 2)
anchor_wh = self.anchor_wh.repeat((1, 1, self.nx, self.ny, 1)).view(1, m, 2) / ngu
grid_xy = self.grid_xy.repeat((1, self.na, 1, 1, 1)).view(m, 2)
anchor_wh = self.anchor_wh.repeat((1, 1, self.nx, self.ny, 1)).view(m, 2) / self.ng
p = p.view(m, self.no)
xy = torch.sigmoid(p[:, 0:2]) + grid_xy[0] # x, y
wh = torch.exp(p[:, 2:4]) * anchor_wh[0] # width, height
p_cls = F.softmax(p[:, 5:self.no], 1) * torch.sigmoid(p[:, 4:5]) # SSD-like conf
return torch.cat((xy / ngu[0], wh, p_cls), 1).t()
# p = p.view(1, m, self.no)
# xy = torch.sigmoid(p[..., 0:2]) + grid_xy # x, y
# wh = torch.exp(p[..., 2:4]) * anchor_wh # width, height
# p_conf = torch.sigmoid(p[..., 4:5]) # Conf
# p_cls = p[..., 5:self.no]
# # Broadcasting only supported on first dimension in CoreML. See onnx-coreml/_operators.py
# # p_cls = F.softmax(p_cls, 2) * p_conf # SSD-like conf
# p_cls = torch.exp(p_cls).permute((2, 1, 0))
# p_cls = p_cls / p_cls.sum(0).unsqueeze(0) * p_conf.permute((2, 1, 0)) # F.softmax() equivalent
# p_cls = p_cls.permute(2, 1, 0)
# return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t()
xy = torch.sigmoid(p[:, 0:2]) + grid_xy # x, y
wh = torch.exp(p[:, 2:4]) * anchor_wh # width, height
p_cls = torch.sigmoid(p[:, 5:self.no]) * torch.sigmoid(p[:, 4:5]) # conf
return p_cls, xy / self.ng, wh
else: # inference
# s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2)
@ -266,9 +253,8 @@ class Darknet(nn.Module):
if self.training:
return output
elif ONNX_EXPORT:
output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647
nc = self.module_list[self.yolo_layers[0]].nc # number of classes
return output[4:4 + nc].t(), output[0:4].t() # ONNX scores, boxes
x = [torch.cat(x, 0) for x in zip(*output)]
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
else:
io, p = list(zip(*output)) # inference output, training output
return torch.cat(io, 1), p