diff --git a/detect.py b/detect.py index 6f706b41..1dbd4076 100644 --- a/detect.py +++ b/detect.py @@ -36,7 +36,7 @@ def detect( os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights) model.load_state_dict(torch.load(weights, map_location='cpu')['model']) else: # darknet format - load_darknet_weights(model, weights) + _ = load_darknet_weights(model, weights) model.to(device).eval() diff --git a/models.py b/models.py index a1988d52..38ec55d8 100755 --- a/models.py +++ b/models.py @@ -293,11 +293,7 @@ def load_darknet_weights(self, weights, cutoff=-1): conv_layer.weight.data.copy_(conv_w) ptr += num_w - -""" - @:param path - path of the new weights file - @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved) -""" + return cutoff def save_weights(self, path, cutoff=-1): diff --git a/test.py b/test.py index 6ac47f4d..b614eae6 100644 --- a/test.py +++ b/test.py @@ -35,7 +35,7 @@ def test( if weights.endswith('.pt'): # pytorch format model.load_state_dict(torch.load(weights, map_location='cpu')['model']) else: # darknet format - load_darknet_weights(model, weights) + _ = load_darknet_weights(model, weights) model.to(device).eval() diff --git a/train.py b/train.py index c8ea97c3..208c471f 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ def train( resume=False, epochs=100, batch_size=16, - accumulated_batches=1, + accumulate=1, multi_scale=False, freeze_backbone=False, ): @@ -35,9 +35,9 @@ def train( model = Darknet(cfg, img_size) # Get dataloader - dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, multi_scale=multi_scale, augment=True) + dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) - lr0 = 0.001 + lr0 = 0.001 # initial learning rate cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 best_loss = float('inf') @@ -64,14 +64,12 @@ def train( else: # Initialize model with backbone (optional) if cfg.endswith('yolov3.cfg'): - load_darknet_weights(model, weights + 'darknet53.conv.74') - cutoff = 75 + cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74') elif cfg.endswith('yolov3-tiny.cfg'): - load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') - cutoff = 15 + cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') # Set optimizer - optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9) + optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) @@ -94,22 +92,21 @@ def train( # Update scheduler (automatic) # scheduler.step() - # Update scheduler (manual) at 0, 54, 61 epochs to 1e-3, 1e-4, 1e-5 + # Update scheduler (manual) if epoch > 250: lr = lr0 / 10 else: lr = lr0 - for g in optimizer.param_groups: - g['lr'] = lr + for x in optimizer.param_groups: + x['lr'] = lr - # Freeze darknet53.conv.74 for first epoch - if freeze_backbone and (epoch < 2): + # Freeze backbone at epoch 0, unfreeze at epoch 1 + if freeze_backbone and epoch < 2: for i, (name, p) in enumerate(model.named_parameters()): if int(name.split('.')[1]) < cutoff: # if layer < 75 p.requires_grad = False if (epoch == 0) else True ui = -1 - optimizer.zero_grad() rloss = defaultdict(float) for i, (imgs, targets, _, _) in enumerate(dataloader): targets = targets.to(device) @@ -118,10 +115,10 @@ def train( continue # SGD burn-in - if (epoch == 0) & (i <= n_burnin): + if (epoch == 0) and (i <= n_burnin): lr = lr0 * (i / n_burnin) ** 4 - for g in optimizer.param_groups: - g['lr'] = lr + for x in optimizer.param_groups: + x['lr'] = lr # Run model pred = model(imgs.to(device)) @@ -136,7 +133,7 @@ def train( loss.backward() # Accumulate gradient for x batches before optimizing - if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): + if (i + 1) % accumulate == 0 or (i + 1) == len(dataloader): optimizer.step() optimizer.zero_grad() @@ -154,11 +151,17 @@ def train( t0 = time.time() print(s) + # Multi-Scale training (320 - 608 pixels) every 10 batches + if multi_scale and (i + 1) % 10 == 0: + dataloader.img_size = random.choice(range(10, 20)) * 32 + print('multi_scale img_size = %g' % dataloader.img_size) + # Update best loss if rloss['total'] < best_loss: best_loss = rloss['total'] - save = True # save training results + # Save training results + save = True if save: # Save latest checkpoint checkpoint = {'epoch': epoch, @@ -172,7 +175,7 @@ def train( os.system('cp ' + latest + ' ' + best) # Save backup weights every 5 epochs (optional) - if (epoch > 0) & (epoch % 5 == 0): + if (epoch > 0) and (epoch % 5 == 0): os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch)) # Calculate mAP @@ -188,7 +191,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=270, help='number of epochs') parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch') - parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step') + parser.add_argument('--accumulate', type=int, default=1, help='accumulate gradient x batches before optimizing') parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('--data-cfg', type=str, default='cfg/coco.data', help='coco.data file path') parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608') @@ -206,6 +209,6 @@ if __name__ == '__main__': resume=opt.resume, epochs=opt.epochs, batch_size=opt.batch_size, - accumulated_batches=opt.accumulated_batches, + accumulate=opt.accumulate, multi_scale=opt.multi_scale, ) diff --git a/utils/datasets.py b/utils/datasets.py index 05d1f98d..40d935d2 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -90,7 +90,7 @@ class LoadWebcam: # for inference class LoadImagesAndLabels: # for training - def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False): + def __init__(self, path, batch_size=1, img_size=608, augment=False): with open(path, 'r') as file: self.img_files = file.readlines() self.img_files = [x.replace('\n', '') for x in self.img_files] @@ -102,8 +102,7 @@ class LoadImagesAndLabels: # for training self.nF = len(self.img_files) # number of image files self.nB = math.ceil(self.nF / batch_size) # number of batches self.batch_size = batch_size - self.height = img_size - self.multi_scale = multi_scale + self.img_size = img_size self.augment = augment assert self.nF > 0, 'No images found in %s' % path @@ -121,13 +120,6 @@ class LoadImagesAndLabels: # for training ia = self.count * self.batch_size ib = min((self.count + 1) * self.batch_size, self.nF) - if self.multi_scale: - # Multi-Scale YOLO Training - height = random.choice(range(10, 20)) * 32 # 320 - 608 pixels - else: - # Fixed-Scale YOLO Training - height = self.height - img_all, labels_all, img_paths, img_shapes = [], [], [], [] for index, files_index in enumerate(range(ia, ib)): img_path = self.img_files[self.shuffled_vector[files_index]] @@ -159,7 +151,7 @@ class LoadImagesAndLabels: # for training cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) h, w, _ = img.shape - img, ratio, padw, padh = letterbox(img, height=height) + img, ratio, padw, padh = letterbox(img, height=self.img_size) # Load labels if os.path.isfile(label_path): @@ -189,7 +181,7 @@ class LoadImagesAndLabels: # for training nL = len(labels) if nL > 0: # convert xyxy to xywh - labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height + labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / self.img_size if self.augment: # random left-right flip