updates
This commit is contained in:
parent
a92b6d4d32
commit
58d4826a11
15
models.py
15
models.py
|
@ -135,6 +135,10 @@ class YOLOLayer(nn.Module):
|
|||
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
||||
# Width and height
|
||||
wh = p[..., 2:4] # yolo method
|
||||
# wh = torch.sigmoid(p[..., 2:4]) # power method
|
||||
|
||||
# Training
|
||||
if targets is not None:
|
||||
MSELoss = nn.MSELoss()
|
||||
|
@ -146,14 +150,6 @@ class YOLOLayer(nn.Module):
|
|||
p_conf = p[..., 4] # Conf
|
||||
p_cls = p[..., 5:] # Class
|
||||
|
||||
# Width and height (yolo method)
|
||||
wh = p[..., 2:4] # wh
|
||||
# wh_pixels = torch.exp(wh.data) * self.anchor_wh
|
||||
|
||||
# Width and height (power method)
|
||||
# wh = torch.sigmoid(p[..., 2:4]) # wh
|
||||
# wh_pixels = ((wh.data * 2) ** 2) * self.anchor_wh
|
||||
|
||||
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
|
||||
|
||||
tcls = tcls[mask]
|
||||
|
@ -206,7 +202,8 @@ class YOLOLayer(nn.Module):
|
|||
return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t()
|
||||
|
||||
p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy
|
||||
p[..., 2:4] = torch.exp(p[..., 2:4]) * self.anchor_wh # wh
|
||||
p[..., 2:4] = torch.exp(wh) * self.anchor_wh # wh yolo method
|
||||
# p[..., 2:4] = ((wh * 2) ** 2) * self.anchor_wh # wh power method
|
||||
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
|
||||
p[..., :4] *= self.stride
|
||||
|
||||
|
|
Loading…
Reference in New Issue