equal layer weights
This commit is contained in:
parent
5886200401
commit
bd9789aa00
4
train.py
4
train.py
|
@ -348,14 +348,14 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Mutate
|
# Mutate
|
||||||
init_seeds(seed=int(time.time()))
|
init_seeds(seed=int(time.time()))
|
||||||
s = [.15, .15, .15, .15, .15, .15, .15, .15, 0, 0, 0, 0] # fractional sigmas
|
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15] # fractional sigmas
|
||||||
for i, k in enumerate(hyp.keys()):
|
for i, k in enumerate(hyp.keys()):
|
||||||
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
||||||
hyp[k] *= float(x) # vary by 20% 1sigma
|
hyp[k] *= float(x) # vary by 20% 1sigma
|
||||||
|
|
||||||
# Clip to limits
|
# Clip to limits
|
||||||
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay']
|
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay']
|
||||||
limits = [(1e-4, 1e-2), (0, 0.70), (0.70, 0.98), (0, 0.01)]
|
limits = [(1e-4, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.01)]
|
||||||
for k, v in zip(keys, limits):
|
for k, v in zip(keys, limits):
|
||||||
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
||||||
|
|
||||||
|
|
|
@ -274,28 +274,24 @@ def wh_iou(box1, box2):
|
||||||
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
||||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||||
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
|
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
|
||||||
txy, twh, tcls, tbox, indices, anchor_vec, nc = build_targets(model, targets)
|
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
MSE = nn.MSELoss(reduction='sum')
|
MSE = nn.MSELoss()
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction='sum')
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction='sum')
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
|
||||||
# CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
|
CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
bs = p[0].shape[0] # batch size
|
bs = p[0].shape[0] # batch size
|
||||||
k = 3 * bs / 64 # loss gain
|
k = bs / 64 # loss gain
|
||||||
nt, ng = 0, 0 # number of targets, number grid points
|
|
||||||
for i, pi0 in enumerate(p): # layer i predictions, i
|
for i, pi0 in enumerate(p): # layer i predictions, i
|
||||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||||
tobj = torch.zeros_like(pi0[..., 0]) # target obj
|
tobj = torch.zeros_like(pi0[..., 0]) # target obj
|
||||||
ng += tobj.numel()
|
|
||||||
nb = len(b)
|
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
if nb: # number of targets
|
if len(b): # number of targets
|
||||||
nt += nb
|
|
||||||
pi = pi0[b, a, gj, gi] # predictions closest to anchors
|
pi = pi0[b, a, gj, gi] # predictions closest to anchors
|
||||||
tobj[b, a, gj, gi] = 1.0 # obj
|
tobj[b, a, gj, gi] = 1.0 # obj
|
||||||
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
||||||
|
@ -303,27 +299,21 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
||||||
if giou_loss:
|
if giou_loss:
|
||||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
||||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||||
lxy += (1.0 - giou).sum() # giou loss
|
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
||||||
else:
|
else:
|
||||||
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
||||||
lwh += MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
||||||
|
|
||||||
tclsm = torch.zeros_like(pi[..., 5:])
|
# tclsm = torch.zeros_like(pi[..., 5:])
|
||||||
tclsm[range(nb), tcls[i]] = 1.0
|
# tclsm[range(len(b)), tcls[i]] = 1.0
|
||||||
lcls += BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
|
# lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
|
||||||
# lcls += CE(pi[..., 5:], tcls[i]) # cls loss (CE)
|
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE)
|
||||||
|
|
||||||
# Append targets to text file
|
# Append targets to text file
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||||
|
|
||||||
lobj += BCEobj(pi0[..., 4], tobj) # obj loss
|
lobj += (k * h['obj']) * BCEobj(pi0[..., 4], tobj) # obj loss
|
||||||
|
|
||||||
lxy *= (k * h['giou']) / nt
|
|
||||||
lwh *= (k * h['wh']) / nt
|
|
||||||
lcls *= (k * h['cls']) / (nt * nc)
|
|
||||||
lobj *= (k * h['obj']) / ng
|
|
||||||
|
|
||||||
loss = lxy + lwh + lobj + lcls
|
loss = lxy + lwh + lobj + lcls
|
||||||
|
|
||||||
return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach()
|
return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach()
|
||||||
|
@ -385,7 +375,7 @@ def build_targets(model, targets):
|
||||||
if c.shape[0]:
|
if c.shape[0]:
|
||||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||||
|
|
||||||
return txy, twh, tcls, tbox, indices, anchor_vec, layer.nc
|
return txy, twh, tcls, tbox, indices, anchor_vec
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
|
|
Loading…
Reference in New Issue