From 55806949704025c479e0904404cc3d0967ae3bad Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 8 May 2019 17:29:23 +0200 Subject: [PATCH] updates --- train.py | 8 ++++---- utils/datasets.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 4adb68eb..67038e85 100644 --- a/train.py +++ b/train.py @@ -143,15 +143,15 @@ def train( model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # Start training - t, t0 = time.time(), time.time() model.hyp = hyp # attach hyperparameters to model model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model_info(model) nb = len(dataloader) results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches - os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None - os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None + for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg'): + os.remove(f) + t, t0 = time.time(), time.time() for epoch in range(start_epoch, epochs): model.train() print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) @@ -282,7 +282,7 @@ if __name__ == '__main__': parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)') parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--transfer', action='store_true', help='transfer learning flag') - parser.add_argument('--num-workers', type=int, default=2, help='number of Pytorch DataLoader workers') + parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers') parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method') parser.add_argument('--rank', default=0, type=int, help='distributed training node rank') parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') diff --git a/utils/datasets.py b/utils/datasets.py index 147dc077..22b6c4ba 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -318,7 +318,7 @@ def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'): top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) # resized, no border img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square return img, ratio, dw, dh