This commit is contained in:
Glenn Jocher 2019-02-20 23:52:36 +01:00
parent a92b6d4d32
commit 58d4826a11
1 changed files with 6 additions and 9 deletions

View File

@ -135,6 +135,10 @@ class YOLOLayer(nn.Module):
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
# Width and height
wh = p[..., 2:4] # yolo method
# wh = torch.sigmoid(p[..., 2:4]) # power method
# Training # Training
if targets is not None: if targets is not None:
MSELoss = nn.MSELoss() MSELoss = nn.MSELoss()
@ -146,14 +150,6 @@ class YOLOLayer(nn.Module):
p_conf = p[..., 4] # Conf p_conf = p[..., 4] # Conf
p_cls = p[..., 5:] # Class p_cls = p[..., 5:] # Class
# Width and height (yolo method)
wh = p[..., 2:4] # wh
# wh_pixels = torch.exp(wh.data) * self.anchor_wh
# Width and height (power method)
# wh = torch.sigmoid(p[..., 2:4]) # wh
# wh_pixels = ((wh.data * 2) ** 2) * self.anchor_wh
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
tcls = tcls[mask] tcls = tcls[mask]
@ -206,7 +202,8 @@ class YOLOLayer(nn.Module):
return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t() return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t()
p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy
p[..., 2:4] = torch.exp(p[..., 2:4]) * self.anchor_wh # wh p[..., 2:4] = torch.exp(wh) * self.anchor_wh # wh yolo method
# p[..., 2:4] = ((wh * 2) ** 2) * self.anchor_wh # wh power method
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
p[..., :4] *= self.stride p[..., :4] *= self.stride