This commit is contained in:
Glenn Jocher 2019-06-12 14:30:40 +02:00
parent bca423ee43
commit 81b4a7833f
1 changed files with 24 additions and 7 deletions

View File

@ -219,7 +219,7 @@ def compute_ap(recall, precision):
return ap
def bbox_iou(box1, box2, x1y1x2y2=True):
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False):
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
box2 = box2.t()
@ -243,7 +243,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
union_area = ((b1_x2 - b1_x1) * (b1_y2 - b1_y1) + 1e-16) + \
(b2_x2 - b2_x1) * (b2_y2 - b2_y1) - inter_area
return inter_area / union_area # iou
iou = inter_area / union_area # iou
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
c_x1, c_x2 = torch.min(b1_x1, b2_x1), torch.max(b1_x2, b2_x2)
c_y1, c_y2 = torch.min(b1_y1, b2_y1), torch.max(b1_y2, b2_y2)
c_area = (c_x2 - c_x1) * (c_y2 - c_y1) # convex area
return iou - (c_area - union_area) / c_area # GIoU
return iou
def wh_iou(box1, box2):
@ -265,8 +272,8 @@ def wh_iou(box1, box2):
def compute_loss(p, targets, model): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, indices = build_targets(model, targets)
lxy, lwh, lcls, lconf, lgiou = ft([0]), ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
# Define criteria
MSE = nn.MSELoss()
@ -287,6 +294,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
tconf[b, a, gj, gi] = 1 # conf
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
# Build GIoU boxes
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
# lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss
@ -306,7 +318,7 @@ def build_targets(model, targets):
model = model.module
nt = len(targets)
txy, twh, tcls, indices = [], [], [], []
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
for i in model.yolo_layers:
layer = model.module_list[i][0]
@ -330,7 +342,12 @@ def build_targets(model, targets):
indices.append((b, a, gj, gi))
# XY coordinates
txy.append(gxy - gxy.floor())
gxy -= gxy.floor()
txy.append(gxy)
# GIoU
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
anchor_vec.append(layer.anchor_vec[a])
# Width and height
twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method
@ -341,7 +358,7 @@ def build_targets(model, targets):
if c.shape[0]:
assert c.max() <= layer.nc, 'Target classes exceed model classes'
return txy, twh, tcls, indices
return txy, twh, tcls, tbox, indices, anchor_vec
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):