From 0cc8f2be01456c1e91f160e7fd8f1d830e3332ae Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 9 Oct 2018 19:22:33 +0200 Subject: [PATCH] clean up train.py --- models.py | 9 ++++----- train.py | 26 ++++++++++++++------------ utils/gcp.sh | 13 +++++++------ 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/models.py b/models.py index 2bb8719b..562a5c6a 100755 --- a/models.py +++ b/models.py @@ -139,7 +139,7 @@ class YOLOLayer(nn.Module): if targets is not None: MSELoss = nn.MSELoss(size_average=True) BCEWithLogitsLoss = nn.BCEWithLogitsLoss(size_average=True) - CrossEntropyLoss = nn.CrossEntropyLoss() + # CrossEntropyLoss = nn.CrossEntropyLoss() if requestPrecision: gx = self.grid_x[:, :, :nG, :nG] @@ -176,7 +176,7 @@ class YOLOLayer(nn.Module): lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) # Add confidence loss for background anchors (noobj) - #lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) + # lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) # Sum loss components loss = lx + ly + lw + lh + lconf + lcls @@ -244,8 +244,8 @@ class Darknet(nn.Module): if is_training: self.losses['nT'] /= 3 - self.losses['TC'] /= 3 - metrics = torch.zeros(4, len(self.losses['FPe'])) # TP, FP, FN, target_count + self.losses['TC'] /= 3 # target category + metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN ui = np.unique(self.losses['TC'])[1:] for i in ui: @@ -253,7 +253,6 @@ class Darknet(nn.Module): metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN - metrics[3] = metrics.sum(0) metrics[1] += self.losses['FPe'] self.losses['TP'] = metrics[0].sum() diff --git a/train.py b/train.py index 82a5188e..a605e186 100644 --- a/train.py +++ b/train.py @@ -87,6 +87,7 @@ def main(opt): modelinfo(model) t0, t1 = time.time(), time.time() + mean_recall, mean_precision = 0, 0 print('%10s' * 16 % ( 'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nTargets', 'TP', 'FP', 'FN', 'time')) for epoch in range(opt.epochs): @@ -112,7 +113,8 @@ def main(opt): ui = -1 rloss = defaultdict(float) # running loss - metrics = torch.zeros(4, num_classes) + metrics = torch.zeros(3, num_classes) + optimizer.zero_grad() for i, (imgs, targets) in enumerate(dataloader): if sum([len(x) for x in targets]) < 1: # if no targets continue continue @@ -125,37 +127,37 @@ def main(opt): # Compute loss, compute gradient, update parameters loss = model(imgs.to(device), targets, requestPrecision=True) - optimizer.zero_grad() loss.backward() + + # accumulated_batches = 4 # accumulate gradient for 4 batches before stepping optimizer + # if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1): optimizer.step() + optimizer.zero_grad() # Compute running epoch-means of tracked metrics ui += 1 metrics += model.losses['metrics'] + TP, FP, FN = metrics for key, val in model.losses.items(): rloss[key] = (rloss[key] * ui + val) / (ui + 1) # Precision - precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16) - k = (metrics[0] + metrics[1]) > 0 + precision = TP / (TP + FP) + k = (TP + FP) > 0 if k.sum() > 0: mean_precision = precision[k].mean() - else: - mean_precision = 0 # Recall - recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16) - k = (metrics[0] + metrics[2]) > 0 + recall = TP / (TP + FN) + k = (TP + FN) > 0 if k.sum() > 0: mean_recall = recall[k].mean() - else: - mean_recall = 0 s = ('%10s%10s' + '%10.3g' * 14) % ( '%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], - rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'], - model.losses['FP'], model.losses['FN'], time.time() - t1) + rloss['loss'], mean_precision, mean_recall, model.losses['nT'], TP.sum(), + FP.sum(), FN.sum(), time.time() - t1) t1 = time.time() print(s) diff --git a/utils/gcp.sh b/utils/gcp.sh index 42d5ee7c..e21e5155 100644 --- a/utils/gcp.sh +++ b/utils/gcp.sh @@ -11,21 +11,22 @@ gsutil cp gs://ultralytics/fresh9_5_e201.pt yolov3/checkpoints python3 detect.py # Test -python3 test.py -img_size 416 -weights_path checkpoints/latest.pt -conf_thresh 0.5 +python3 test.py -img_size 416 -weights_path checkpoints/latest.pt -conf_thres 0.5 # Download and Test sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 -cd yolov3 -cd checkpoints +cd yolov3/checkpoints wget https://pjreddie.com/media/files/yolov3.weights cd .. python3 test.py -img_size 416 -weights_path checkpoints/backup5.pt -nms_thres 0.45 # Download and Resume sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 -cd yolov3 -cd checkpoints +cd yolov3/checkpoints wget https://storage.googleapis.com/ultralytics/yolov3.pt cp yolov3.pt latest.pt cd .. -python3 train.py -img_size 416 -epochs 1 -resume 1 +python3 train.py -img_size 416 -batch_size 12 -epochs 1 -resume 1 +python3 test.py -img_size 416 -weights_path checkpoints/latest.pt -conf_thres 0.5 + +