equal layer weights

This commit is contained in:
Glenn Jocher 2019-07-12 12:23:17 +02:00
parent 5886200401
commit bd9789aa00
2 changed files with 17 additions and 27 deletions

View File

@ -348,14 +348,14 @@ if __name__ == '__main__':
# Mutate
init_seeds(seed=int(time.time()))
s = [.15, .15, .15, .15, .15, .15, .15, .15, 0, 0, 0, 0] # fractional sigmas
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15] # fractional sigmas
for i, k in enumerate(hyp.keys()):
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
hyp[k] *= float(x) # vary by 20% 1sigma
# Clip to limits
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay']
limits = [(1e-4, 1e-2), (0, 0.70), (0.70, 0.98), (0, 0.01)]
limits = [(1e-4, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.01)]
for k, v in zip(keys, limits):
hyp[k] = np.clip(hyp[k], v[0], v[1])

View File

@ -274,28 +274,24 @@ def wh_iou(box1, box2):
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, tbox, indices, anchor_vec, nc = build_targets(model, targets)
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
h = model.hyp # hyperparameters
# Define criteria
MSE = nn.MSELoss(reduction='sum')
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction='sum')
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction='sum')
# CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
MSE = nn.MSELoss()
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
# Compute losses
bs = p[0].shape[0] # batch size
k = 3 * bs / 64 # loss gain
nt, ng = 0, 0 # number of targets, number grid points
k = bs / 64 # loss gain
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi0[..., 0]) # target obj
ng += tobj.numel()
nb = len(b)
# Compute losses
if nb: # number of targets
nt += nb
if len(b): # number of targets
pi = pi0[b, a, gj, gi] # predictions closest to anchors
tobj[b, a, gj, gi] = 1.0 # obj
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
@ -303,27 +299,21 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
if giou_loss:
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lxy += (1.0 - giou).sum() # giou loss
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
else:
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
tclsm = torch.zeros_like(pi[..., 5:])
tclsm[range(nb), tcls[i]] = 1.0
lcls += BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
# lcls += CE(pi[..., 5:], tcls[i]) # cls loss (CE)
# tclsm = torch.zeros_like(pi[..., 5:])
# tclsm[range(len(b)), tcls[i]] = 1.0
# lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE)
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
lobj += BCEobj(pi0[..., 4], tobj) # obj loss
lxy *= (k * h['giou']) / nt
lwh *= (k * h['wh']) / nt
lcls *= (k * h['cls']) / (nt * nc)
lobj *= (k * h['obj']) / ng
lobj += (k * h['obj']) * BCEobj(pi0[..., 4], tobj) # obj loss
loss = lxy + lwh + lobj + lcls
return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach()
@ -385,7 +375,7 @@ def build_targets(model, targets):
if c.shape[0]:
assert c.max() <= layer.nc, 'Target classes exceed model classes'
return txy, twh, tcls, tbox, indices, anchor_vec, layer.nc
return txy, twh, tcls, tbox, indices, anchor_vec
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):