updates
This commit is contained in:
parent
bca423ee43
commit
81b4a7833f
|
@ -219,7 +219,7 @@ def compute_ap(recall, precision):
|
|||
return ap
|
||||
|
||||
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True):
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False):
|
||||
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
||||
box2 = box2.t()
|
||||
|
||||
|
@ -243,7 +243,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
|
|||
union_area = ((b1_x2 - b1_x1) * (b1_y2 - b1_y1) + 1e-16) + \
|
||||
(b2_x2 - b2_x1) * (b2_y2 - b2_y1) - inter_area
|
||||
|
||||
return inter_area / union_area # iou
|
||||
iou = inter_area / union_area # iou
|
||||
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
|
||||
c_x1, c_x2 = torch.min(b1_x1, b2_x1), torch.max(b1_x2, b2_x2)
|
||||
c_y1, c_y2 = torch.min(b1_y1, b2_y1), torch.max(b1_y2, b2_y2)
|
||||
c_area = (c_x2 - c_x1) * (c_y2 - c_y1) # convex area
|
||||
return iou - (c_area - union_area) / c_area # GIoU
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def wh_iou(box1, box2):
|
||||
|
@ -265,8 +272,8 @@ def wh_iou(box1, box2):
|
|||
|
||||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
|
||||
txy, twh, tcls, indices = build_targets(model, targets)
|
||||
lxy, lwh, lcls, lconf, lgiou = ft([0]), ft([0]), ft([0]), ft([0]), ft([0])
|
||||
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||
|
||||
# Define criteria
|
||||
MSE = nn.MSELoss()
|
||||
|
@ -287,6 +294,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
tconf[b, a, gj, gi] = 1 # conf
|
||||
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
||||
|
||||
# Build GIoU boxes
|
||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
|
||||
|
||||
# lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
||||
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
||||
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
||||
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss
|
||||
|
@ -306,7 +318,7 @@ def build_targets(model, targets):
|
|||
model = model.module
|
||||
|
||||
nt = len(targets)
|
||||
txy, twh, tcls, indices = [], [], [], []
|
||||
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
|
||||
for i in model.yolo_layers:
|
||||
layer = model.module_list[i][0]
|
||||
|
||||
|
@ -330,7 +342,12 @@ def build_targets(model, targets):
|
|||
indices.append((b, a, gj, gi))
|
||||
|
||||
# XY coordinates
|
||||
txy.append(gxy - gxy.floor())
|
||||
gxy -= gxy.floor()
|
||||
txy.append(gxy)
|
||||
|
||||
# GIoU
|
||||
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
||||
anchor_vec.append(layer.anchor_vec[a])
|
||||
|
||||
# Width and height
|
||||
twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method
|
||||
|
@ -341,7 +358,7 @@ def build_targets(model, targets):
|
|||
if c.shape[0]:
|
||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||
|
||||
return txy, twh, tcls, indices
|
||||
return txy, twh, tcls, tbox, indices, anchor_vec
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||
|
|
Loading…
Reference in New Issue