This commit is contained in:
Glenn Jocher 2019-08-17 19:20:39 +02:00
parent 926447e8c4
commit 2d57e5d877
2 changed files with 30 additions and 36 deletions

View File

@ -155,7 +155,7 @@ class YOLOLayer(nn.Module):
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
io[..., :4] *= self.stride
arc = 'normal' # (normal, uCE uBCE) architecture types
arc = 'normal' # (normal, uCE, uBCE) architecture types
if arc == 'normal':
io[..., 4:] = torch.sigmoid(io[..., 4:])
elif arc == 'uCE':

View File

@ -290,19 +290,19 @@ def wh_iou(box1, box2):
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
h = model.hyp # hyperparameters
# Define criteria
MSE = nn.MSELoss()
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
# CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
CE = nn.CrossEntropyLoss(weight=model.class_weights)
# Compute losses
bs = p[0].shape[0] # batch size
k = bs / 64 # loss gain
arc = 'normal' # (normal, uCE, uBCE) architecture types
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi0[..., 0]) # target obj
@ -314,45 +314,46 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
tobj[b, a, gj, gi] = 1.0 # obj
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
# s = 1.5 # scale_xy
pxy = torch.sigmoid(pi[..., 0:2]) # * s - (s - 1) / 2
if giou_loss:
# GIoU
pxy = torch.sigmoid(pi[..., 0:2]) # pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)
pbox = torch.cat((pxy, torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
else:
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lbox += (k * h['giou']) * (1.0 - giou).mean() # giou loss
if model.nc > 1: # cls loss (only if multiple classes)
if arc == 'normal' and model.nc > 1: # cls loss (only if multiple classes)
tclsm = torch.zeros_like(pi[..., 5:])
tclsm[range(nb), tcls[i]] = 1.0
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
# udm_ce = torch.zeros_like(pi0[..., 0]).long() # unified detection matrix for CE
# udm_ce[b, a, gj, gi] = tcls[i] + 1
# lcls += (k * h['cls']) * CE(pi0[..., 4:].view(-1, model.nc + 1), udm_ce.view(-1)) # unified CE
# udm = torch.zeros_like(pi0[..., 5:]) # unified detection matrix for BCE
# udm[b, a, gj, gi, tcls[i]] = 1.0
# lcls += (k * h['cls']) * BCEcls(pi0[..., 5:], udm) # unified BCE (hyps 200-30)
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
if arc == 'normal':
lobj += (k * h['obj']) * BCEobj(pi0[..., 4], tobj) # obj loss
loss = lxy + lwh + lobj + lcls
return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach()
elif arc == 'uCE': # suggest h['cls']=5.
udm_ce = torch.zeros_like(pi0[..., 0]).long() # unified detection matrix for CE
if nb:
udm_ce[b, a, gj, gi] = tcls[i] + 1
lcls += (k * h['cls']) * CE(pi0[..., 4:].view(-1, model.nc + 1), udm_ce.view(-1)) # unified CE
elif arc == 'uBCE':
udm = torch.zeros_like(pi0[..., 5:]) # unified detection matrix for BCE
if nb:
udm[b, a, gj, gi, tcls[i]] = 1.0
lcls += (k * h['cls']) * BCEcls(pi0[..., 5:], udm) # unified BCE (hyps 200-30)
loss = lbox + lobj + lcls
return loss, torch.cat((lbox, ft([0]), lobj, lcls, loss)).detach()
def build_targets(model, targets):
# targets = [image, class, x, y, w, h]
nt = len(targets)
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
tcls, tbox, indices, av = [], [], [], []
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i in model.yolo_layers:
# get number of grid points and anchor vec for this yolo layer
@ -389,24 +390,17 @@ def build_targets(model, targets):
gi, gj = gxy.long().t() # grid x, y indices
indices.append((b, a, gj, gi))
# XY coordinates
gxy -= gxy.floor()
txy.append(gxy)
# GIoU
gxy -= gxy.floor() # xy
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
av.append(anchor_vec[a]) # anchor vec
# Width and height
twh.append(torch.log(gwh / anchor_vec[a])) # wh yolo method
# twh.append((gwh / anchor_vec[a]) ** (1 / 3) / 2) # wh power method
# Class
tcls.append(c)
if c.shape[0]: # if any targets
assert c.max() <= model.nc, 'Target classes exceed model classes'
return txy, twh, tcls, tbox, indices, av
return tcls, tbox, indices, av
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):