reapply yolo width and height

This commit is contained in:
Glenn Jocher 2018-09-23 22:41:36 +02:00
parent cf9b4cfa52
commit 5d402ad31a
2 changed files with 14 additions and 14 deletions

View File

@ -119,16 +119,16 @@ class YOLOLayer(nn.Module):
y = torch.sigmoid(p[..., 1]) # Center y y = torch.sigmoid(p[..., 1]) # Center y
# Width and height (yolo method) # Width and height (yolo method)
# w = p[..., 2] # Width w = p[..., 2] # Width
# h = p[..., 3] # Height h = p[..., 3] # Height
# width = torch.exp(w.data) * self.anchor_w width = torch.exp(w.data) * self.anchor_w
# height = torch.exp(h.data) * self.anchor_h height = torch.exp(h.data) * self.anchor_h
# Width and height (power method) # Width and height (power method)
w = torch.sigmoid(p[..., 2]) # Width # w = torch.sigmoid(p[..., 2]) # Width
h = torch.sigmoid(p[..., 3]) # Height # h = torch.sigmoid(p[..., 3]) # Height
width = ((w.data * 2) ** 2) * self.anchor_w # width = ((w.data * 2) ** 2) * self.anchor_w
height = ((h.data * 2) ** 2) * self.anchor_h # height = ((h.data * 2) ** 2) * self.anchor_h
# Add offset and scale with anchors (in grid space, i.e. 0-13) # Add offset and scale with anchors (in grid space, i.e. 0-13)
pred_boxes = FT(bs, self.nA, nG, nG, 4) pred_boxes = FT(bs, self.nA, nG, nG, 4)

View File

@ -263,13 +263,13 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
tx[b, a, gj, gi] = gx - gi.float() tx[b, a, gj, gi] = gx - gi.float()
ty[b, a, gj, gi] = gy - gj.float() ty[b, a, gj, gi] = gy - gj.float()
# Width and height (power method) # Width and height (yolo method)
tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2 tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0] + 1e-16)
th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2 th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1] + 1e-16)
# Width and height (yolov3 method) # Width and height (power method)
# tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0] + 1e-16) # tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2
# th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1] + 1e-16) # th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2
# One-hot encoding of label # One-hot encoding of label
tcls[b, a, gj, gi, tc] = 1 tcls[b, a, gj, gi, tc] = 1