From cd51e1137b9dc0a5abca260bd41a4a28359db60f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 25 Mar 2019 14:59:38 +0100 Subject: [PATCH] Add collate_fn() to DataLoader (#163) Multi-GPU update with custom collate function to allow variable size target vector per image without needing to pad targets. --- README.md | 24 +++-- detect.py | 22 ++--- requirements.txt | 2 + test.py | 51 ++++++----- train.py | 118 ++++++++++++------------ utils/datasets.py | 139 ++++++++++++++++------------- utils/gcp.sh | 13 ++- utils/utils.py | 41 +++------ weights/download_yolov3_weights.sh | 1 + 9 files changed, 217 insertions(+), 194 deletions(-) diff --git a/README.md b/README.md index 5d1261aa..ea6c356e 100755 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This directory contains python software and an iOS App developed by Ultralytics # Description -The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO** (https://pjreddie.com/darknet/yolo/) and to **Erik Lindernoren for the PyTorch implementation** this work is based on (https://github.com/eriklindernoren/PyTorch-YOLOv3). +The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO: ** https://pjreddie.com/darknet/yolo/. # Requirements @@ -26,6 +26,7 @@ Python 3.7 or later with the following `pip3 install -U -r requirements.txt` pac - `numpy` - `torch >= 1.0.0` - `opencv-python` +- `tqdm` # Tutorials @@ -64,17 +65,22 @@ HS**V** Intensity | +/- 50% ## Speed https://cloud.google.com/deep-learning-vm/ -**Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory) +**Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory) **CPU platform:** Intel Skylake -**GPUs:** 1-4 x NVIDIA Tesla P100 +**GPUs:** 1-4x P100 ($0.493/hr), 1-8x V100 ($0.803/hr) **HDD:** 100 GB SSD +**Dataset:** COCO train 2014 -GPUs | `batch_size` | speed | COCO epoch ---- |---| --- | --- -(P100) | (images) | (s/batch) | (min/epoch) -1 | 16 | 0.39s | 48min -2 | 32 | 0.48s | 29min -4 | 64 | 0.65s | 20min +GPUs | `batch_size` | batch time | epoch time | epoch cost +--- |---| --- | --- | --- + | (images) | (s/batch) | | +1 P100 | 16 | 0.39s | 48min | $0.39 +2 P100 | 32 | 0.48s | 29min | $0.47 +4 P100 | 64 | 0.65s | 20min | $0.65 +1 V100 | 16 | 0.25s | 31min | $0.41 +2 V100 | 32 | 0.29s | 18min | $0.48 +4 V100 | 64 | 0.41s | 13min | $0.70 +8 V100 | 128 | 0.49s | 7min | $0.80 # Inference diff --git a/detect.py b/detect.py index 1dbd4076..32c41c55 100644 --- a/detect.py +++ b/detect.py @@ -1,7 +1,5 @@ import argparse -import shutil import time -from pathlib import Path from sys import platform from models import * @@ -32,9 +30,9 @@ def detect( # Load weights if weights.endswith('.pt'): # pytorch format if weights.endswith('yolov3.pt') and not os.path.exists(weights): - if (platform == 'darwin') or (platform == 'linux'): + if platform in ('darwin', 'linux'): # linux/macos os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights) - model.load_state_dict(torch.load(weights, map_location='cpu')['model']) + model.load_state_dict(torch.load(weights, map_location=device)['model']) else: # darknet format _ = load_darknet_weights(model, weights) @@ -49,15 +47,15 @@ def detect( # Get classes and colors classes = load_classes(parse_data_cfg('cfg/coco.data')['names']) - colors = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))] + colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))] for i, (path, img, im0) in enumerate(dataloader): t = time.time() + save_path = str(Path(output) / Path(path).name) if webcam: print('webcam frame %g: ' % (i + 1), end='') else: print('image %g/%g %s: ' % (i + 1, len(dataloader), path), end='') - save_path = str(Path(output) / Path(path).name) # Get detections img = torch.from_numpy(img).unsqueeze(0).to(device) @@ -81,18 +79,16 @@ def detect( print('%g %ss' % (n, classes[int(c)]), end=', ') # Draw bounding boxes and labels of detections - for x1, y1, x2, y2, conf, cls_conf, cls in detections: + for *xyxy, conf, cls_conf, cls in detections: if save_txt: # Write to file with open(save_path + '.txt', 'a') as file: - file.write('%g %g %g %g %g %g\n' % - (x1, y1, x2, y2, cls, cls_conf * conf)) + file.write(('%g ' * 6 + '\n') % (*xyxy, cls, cls_conf * conf)) # Add bbox to the image label = '%s %.2f' % (classes[int(cls)], conf) - plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls)]) + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)]) - dt = time.time() - t - print('Done. (%.3fs)' % dt) + print('Done. (%.3fs)' % (time.time() - t)) if save_images: # Save generated image with detections cv2.imwrite(save_path, im0) @@ -100,7 +96,7 @@ def detect( if webcam: # Show live webcam cv2.imshow(weights, im0) - if save_images and (platform == 'darwin'): # linux/macos + if save_images and platform == 'darwin': # macos os.system('open ' + output + ' ' + save_path) diff --git a/requirements.txt b/requirements.txt index d934964e..f7d21d87 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ numpy opencv-python torch >= 1.0.0 matplotlib +pycocotools +tqdm diff --git a/test.py b/test.py index 5d0ca98a..c1cc3d2a 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,6 @@ import argparse import json import time -from pathlib import Path from torch.utils.data import DataLoader @@ -24,41 +23,44 @@ def test( ): device = torch_utils.select_device() - # Configure run - data_cfg_dict = parse_data_cfg(data_cfg) - nC = int(data_cfg_dict['classes']) # number of classes (80 for COCO) - test_path = data_cfg_dict['valid'] - if model is None: # Initialize model - model = Darknet(cfg, img_size) + model = Darknet(cfg, img_size).to(device) # Load weights if weights.endswith('.pt'): # pytorch format - model.load_state_dict(torch.load(weights, map_location='cpu')['model']) + model.load_state_dict(torch.load(weights, map_location=device)['model']) else: # darknet format _ = load_darknet_weights(model, weights) - model.to(device).eval() + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + + # Configure run + data_cfg = parse_data_cfg(data_cfg) + nC = int(data_cfg['classes']) # number of classes (80 for COCO) + test_path = data_cfg['valid'] # Dataloader dataset = LoadImagesAndLabels(test_path, img_size=img_size) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4) + dataloader = DataLoader(dataset, + batch_size=batch_size, + num_workers=4, + pin_memory=False, + collate_fn=dataset.collate_fn) + model.eval() mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0 print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP')) mP, mR, mAPs, TP, jdict = [], [], [], [], [] AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) coco91class = coco80_to_coco91_class() - for imgs, targets, paths, shapes in dataloader: - # Unpad and collate targets - for j, t in enumerate(targets): - t[:, 0] = j - targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1) - - targets = targets.to(device) + for imgs, targets, paths, shapes in tqdm(dataloader): t = time.time() - output = model(imgs.to(device)) + targets = targets.to(device) + imgs = imgs.to(device) + + output = model(imgs) output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) # Compute average precision for each sample @@ -78,7 +80,7 @@ def test( if save_json: # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... box = detections[:, :4].clone() # xyxy - scale_coords(img_size, box, (shapes[0][si], shapes[1][si])) # to original shape + scale_coords(img_size, box, shapes[si]) # to original shape box = xyxy2xywh(box) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner @@ -134,13 +136,13 @@ def test( mean_R = np.mean(mR) mean_mAP = np.mean(mAPs) - # Print image mAP and running mean mAP - print(('%11s%11s' + '%11.3g' * 4 + 's') % - (seen, len(dataset), mean_P, mean_R, mean_mAP, time.time() - t)) + # Print image mAP and running mean mAP + print(('%11s%11s' + '%11.3g' * 4 + 's') % + (seen, len(dataset), mean_P, mean_R, mean_mAP, time.time() - t)) # Print mAP per class print('\nmAP Per Class:') - for i, c in enumerate(load_classes(data_cfg_dict['names'])): + for i, c in enumerate(load_classes(data_cfg['names'])): if AP_accum_count[i]: print('%15s: %-.4f' % (c, AP_accum[i] / (AP_accum_count[i]))) @@ -191,4 +193,5 @@ if __name__ == '__main__': opt.iou_thres, opt.conf_thres, opt.nms_thres, - opt.save_json) + opt.save_json + ) diff --git a/train.py b/train.py index 0d66d674..3e983127 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import argparse import time +import torch.distributed as dist from torch.utils.data import DataLoader import test # Import test.py to get mAP after each epoch @@ -20,7 +21,7 @@ def train( accumulate=1, multi_scale=False, freeze_backbone=False, - num_workers=0 + num_workers=4 ): weights = 'weights' + os.sep latest = weights + 'latest.pt' @@ -40,7 +41,7 @@ def train( # Optimizer lr0 = 0.001 # initial learning rate - optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005) + optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9, weight_decay=0.0005) cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 @@ -65,52 +66,55 @@ def train( dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) model = torch.nn.parallel.DistributedDataParallel(model) - # Dataloader - dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) - if torch.cuda.device_count() > 1: - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - else: - train_sampler=None - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler) - # Transfer learning (train only YOLO layers) # for i, (name, p) in enumerate(model.named_parameters()): # p.requires_grad = True if (p.shape[0] == 255) else False - # Set scheduler - # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) + # Set scheduler (reduce lr at epoch 250) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1, last_epoch=start_epoch - 1) + + # Dataset + dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) + + # Initialize distributed training + if torch.cuda.device_count() > 1: + dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank) + model = torch.nn.parallel.DistributedDataParallel(model) + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + sampler = None + + # Dataloader + dataloader = DataLoader(dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + pin_memory=False, + collate_fn=dataset.collate_fn, + sampler=sampler) # Start training - t0 = time.time() + nB = len(dataloader) + t = time.time() model_info(model) - n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches - for epoch in range(epochs): + n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches + for epoch in range(start_epoch, epochs): model.train() - epoch += start_epoch + print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) - print(('\n%8s%12s' + '%10s' * 7) % ( - 'Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) - - # Update scheduler (automatic) - # scheduler.step() - - # Update scheduler (manual) - lr = lr0 / 10 if epoch > 250 else lr0 - for x in optimizer.param_groups: - x['lr'] = lr + # Update scheduler + scheduler.step() # Freeze backbone at epoch 0, unfreeze at epoch 1 if freeze_backbone and epoch < 2: - for i, (name, p) in enumerate(model.named_parameters()): + for name, p in model.named_parameters(): if int(name.split('.')[1]) < cutoff: # if layer < 75 - p.requires_grad = False if (epoch == 0) else True + p.requires_grad = False if epoch == 0 else True - rloss = defaultdict(float) + mloss = defaultdict(float) # mean loss for i, (imgs, targets, _, _) in enumerate(dataloader): - # Unpad and collate targets - for j, t in enumerate(targets): - t[:, 0] = j - targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1) + imgs = imgs.to(device) + targets = targets.to(device) nT = len(targets) if nT == 0: # if no targets continue @@ -119,25 +123,26 @@ def train( # Plot images with bounding boxes plot_images = False if plot_images: - import matplotlib.pyplot as plt - plt.figure(figsize=(10, 10)) + fig = plt.figure(figsize=(10, 10)) for ip in range(batch_size): labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0)) plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-') plt.axis('off') + fig.tight_layout() + fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi) # SGD burn-in - if (epoch == 0) and (i <= n_burnin): + if epoch == 0 and i <= n_burnin: lr = lr0 * (i / n_burnin) ** 4 for x in optimizer.param_groups: x['lr'] = lr # Run model - pred = model(imgs.to(device)) + pred = model(imgs) # Build targets - target_list = build_targets(model, targets.to(device), pred) + target_list = build_targets(model, targets) # Compute loss loss, loss_dict = compute_loss(pred, target_list) @@ -146,21 +151,19 @@ def train( loss.backward() # Accumulate gradient for x batches before optimizing - if (i + 1) % accumulate == 0 or (i + 1) == len(dataloader): + if (i + 1) % accumulate == 0 or (i + 1) == nB: optimizer.step() optimizer.zero_grad() # Running epoch-means of tracked metrics for key, val in loss_dict.items(): - rloss[key] = (rloss[key] * i + val) / (i + 1) + mloss[key] = (mloss[key] * i + val) / (i + 1) s = ('%8s%12s' + '%10.3g' * 7) % ( - '%g/%g' % (epoch, epochs - 1), - '%g/%g' % (i, len(dataloader) - 1), - rloss['xy'], rloss['wh'], rloss['conf'], - rloss['cls'], rloss['total'], - nT, time.time() - t0) - t0 = time.time() + '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1), + mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'], + mloss['total'], nT, time.time() - t) + t = time.time() print(s) # Multi-Scale training (320 - 608 pixels) every 10 batches @@ -169,8 +172,8 @@ def train( print('multi_scale img_size = %g' % dataset.img_size) # Update best loss - if rloss['total'] < best_loss: - best_loss = rloss['total'] + if mloss['total'] < best_loss: + best_loss = mloss['total'] # Save training results save = True @@ -178,23 +181,24 @@ def train( # Save latest checkpoint checkpoint = {'epoch': epoch, 'best_loss': best_loss, - 'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(), + 'model': model.module.state_dict() if type( + model) is nn.parallel.DistributedDataParallel else model.state_dict(), 'optimizer': optimizer.state_dict()} torch.save(checkpoint, latest) # Save best checkpoint - if best_loss == rloss['total']: + if best_loss == mloss['total']: os.system('cp ' + latest + ' ' + best) # Save backup weights every 5 epochs (optional) - if (epoch > 0) and (epoch % 5 == 0): - os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch)) + if epoch > 0 and epoch % 5 == 0: + os.system('cp ' + latest + ' ' + weights + 'backup%g.pt' % epoch) # Calculate mAP if type(model) is nn.parallel.DistributedDataParallel: model = model.module with torch.no_grad(): - P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model) + P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size) # Write epoch results with open('results.txt', 'a') as file: @@ -212,10 +216,10 @@ if __name__ == '__main__': parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers') - parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str,help='url used to set up distributed training') - parser.add_argument('--rank', default=0, type=int,help='node rank for distributed training') - parser.add_argument('--world-size', default=1, type=int,help='number of nodes for distributed training') - parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend') + parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method') + parser.add_argument('--rank', default=0, type=int, help='distributed training node rank') + parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') + parser.add_argument('--backend', default='nccl', type=str, help='distributed backend') opt = parser.parse_args() print(opt, end='\n\n') diff --git a/utils/datasets.py b/utils/datasets.py index 285e51f9..be127210 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -2,11 +2,14 @@ import glob import math import os import random +import shutil +from pathlib import Path import cv2 import numpy as np import torch from torch.utils.data import Dataset +from tqdm import tqdm from utils.utils import xyxy2xywh @@ -97,7 +100,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing assert len(self.img_files) > 0, 'No images found in %s' % path self.img_size = img_size self.augment = augment - self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') + self.label_files = [x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt') for x in self.img_files] def __len__(self): @@ -136,58 +139,61 @@ class LoadImagesAndLabels(Dataset): # for training/testing img, ratio, padw, padh = letterbox(img, height=self.img_size) # Load labels + labels = [] if os.path.isfile(label_path): with open(label_path, 'r') as file: lines = file.read().splitlines() x = np.array([x.split() for x in lines], dtype=np.float32) - if x.size is 0: - # Empty labels file - labels = np.array([]) - else: + if x.size > 0: # Normalized xywh to pixel xyxy format labels = x.copy() labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh - else: - labels = np.array([]) # Augment image and labels if self.augment: - img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) + img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) - nL = len(labels) - if nL > 0: + nL = len(labels) # number of labels + if nL: # convert xyxy to xywh labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size if self.augment: # random left-right flip lr_flip = True - if lr_flip & (random.random() > 0.5): + if lr_flip and random.random() > 0.5: img = np.fliplr(img) - if nL > 0: + if nL: labels[:, 1] = 1 - labels[:, 1] # random up-down flip ud_flip = False - if ud_flip & (random.random() > 0.5): + if ud_flip and random.random() > 0.5: img = np.flipud(img) - if nL > 0: + if nL: labels[:, 2] = 1 - labels[:, 2] - labels_out = np.zeros((100, 6), dtype=np.float32) - if nL > 0: - labels_out[:nL, 1:] = labels # max 100 labels per image + labels_out = torch.zeros((nL, 6)) + if nL: + labels_out[:, 1:] = torch.from_numpy(labels) # Normalize - img = img[:, :, ::-1].transpose(2, 0, 1) # list to np.array and BGR to RGB + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 - return torch.from_numpy(img), torch.from_numpy(labels_out), img_path, (h, w) + return torch.from_numpy(img), labels_out, img_path, (h, w) + + @staticmethod + def collate_fn(batch): + img, label, path, hw = list(zip(*batch)) # transposed + for i, l in enumerate(label): + l[:, 0] = i # add target image index for build_targets() + return torch.stack(img, 0), torch.cat(label, 0), path, hw def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square @@ -203,11 +209,13 @@ def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectang return img, ratio, dw, dh -def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), +def random_affine(img, targets=(), degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), borderValue=(127.5, 127.5, 127.5)): # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 + if targets is None: + targets = [] border = 0 # width of added border (optional) height = max(img.shape[0], img.shape[1]) + border * 2 @@ -233,52 +241,61 @@ def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scal borderValue=borderValue) # BGR order borderValue # Return warped points also - if targets is not None: - if len(targets) > 0: - n = targets.shape[0] - points = targets[:, 1:5].copy() - area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1]) + if len(targets) > 0: + n = targets.shape[0] + points = targets[:, 1:5].copy() + area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1]) - # warp points - xy = np.ones((n * 4, 3)) - xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 - xy = (xy @ M.T)[:, :2].reshape(n, 8) + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = (xy @ M.T)[:, :2].reshape(n, 8) - # create new boxes - x = xy[:, [0, 2, 4, 6]] - y = xy[:, [1, 3, 5, 7]] - xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T - # apply angle-based reduction - radians = a * math.pi / 180 - reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 - x = (xy[:, 2] + xy[:, 0]) / 2 - y = (xy[:, 3] + xy[:, 1]) / 2 - w = (xy[:, 2] - xy[:, 0]) * reduction - h = (xy[:, 3] - xy[:, 1]) * reduction - xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T + # apply angle-based reduction + radians = a * math.pi / 180 + reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 + x = (xy[:, 2] + xy[:, 0]) / 2 + y = (xy[:, 3] + xy[:, 1]) / 2 + w = (xy[:, 2] - xy[:, 0]) * reduction + h = (xy[:, 3] - xy[:, 1]) * reduction + xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T - # reject warped points outside of image - np.clip(xy, 0, height, out=xy) - w = xy[:, 2] - xy[:, 0] - h = xy[:, 3] - xy[:, 1] - area = w * h - ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) - i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) + # reject warped points outside of image + np.clip(xy, 0, height, out=xy) + w = xy[:, 2] - xy[:, 0] + h = xy[:, 3] - xy[:, 1] + area = w * h + ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) + i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) - targets = targets[i] - targets[:, 1:5] = xy[i] + targets = targets[i] + targets[:, 1:5] = xy[i] - return imw, targets, M - else: - return imw + return imw, targets -def convert_tif2bmp(p='../xview/val_images_bmp'): - import glob - import cv2 - files = sorted(glob.glob('%s/*.tif' % p)) - for i, f in enumerate(files): - print('%g/%g' % (i + 1, len(files))) - cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f)) - os.system('rm -rf ' + f) +def convert_images2bmp(): + # cv2.imread() jpg at 230 img/s, *.bmp at 400 img/s + for path in ['../coco/images/val2014/', '../coco/images/train2014/']: + folder = os.sep + Path(path).name + output = path.replace(folder, folder + 'bmp') + if os.path.exists(output): + shutil.rmtree(output) # delete output folder + os.makedirs(output) # make new output folder + + for f in tqdm(glob.glob('%s*.jpg' % path)): + save_name = f.replace('.jpg', '.bmp').replace(folder, folder + 'bmp') + cv2.imwrite(save_name, cv2.imread(f)) + + for label_path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']: + with open(label_path, 'r') as file: + lines = file.read() + lines = lines.replace('2014/', '2014bmp/').replace('.jpg', '.bmp').replace( + '/Users/glennjocher/PycharmProjects/', '../') + with open(label_path.replace('5k', '5k_bmp'), 'w') as file: + file.write(lines) diff --git a/utils/gcp.sh b/utils/gcp.sh index 50274913..a3372631 100755 --- a/utils/gcp.sh +++ b/utils/gcp.sh @@ -3,13 +3,14 @@ # New VM sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 bash yolov3/data/get_coco_dataset.sh +bash yolov3/weights/download_yolov3_weights.sh sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3 sudo shutdown # Train sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 cp -r weights yolov3 -cd yolov3 && python3 train.py --batch-size 16 --epochs 1 +cd yolov3 && python3 train.py --batch-size 48 --epochs 1 sudo shutdown # Resume @@ -20,11 +21,17 @@ python3 detect.py # Clone a branch sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3 +cp -r weights yolov3 +cd yolov3 && python3 train.py --batch-size 48 --epochs 1 +sudo shutdown + +# Git pull branch +git pull https://github.com/ultralytics/yolov3 multi_gpu # Test sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3 -cd yolov3 && python3 test.py --save-json --conf-thres 0.005 +cd yolov3 && python3 test.py --save-json --conf-thres 0.001 --img-size 416 # Test Darknet training python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup @@ -33,7 +40,7 @@ python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup wget https://storage.googleapis.com/ultralytics/yolov3.pt -O weights/latest.pt # Copy latest.pt to bucket -gsutil cp yolov3/weights/latest.pt gs://ultralytics +gsutil cp yolov3/weights/latest1gpu.pt gs://ultralytics # Copy latest.pt from bucket gsutil cp gs://ultralytics/latest.pt yolov3/weights/latest.pt diff --git a/utils/utils.py b/utils/utils.py index 76b25df2..2233f88f 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -95,7 +95,7 @@ def weights_init_normal(m): def xyxy2xywh(x): # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h] - y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x) + y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 y[:, 1] = (x[:, 1] + x[:, 3]) / 2 y[:, 2] = x[:, 2] - x[:, 0] @@ -105,7 +105,7 @@ def xyxy2xywh(x): def xywh2xyxy(x): # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] - y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x) + y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) y[:, 0] = (x[:, 0] - x[:, 2] / 2) y[:, 1] = (x[:, 1] - x[:, 3] / 2) y[:, 2] = (x[:, 0] + x[:, 2] / 2) @@ -251,7 +251,7 @@ def wh_iou(box1, box2): def compute_loss(p, targets): # predictions, targets FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor loss, lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) - txy, twh, tcls, tconf, indices = targets + txy, twh, tcls, indices = targets MSE = nn.MSELoss() CE = nn.CrossEntropyLoss() BCE = nn.BCEWithLogitsLoss() @@ -260,18 +260,21 @@ def compute_loss(p, targets): # predictions, targets # gp = [x.numel() for x in tconf] # grid points for i, pi0 in enumerate(p): # layer i predictions, i b, a, gj, gi = indices[i] # image, anchor, gridx, gridy + tconf = torch.zeros_like(pi0[..., 0]) # conf # Compute losses k = 1 # nT / bs if len(b) > 0: pi = pi0[b, a, gj, gi] # predictions closest to anchors + tconf[b, a, gj, gi] = 1 # conf + lxy += k * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy lwh += k * MSE(pi[..., 2:4], twh[i]) # wh lcls += (k / 4) * CE(pi[..., 5:], tcls[i]) # pos_weight = FT([gp[i] / min(gp) * 4.]) # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) - lconf += (k * 64) * BCE(pi0[..., 4], tconf[i]) + lconf += (k * 64) * BCE(pi0[..., 4], tconf) loss = lxy + lwh + lconf + lcls # Add to dictionary @@ -283,15 +286,13 @@ def compute_loss(p, targets): # predictions, targets return loss, d -def build_targets(model, targets, pred): +def build_targets(model, targets): # targets = [image, class, x, y, w, h] if isinstance(model, nn.parallel.DistributedDataParallel): model = model.module - yolo_layers = get_yolo_layers(model) - # anchors = closest_anchor(model, targets) # [layer, anchor, i, j] - txy, twh, tcls, tconf, indices = [], [], [], [], [] - for i, layer in enumerate(yolo_layers): + txy, twh, tcls, indices = [], [], [], [] + for i, layer in enumerate(get_yolo_layers(model)): nG = model.module_list[layer][0].nG # grid size anchor_vec = model.module_list[layer][0].anchor_vec @@ -324,12 +325,7 @@ def build_targets(model, targets, pred): # Class tcls.append(c) - # Conf - tci = torch.zeros_like(pred[i][..., 0]) - tci[b, a, gj, gi] = 1 # conf - tconf.append(tci) - - return txy, twh, tcls, tconf, indices + return txy, twh, tcls, indices def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): @@ -439,15 +435,6 @@ def get_yolo_layers(model): return [i for i, x in enumerate(bool_vec) if x] # [82, 94, 106] for yolov3 -def return_torch_unique_index(u, uv): - n = uv.shape[1] # number of columns - first_unique = torch.zeros(n, device=u.device).long() - for j in range(n): - first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0] - - return first_unique - - def strip_optimizer_from_checkpoint(filename='weights/best.pt'): # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size) a = torch.load(filename, map_location='cpu') @@ -480,10 +467,9 @@ def plot_results(start=0): # import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt') # from utils.utils import *; plot_results() - plt.figure(figsize=(14, 7)) + fig = plt.figure(figsize=(14, 7)) s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP'] - files = sorted(glob.glob('results*.txt')) - for f in files: + for f in sorted(glob.glob('results*.txt')): results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP x = range(1, results.shape[1]) for i in range(8): @@ -492,3 +478,4 @@ def plot_results(start=0): plt.title(s[i]) if i == 0: plt.legend() + fig.tight_layout() diff --git a/weights/download_yolov3_weights.sh b/weights/download_yolov3_weights.sh index 0568cb87..fe6213aa 100644 --- a/weights/download_yolov3_weights.sh +++ b/weights/download_yolov3_weights.sh @@ -18,3 +18,4 @@ wget -c https://pjreddie.com/media/files/darknet53.conv.74 # ./darknet partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15 # mv yolov3-tiny.conv.15 ../ +cd ..