xy and wh losses respectively merged
This commit is contained in:
parent
3eb49be263
commit
0772ebf7c9
|
@ -231,18 +231,16 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
continue
|
continue
|
||||||
t = target[b]
|
t = target[b]
|
||||||
|
|
||||||
# Convert to position relative to box
|
gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
|
||||||
gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
|
||||||
|
|
||||||
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
|
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
|
||||||
gi = torch.clamp(gx.long(), min=0, max=nG - 1)
|
gi, gj = torch.clamp(gxy.long(), min=0, max=nG - 1).t()
|
||||||
gj = torch.clamp(gy.long(), min=0, max=nG - 1)
|
|
||||||
|
|
||||||
# iou of targets-anchors (using wh only)
|
# iou of targets-anchors (using wh only)
|
||||||
box1 = t[:, 3:5] * nG
|
box1 = gwh
|
||||||
box2 = anchor_wh.unsqueeze(1)
|
box2 = anchor_wh.unsqueeze(1)
|
||||||
inter_area = torch.min(box1, box2).prod(2)
|
inter_area = torch.min(box1, box2).prod(2)
|
||||||
iou = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16)
|
iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
|
||||||
|
|
||||||
# Select best iou_pred and anchor
|
# Select best iou_pred and anchor
|
||||||
iou_best, a = iou.max(0) # best anchor [0-2] for each target
|
iou_best, a = iou.max(0) # best anchor [0-2] for each target
|
||||||
|
@ -269,17 +267,14 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
if iou_best < 0.10:
|
if iou_best < 0.10:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tc, gx, gy, gwh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3:5] * nG
|
tc, gxy, gwh = t[:, 0].long(), t[:, 1:3] * nG, t[:, 3:5] * nG
|
||||||
|
|
||||||
# Coordinates
|
# XY coordinates
|
||||||
txy[b, a, gj, gi, 0] = gx - gi.float()
|
txy[b, a, gj, gi] = gxy - gxy.floor()
|
||||||
txy[b, a, gj, gi, 1] = gy - gj.float()
|
|
||||||
|
|
||||||
# Width and height (yolo method)
|
# Width and height
|
||||||
twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a])
|
twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a]) # yolo method
|
||||||
|
# twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2 # power method
|
||||||
# Width and height (power method)
|
|
||||||
# twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2
|
|
||||||
|
|
||||||
# One-hot encoding of label
|
# One-hot encoding of label
|
||||||
tcls[b, a, gj, gi, tc] = 1
|
tcls[b, a, gj, gi, tc] = 1
|
||||||
|
|
Loading…
Reference in New Issue