xy and wh losses respectively merged

This commit is contained in:
Glenn Jocher 2019-02-19 19:55:33 +01:00
parent 9df279cded
commit 15bba5a345
3 changed files with 33 additions and 43 deletions

View File

@ -144,52 +144,44 @@ class YOLOLayer(nn.Module):
CrossEntropyLoss = nn.CrossEntropyLoss() CrossEntropyLoss = nn.CrossEntropyLoss()
# Get outputs # Get outputs
x = torch.sigmoid(p[..., 0]) # Center x xy = torch.sigmoid(p[..., 0:2])
y = torch.sigmoid(p[..., 1]) # Center y
p_conf = p[..., 4] # Conf p_conf = p[..., 4] # Conf
p_cls = p[..., 5:] # Class p_cls = p[..., 5:] # Class
# Width and height (yolo method) # Width and height (yolo method)
w = p[..., 2] # Width wh = p[..., 2:4] # wh
h = p[..., 3] # Height # wh_pixels = torch.exp(wh.data) * self.anchor_wh
# width = torch.exp(w.data) * self.anchor_w
# height = torch.exp(h.data) * self.anchor_h
# Width and height (power method) # Width and height (power method)
# w = torch.sigmoid(p[..., 2]) # Width # wh = torch.sigmoid(p[..., 2:4]) # wh
# h = torch.sigmoid(p[..., 3]) # Height # wh_pixels = ((wh.data * 2) ** 2) * self.anchor_wh
# width = ((w.data * 2) ** 2) * self.anchor_w
# height = ((h.data * 2) ** 2) * self.anchor_h
tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
tcls = tcls[mask] tcls = tcls[mask]
if x.is_cuda: if xy.is_cuda:
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda() txy, tw, th, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda()
# Compute losses # Compute losses
nT = sum([len(x) for x in targets]) # number of targets nT = sum([len(x) for x in targets]) # number of targets
nM = mask.sum().float() # number of anchors (assigned to targets) nM = mask.sum().float() # number of anchors (assigned to targets)
nB = len(targets) # batch size k = nM / bs
k = nM / nB
if nM > 0: if nM > 0:
lx = k * MSELoss(x[mask], tx[mask]) lxy = k * MSELoss(xy[mask], txy[mask])
ly = k * MSELoss(y[mask], ty[mask]) lwh = k * MSELoss(wh[mask], twh[mask])
lw = k * MSELoss(w[mask], tw[mask])
lh = k * MSELoss(h[mask], th[mask])
lcls = (k / 4) * CrossEntropyLoss(p_cls[mask], torch.argmax(tcls, 1)) lcls = (k / 4) * CrossEntropyLoss(p_cls[mask], torch.argmax(tcls, 1))
# lcls = (k * 10) * BCEWithLogitsLoss(p_cls[mask], tcls.float()) # lcls = (k * 10) * BCEWithLogitsLoss(p_cls[mask], tcls.float())
else: else:
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0])
lconf = (k * 64) * BCEWithLogitsLoss(p_conf, mask.float()) lconf = (k * 64) * BCEWithLogitsLoss(p_conf, mask.float())
# Sum loss components # Sum loss components
loss = lx + ly + lw + lh + lconf + lcls loss = lxy + lwh + lconf + lcls
return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), nT return loss, loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item(), nT
else: else:
if ONNX_EXPORT: if ONNX_EXPORT:
@ -235,7 +227,7 @@ class Darknet(nn.Module):
self.module_defs[0]['height'] = img_size self.module_defs[0]['height'] = img_size
self.hyperparams, self.module_list = create_modules(self.module_defs) self.hyperparams, self.module_list = create_modules(self.module_defs)
self.img_size = img_size self.img_size = img_size
self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT'] self.loss_names = ['loss', 'xy', 'wh', 'conf', 'cls', 'nT']
self.losses = [] self.losses = []
def forward(self, x, targets=None, var=0): def forward(self, x, targets=None, var=0):

View File

@ -87,8 +87,8 @@ def train(
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
print(('%8s%12s' + '%10s' * 9) % ( print(('%8s%12s' + '%10s' * 7) % (
'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'nTargets', 'time')) 'Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
# Update scheduler (automatic) # Update scheduler (automatic)
# scheduler.step() # scheduler.step()
@ -139,9 +139,9 @@ def train(
for key, val in model.losses.items(): for key, val in model.losses.items():
rloss[key] = (rloss[key] * ui + val) / (ui + 1) rloss[key] = (rloss[key] * ui + val) / (ui + 1)
s = ('%8s%12s' + '%10.3g' * 9) % ( s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['xy'],
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], rloss['wh'], rloss['conf'], rloss['cls'],
rloss['loss'], model.losses['nT'], time.time() - t0) rloss['loss'], model.losses['nT'], time.time() - t0)
t0 = time.time() t0 = time.time()
print(s) print(s)

View File

@ -220,10 +220,8 @@ def build_targets(target, anchor_wh, nA, nC, nG):
""" """
nB = len(target) # number of images in batch nB = len(target) # number of images in batch
nT = [len(x) for x in target] nT = [len(x) for x in target]
tx = torch.zeros(nB, nA, nG, nG) # batch size, anchors, grid size txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
ty = torch.zeros(nB, nA, nG, nG) twh = torch.zeros(nB, nA, nG, nG, 2)
tw = torch.zeros(nB, nA, nG, nG)
th = torch.zeros(nB, nA, nG, nG)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
@ -274,22 +272,22 @@ def build_targets(target, anchor_wh, nA, nC, nG):
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
# Coordinates # Coordinates
tx[b, a, gj, gi] = gx - gi.float() txy[b, a, gj, gi, 0] = gx - gi.float()
ty[b, a, gj, gi] = gy - gj.float() txy[b, a, gj, gi, 1] = gy - gj.float()
# Width and height (yolo method) # Width and height (yolo method)
tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0]) twh[b, a, gj, gi, 0] = torch.log(gw / anchor_wh[a, 0])
th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1]) twh[b, a, gj, gi, 1] = torch.log(gh / anchor_wh[a, 1])
# Width and height (power method) # Width and height (power method)
# tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2 # twh[b, a, gj, gi, 0] = torch.sqrt(gw / anchor_wh[a, 0]) / 2
# th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2 # twh[b, a, gj, gi, 1] = torch.sqrt(gh / anchor_wh[a, 1]) / 2
# One-hot encoding of label # One-hot encoding of label
tcls[b, a, gj, gi, tc] = 1 tcls[b, a, gj, gi, tc] = 1
tconf[b, a, gj, gi] = 1 tconf[b, a, gj, gi] = 1
return tx, ty, tw, th, tconf, tcls return txy, twh, tconf, tcls
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
@ -447,13 +445,13 @@ def plot_results():
# import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt') # import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt')
plt.figure(figsize=(16, 8)) plt.figure(figsize=(16, 8))
s = ['X', 'Y', 'Width', 'Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision'] s = ['XY', 'Width-Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision']
files = sorted(glob.glob('results*.txt')) files = sorted(glob.glob('results*.txt'))
for f in files: for f in files:
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 11, 12, 13]).T # column 13 is mAP results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 11, 12, 13]).T # column 11 is mAP
n = results.shape[1] n = results.shape[1]
for i in range(10): for i in range(8):
plt.subplot(2, 5, i + 1) plt.subplot(2, 4, i + 1)
plt.plot(range(1, n), results[i, 1:], marker='.', label=f) plt.plot(range(1, n), results[i, 1:], marker='.', label=f)
plt.title(s[i]) plt.title(s[i])
if i == 0: if i == 0: