diff --git a/models.py b/models.py index 23f48d00..6e8b6253 100755 --- a/models.py +++ b/models.py @@ -144,52 +144,44 @@ class YOLOLayer(nn.Module): CrossEntropyLoss = nn.CrossEntropyLoss() # Get outputs - x = torch.sigmoid(p[..., 0]) # Center x - y = torch.sigmoid(p[..., 1]) # Center y + xy = torch.sigmoid(p[..., 0:2]) p_conf = p[..., 4] # Conf p_cls = p[..., 5:] # Class # Width and height (yolo method) - w = p[..., 2] # Width - h = p[..., 3] # Height - # width = torch.exp(w.data) * self.anchor_w - # height = torch.exp(h.data) * self.anchor_h + wh = p[..., 2:4] # wh + # wh_pixels = torch.exp(wh.data) * self.anchor_wh # Width and height (power method) - # w = torch.sigmoid(p[..., 2]) # Width - # h = torch.sigmoid(p[..., 3]) # Height - # width = ((w.data * 2) ** 2) * self.anchor_w - # height = ((h.data * 2) ** 2) * self.anchor_h + # wh = torch.sigmoid(p[..., 2:4]) # wh + # wh_pixels = ((wh.data * 2) ** 2) * self.anchor_wh - tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) + txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) tcls = tcls[mask] - if x.is_cuda: - tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda() + if xy.is_cuda: + txy, tw, th, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda() # Compute losses nT = sum([len(x) for x in targets]) # number of targets nM = mask.sum().float() # number of anchors (assigned to targets) - nB = len(targets) # batch size - k = nM / nB + k = nM / bs if nM > 0: - lx = k * MSELoss(x[mask], tx[mask]) - ly = k * MSELoss(y[mask], ty[mask]) - lw = k * MSELoss(w[mask], tw[mask]) - lh = k * MSELoss(h[mask], th[mask]) + lxy = k * MSELoss(xy[mask], txy[mask]) + lwh = k * MSELoss(wh[mask], twh[mask]) lcls = (k / 4) * CrossEntropyLoss(p_cls[mask], torch.argmax(tcls, 1)) # lcls = (k * 10) * BCEWithLogitsLoss(p_cls[mask], tcls.float()) else: FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor - lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) + lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]) lconf = (k * 64) * BCEWithLogitsLoss(p_conf, mask.float()) # Sum loss components - loss = lx + ly + lw + lh + lconf + lcls + loss = lxy + lwh + lconf + lcls - return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), nT + return loss, loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item(), nT else: if ONNX_EXPORT: @@ -235,7 +227,7 @@ class Darknet(nn.Module): self.module_defs[0]['height'] = img_size self.hyperparams, self.module_list = create_modules(self.module_defs) self.img_size = img_size - self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT'] + self.loss_names = ['loss', 'xy', 'wh', 'conf', 'cls', 'nT'] self.losses = [] def forward(self, x, targets=None, var=0): diff --git a/train.py b/train.py index 5f4018c4..d4919cbc 100644 --- a/train.py +++ b/train.py @@ -87,8 +87,8 @@ def train( for epoch in range(epochs): epoch += start_epoch - print(('%8s%12s' + '%10s' * 9) % ( - 'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'nTargets', 'time')) + print(('%8s%12s' + '%10s' * 7) % ( + 'Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) # Update scheduler (automatic) # scheduler.step() @@ -139,9 +139,9 @@ def train( for key, val in model.losses.items(): rloss[key] = (rloss[key] * ui + val) / (ui + 1) - s = ('%8s%12s' + '%10.3g' * 9) % ( - '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], - rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], + s = ('%8s%12s' + '%10.3g' * 7) % ( + '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['xy'], + rloss['wh'], rloss['conf'], rloss['cls'], rloss['loss'], model.losses['nT'], time.time() - t0) t0 = time.time() print(s) diff --git a/utils/utils.py b/utils/utils.py index 3086dc5e..56314dfd 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -220,10 +220,8 @@ def build_targets(target, anchor_wh, nA, nC, nG): """ nB = len(target) # number of images in batch nT = [len(x) for x in target] - tx = torch.zeros(nB, nA, nG, nG) # batch size, anchors, grid size - ty = torch.zeros(nB, nA, nG, nG) - tw = torch.zeros(nB, nA, nG, nG) - th = torch.zeros(nB, nA, nG, nG) + txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size + twh = torch.zeros(nB, nA, nG, nG, 2) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes @@ -274,22 +272,22 @@ def build_targets(target, anchor_wh, nA, nC, nG): tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG # Coordinates - tx[b, a, gj, gi] = gx - gi.float() - ty[b, a, gj, gi] = gy - gj.float() + txy[b, a, gj, gi, 0] = gx - gi.float() + txy[b, a, gj, gi, 1] = gy - gj.float() # Width and height (yolo method) - tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0]) - th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1]) + twh[b, a, gj, gi, 0] = torch.log(gw / anchor_wh[a, 0]) + twh[b, a, gj, gi, 1] = torch.log(gh / anchor_wh[a, 1]) # Width and height (power method) - # tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2 - # th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2 + # twh[b, a, gj, gi, 0] = torch.sqrt(gw / anchor_wh[a, 0]) / 2 + # twh[b, a, gj, gi, 1] = torch.sqrt(gh / anchor_wh[a, 1]) / 2 # One-hot encoding of label tcls[b, a, gj, gi, tc] = 1 tconf[b, a, gj, gi] = 1 - return tx, ty, tw, th, tconf, tcls + return txy, twh, tconf, tcls def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): @@ -447,13 +445,13 @@ def plot_results(): # import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt') plt.figure(figsize=(16, 8)) - s = ['X', 'Y', 'Width', 'Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision'] + s = ['XY', 'Width-Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision'] files = sorted(glob.glob('results*.txt')) for f in files: - results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 11, 12, 13]).T # column 13 is mAP + results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 11, 12, 13]).T # column 11 is mAP n = results.shape[1] - for i in range(10): - plt.subplot(2, 5, i + 1) + for i in range(8): + plt.subplot(2, 4, i + 1) plt.plot(range(1, n), results[i, 1:], marker='.', label=f) plt.title(s[i]) if i == 0: