Add collate_fn() to DataLoader (#163)

Multi-GPU update with custom collate function to allow variable size target vector per image without needing to pad targets.
This commit is contained in:
Glenn Jocher 2019-03-25 14:59:38 +01:00 committed by GitHub
parent 49ae0a55b1
commit cd51e1137b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 217 additions and 194 deletions

View File

@ -17,7 +17,7 @@ This directory contains python software and an iOS App developed by Ultralytics
# Description # Description
The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO** (https://pjreddie.com/darknet/yolo/) and to **Erik Lindernoren for the PyTorch implementation** this work is based on (https://github.com/eriklindernoren/PyTorch-YOLOv3). The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO: ** https://pjreddie.com/darknet/yolo/.
# Requirements # Requirements
@ -26,6 +26,7 @@ Python 3.7 or later with the following `pip3 install -U -r requirements.txt` pac
- `numpy` - `numpy`
- `torch >= 1.0.0` - `torch >= 1.0.0`
- `opencv-python` - `opencv-python`
- `tqdm`
# Tutorials # Tutorials
@ -66,15 +67,20 @@ HS**V** Intensity | +/- 50%
https://cloud.google.com/deep-learning-vm/ https://cloud.google.com/deep-learning-vm/
**Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory) **Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory)
**CPU platform:** Intel Skylake **CPU platform:** Intel Skylake
**GPUs:** 1-4 x NVIDIA Tesla P100 **GPUs:** 1-4x P100 ($0.493/hr), 1-8x V100 ($0.803/hr)
**HDD:** 100 GB SSD **HDD:** 100 GB SSD
**Dataset:** COCO train 2014
GPUs | `batch_size` | speed | COCO epoch GPUs | `batch_size` | batch time | epoch time | epoch cost
--- |---| --- | --- --- |---| --- | --- | ---
(P100) | (images) | (s/batch) | (min/epoch) <i></i> | (images) | (s/batch) | |
1 | 16 | 0.39s | 48min 1 P100 | 16 | 0.39s | 48min | $0.39
2 | 32 | 0.48s | 29min 2 P100 | 32 | 0.48s | 29min | $0.47
4 | 64 | 0.65s | 20min 4 P100 | 64 | 0.65s | 20min | $0.65
1 V100 | 16 | 0.25s | 31min | $0.41
2 V100 | 32 | 0.29s | 18min | $0.48
4 V100 | 64 | 0.41s | 13min | $0.70
8 V100 | 128 | 0.49s | 7min | $0.80
# Inference # Inference

View File

@ -1,7 +1,5 @@
import argparse import argparse
import shutil
import time import time
from pathlib import Path
from sys import platform from sys import platform
from models import * from models import *
@ -32,9 +30,9 @@ def detect(
# Load weights # Load weights
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
if weights.endswith('yolov3.pt') and not os.path.exists(weights): if weights.endswith('yolov3.pt') and not os.path.exists(weights):
if (platform == 'darwin') or (platform == 'linux'): if platform in ('darwin', 'linux'): # linux/macos
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights) os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights)
model.load_state_dict(torch.load(weights, map_location='cpu')['model']) model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format else: # darknet format
_ = load_darknet_weights(model, weights) _ = load_darknet_weights(model, weights)
@ -49,15 +47,15 @@ def detect(
# Get classes and colors # Get classes and colors
classes = load_classes(parse_data_cfg('cfg/coco.data')['names']) classes = load_classes(parse_data_cfg('cfg/coco.data')['names'])
colors = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))] colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
for i, (path, img, im0) in enumerate(dataloader): for i, (path, img, im0) in enumerate(dataloader):
t = time.time() t = time.time()
save_path = str(Path(output) / Path(path).name)
if webcam: if webcam:
print('webcam frame %g: ' % (i + 1), end='') print('webcam frame %g: ' % (i + 1), end='')
else: else:
print('image %g/%g %s: ' % (i + 1, len(dataloader), path), end='') print('image %g/%g %s: ' % (i + 1, len(dataloader), path), end='')
save_path = str(Path(output) / Path(path).name)
# Get detections # Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device) img = torch.from_numpy(img).unsqueeze(0).to(device)
@ -81,18 +79,16 @@ def detect(
print('%g %ss' % (n, classes[int(c)]), end=', ') print('%g %ss' % (n, classes[int(c)]), end=', ')
# Draw bounding boxes and labels of detections # Draw bounding boxes and labels of detections
for x1, y1, x2, y2, conf, cls_conf, cls in detections: for *xyxy, conf, cls_conf, cls in detections:
if save_txt: # Write to file if save_txt: # Write to file
with open(save_path + '.txt', 'a') as file: with open(save_path + '.txt', 'a') as file:
file.write('%g %g %g %g %g %g\n' % file.write(('%g ' * 6 + '\n') % (*xyxy, cls, cls_conf * conf))
(x1, y1, x2, y2, cls, cls_conf * conf))
# Add bbox to the image # Add bbox to the image
label = '%s %.2f' % (classes[int(cls)], conf) label = '%s %.2f' % (classes[int(cls)], conf)
plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls)]) plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
dt = time.time() - t print('Done. (%.3fs)' % (time.time() - t))
print('Done. (%.3fs)' % dt)
if save_images: # Save generated image with detections if save_images: # Save generated image with detections
cv2.imwrite(save_path, im0) cv2.imwrite(save_path, im0)
@ -100,7 +96,7 @@ def detect(
if webcam: # Show live webcam if webcam: # Show live webcam
cv2.imshow(weights, im0) cv2.imshow(weights, im0)
if save_images and (platform == 'darwin'): # linux/macos if save_images and platform == 'darwin': # macos
os.system('open ' + output + ' ' + save_path) os.system('open ' + output + ' ' + save_path)

View File

@ -4,3 +4,5 @@ numpy
opencv-python opencv-python
torch >= 1.0.0 torch >= 1.0.0
matplotlib matplotlib
pycocotools
tqdm

45
test.py
View File

@ -1,7 +1,6 @@
import argparse import argparse
import json import json
import time import time
from pathlib import Path
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -24,41 +23,44 @@ def test(
): ):
device = torch_utils.select_device() device = torch_utils.select_device()
# Configure run
data_cfg_dict = parse_data_cfg(data_cfg)
nC = int(data_cfg_dict['classes']) # number of classes (80 for COCO)
test_path = data_cfg_dict['valid']
if model is None: if model is None:
# Initialize model # Initialize model
model = Darknet(cfg, img_size) model = Darknet(cfg, img_size).to(device)
# Load weights # Load weights
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location='cpu')['model']) model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format else: # darknet format
_ = load_darknet_weights(model, weights) _ = load_darknet_weights(model, weights)
model.to(device).eval() if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
# Configure run
data_cfg = parse_data_cfg(data_cfg)
nC = int(data_cfg['classes']) # number of classes (80 for COCO)
test_path = data_cfg['valid']
# Dataloader # Dataloader
dataset = LoadImagesAndLabels(test_path, img_size=img_size) dataset = LoadImagesAndLabels(test_path, img_size=img_size)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4) dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=4,
pin_memory=False,
collate_fn=dataset.collate_fn)
model.eval()
mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0 mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP')) print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP'))
mP, mR, mAPs, TP, jdict = [], [], [], [], [] mP, mR, mAPs, TP, jdict = [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
coco91class = coco80_to_coco91_class() coco91class = coco80_to_coco91_class()
for imgs, targets, paths, shapes in dataloader: for imgs, targets, paths, shapes in tqdm(dataloader):
# Unpad and collate targets
for j, t in enumerate(targets):
t[:, 0] = j
targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1)
targets = targets.to(device)
t = time.time() t = time.time()
output = model(imgs.to(device)) targets = targets.to(device)
imgs = imgs.to(device)
output = model(imgs)
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
# Compute average precision for each sample # Compute average precision for each sample
@ -78,7 +80,7 @@ def test(
if save_json: if save_json:
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
box = detections[:, :4].clone() # xyxy box = detections[:, :4].clone() # xyxy
scale_coords(img_size, box, (shapes[0][si], shapes[1][si])) # to original shape scale_coords(img_size, box, shapes[si]) # to original shape
box = xyxy2xywh(box) # xywh box = xyxy2xywh(box) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
@ -140,7 +142,7 @@ def test(
# Print mAP per class # Print mAP per class
print('\nmAP Per Class:') print('\nmAP Per Class:')
for i, c in enumerate(load_classes(data_cfg_dict['names'])): for i, c in enumerate(load_classes(data_cfg['names'])):
if AP_accum_count[i]: if AP_accum_count[i]:
print('%15s: %-.4f' % (c, AP_accum[i] / (AP_accum_count[i]))) print('%15s: %-.4f' % (c, AP_accum[i] / (AP_accum_count[i])))
@ -191,4 +193,5 @@ if __name__ == '__main__':
opt.iou_thres, opt.iou_thres,
opt.conf_thres, opt.conf_thres,
opt.nms_thres, opt.nms_thres,
opt.save_json) opt.save_json
)

114
train.py
View File

@ -1,6 +1,7 @@
import argparse import argparse
import time import time
import torch.distributed as dist
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import test # Import test.py to get mAP after each epoch import test # Import test.py to get mAP after each epoch
@ -20,7 +21,7 @@ def train(
accumulate=1, accumulate=1,
multi_scale=False, multi_scale=False,
freeze_backbone=False, freeze_backbone=False,
num_workers=0 num_workers=4
): ):
weights = 'weights' + os.sep weights = 'weights' + os.sep
latest = weights + 'latest.pt' latest = weights + 'latest.pt'
@ -65,52 +66,55 @@ def train(
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
if torch.cuda.device_count() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler=None
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
# Transfer learning (train only YOLO layers) # Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()): # for i, (name, p) in enumerate(model.named_parameters()):
# p.requires_grad = True if (p.shape[0] == 255) else False # p.requires_grad = True if (p.shape[0] == 255) else False
# Set scheduler # Set scheduler (reduce lr at epoch 250)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1, last_epoch=start_epoch - 1)
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
# Initialize distributed training
if torch.cuda.device_count() > 1:
dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = None
# Dataloader
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=False,
collate_fn=dataset.collate_fn,
sampler=sampler)
# Start training # Start training
t0 = time.time() nB = len(dataloader)
t = time.time()
model_info(model) model_info(model)
n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
for epoch in range(epochs): for epoch in range(start_epoch, epochs):
model.train() model.train()
epoch += start_epoch print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
print(('\n%8s%12s' + '%10s' * 7) % ( # Update scheduler
'Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) scheduler.step()
# Update scheduler (automatic)
# scheduler.step()
# Update scheduler (manual)
lr = lr0 / 10 if epoch > 250 else lr0
for x in optimizer.param_groups:
x['lr'] = lr
# Freeze backbone at epoch 0, unfreeze at epoch 1 # Freeze backbone at epoch 0, unfreeze at epoch 1
if freeze_backbone and epoch < 2: if freeze_backbone and epoch < 2:
for i, (name, p) in enumerate(model.named_parameters()): for name, p in model.named_parameters():
if int(name.split('.')[1]) < cutoff: # if layer < 75 if int(name.split('.')[1]) < cutoff: # if layer < 75
p.requires_grad = False if (epoch == 0) else True p.requires_grad = False if epoch == 0 else True
rloss = defaultdict(float) mloss = defaultdict(float) # mean loss
for i, (imgs, targets, _, _) in enumerate(dataloader): for i, (imgs, targets, _, _) in enumerate(dataloader):
# Unpad and collate targets imgs = imgs.to(device)
for j, t in enumerate(targets): targets = targets.to(device)
t[:, 0] = j
targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1)
nT = len(targets) nT = len(targets)
if nT == 0: # if no targets continue if nT == 0: # if no targets continue
@ -119,25 +123,26 @@ def train(
# Plot images with bounding boxes # Plot images with bounding boxes
plot_images = False plot_images = False
if plot_images: if plot_images:
import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 10))
plt.figure(figsize=(10, 10))
for ip in range(batch_size): for ip in range(batch_size):
labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size
plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0)) plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0))
plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-') plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-')
plt.axis('off') plt.axis('off')
fig.tight_layout()
fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi)
# SGD burn-in # SGD burn-in
if (epoch == 0) and (i <= n_burnin): if epoch == 0 and i <= n_burnin:
lr = lr0 * (i / n_burnin) ** 4 lr = lr0 * (i / n_burnin) ** 4
for x in optimizer.param_groups: for x in optimizer.param_groups:
x['lr'] = lr x['lr'] = lr
# Run model # Run model
pred = model(imgs.to(device)) pred = model(imgs)
# Build targets # Build targets
target_list = build_targets(model, targets.to(device), pred) target_list = build_targets(model, targets)
# Compute loss # Compute loss
loss, loss_dict = compute_loss(pred, target_list) loss, loss_dict = compute_loss(pred, target_list)
@ -146,21 +151,19 @@ def train(
loss.backward() loss.backward()
# Accumulate gradient for x batches before optimizing # Accumulate gradient for x batches before optimizing
if (i + 1) % accumulate == 0 or (i + 1) == len(dataloader): if (i + 1) % accumulate == 0 or (i + 1) == nB:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# Running epoch-means of tracked metrics # Running epoch-means of tracked metrics
for key, val in loss_dict.items(): for key, val in loss_dict.items():
rloss[key] = (rloss[key] * i + val) / (i + 1) mloss[key] = (mloss[key] * i + val) / (i + 1)
s = ('%8s%12s' + '%10.3g' * 7) % ( s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1),
'%g/%g' % (i, len(dataloader) - 1), mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'],
rloss['xy'], rloss['wh'], rloss['conf'], mloss['total'], nT, time.time() - t)
rloss['cls'], rloss['total'], t = time.time()
nT, time.time() - t0)
t0 = time.time()
print(s) print(s)
# Multi-Scale training (320 - 608 pixels) every 10 batches # Multi-Scale training (320 - 608 pixels) every 10 batches
@ -169,8 +172,8 @@ def train(
print('multi_scale img_size = %g' % dataset.img_size) print('multi_scale img_size = %g' % dataset.img_size)
# Update best loss # Update best loss
if rloss['total'] < best_loss: if mloss['total'] < best_loss:
best_loss = rloss['total'] best_loss = mloss['total']
# Save training results # Save training results
save = True save = True
@ -178,23 +181,24 @@ def train(
# Save latest checkpoint # Save latest checkpoint
checkpoint = {'epoch': epoch, checkpoint = {'epoch': epoch,
'best_loss': best_loss, 'best_loss': best_loss,
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(), 'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest) torch.save(checkpoint, latest)
# Save best checkpoint # Save best checkpoint
if best_loss == rloss['total']: if best_loss == mloss['total']:
os.system('cp ' + latest + ' ' + best) os.system('cp ' + latest + ' ' + best)
# Save backup weights every 5 epochs (optional) # Save backup weights every 5 epochs (optional)
if (epoch > 0) and (epoch % 5 == 0): if epoch > 0 and epoch % 5 == 0:
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch)) os.system('cp ' + latest + ' ' + weights + 'backup%g.pt' % epoch)
# Calculate mAP # Calculate mAP
if type(model) is nn.parallel.DistributedDataParallel: if type(model) is nn.parallel.DistributedDataParallel:
model = model.module model = model.module
with torch.no_grad(): with torch.no_grad():
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model) P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
# Write epoch results # Write epoch results
with open('results.txt', 'a') as file: with open('results.txt', 'a') as file:
@ -212,10 +216,10 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers') parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str,help='url used to set up distributed training') parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
parser.add_argument('--rank', default=0, type=int,help='node rank for distributed training') parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend') parser.add_argument('--backend', default='nccl', type=str, help='distributed backend')
opt = parser.parse_args() opt = parser.parse_args()
print(opt, end='\n\n') print(opt, end='\n\n')

View File

@ -2,11 +2,14 @@ import glob
import math import math
import os import os
import random import random
import shutil
from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm
from utils.utils import xyxy2xywh from utils.utils import xyxy2xywh
@ -97,7 +100,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
assert len(self.img_files) > 0, 'No images found in %s' % path assert len(self.img_files) > 0, 'No images found in %s' % path
self.img_size = img_size self.img_size = img_size
self.augment = augment self.augment = augment
self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') self.label_files = [x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt')
for x in self.img_files] for x in self.img_files]
def __len__(self): def __len__(self):
@ -136,58 +139,61 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img, ratio, padw, padh = letterbox(img, height=self.img_size) img, ratio, padw, padh = letterbox(img, height=self.img_size)
# Load labels # Load labels
labels = []
if os.path.isfile(label_path): if os.path.isfile(label_path):
with open(label_path, 'r') as file: with open(label_path, 'r') as file:
lines = file.read().splitlines() lines = file.read().splitlines()
x = np.array([x.split() for x in lines], dtype=np.float32) x = np.array([x.split() for x in lines], dtype=np.float32)
if x.size is 0: if x.size > 0:
# Empty labels file
labels = np.array([])
else:
# Normalized xywh to pixel xyxy format # Normalized xywh to pixel xyxy format
labels = x.copy() labels = x.copy()
labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
else:
labels = np.array([])
# Augment image and labels # Augment image and labels
if self.augment: if self.augment:
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))
nL = len(labels) nL = len(labels) # number of labels
if nL > 0: if nL:
# convert xyxy to xywh # convert xyxy to xywh
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size
if self.augment: if self.augment:
# random left-right flip # random left-right flip
lr_flip = True lr_flip = True
if lr_flip & (random.random() > 0.5): if lr_flip and random.random() > 0.5:
img = np.fliplr(img) img = np.fliplr(img)
if nL > 0: if nL:
labels[:, 1] = 1 - labels[:, 1] labels[:, 1] = 1 - labels[:, 1]
# random up-down flip # random up-down flip
ud_flip = False ud_flip = False
if ud_flip & (random.random() > 0.5): if ud_flip and random.random() > 0.5:
img = np.flipud(img) img = np.flipud(img)
if nL > 0: if nL:
labels[:, 2] = 1 - labels[:, 2] labels[:, 2] = 1 - labels[:, 2]
labels_out = np.zeros((100, 6), dtype=np.float32) labels_out = torch.zeros((nL, 6))
if nL > 0: if nL:
labels_out[:nL, 1:] = labels # max 100 labels per image labels_out[:, 1:] = torch.from_numpy(labels)
# Normalize # Normalize
img = img[:, :, ::-1].transpose(2, 0, 1) # list to np.array and BGR to RGB img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32 img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32
img /= 255.0 # 0 - 255 to 0.0 - 1.0 img /= 255.0 # 0 - 255 to 0.0 - 1.0
return torch.from_numpy(img), torch.from_numpy(labels_out), img_path, (h, w) return torch.from_numpy(img), labels_out, img_path, (h, w)
@staticmethod
def collate_fn(batch):
img, label, path, hw = list(zip(*batch)) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, hw
def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square
@ -203,11 +209,13 @@ def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectang
return img, ratio, dw, dh return img, ratio, dw, dh
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), def random_affine(img, targets=(), degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)): borderValue=(127.5, 127.5, 127.5)):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
if targets is None:
targets = []
border = 0 # width of added border (optional) border = 0 # width of added border (optional)
height = max(img.shape[0], img.shape[1]) + border * 2 height = max(img.shape[0], img.shape[1]) + border * 2
@ -233,7 +241,6 @@ def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scal
borderValue=borderValue) # BGR order borderValue borderValue=borderValue) # BGR order borderValue
# Return warped points also # Return warped points also
if targets is not None:
if len(targets) > 0: if len(targets) > 0:
n = targets.shape[0] n = targets.shape[0]
points = targets[:, 1:5].copy() points = targets[:, 1:5].copy()
@ -269,16 +276,26 @@ def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scal
targets = targets[i] targets = targets[i]
targets[:, 1:5] = xy[i] targets[:, 1:5] = xy[i]
return imw, targets, M return imw, targets
else:
return imw
def convert_tif2bmp(p='../xview/val_images_bmp'): def convert_images2bmp():
import glob # cv2.imread() jpg at 230 img/s, *.bmp at 400 img/s
import cv2 for path in ['../coco/images/val2014/', '../coco/images/train2014/']:
files = sorted(glob.glob('%s/*.tif' % p)) folder = os.sep + Path(path).name
for i, f in enumerate(files): output = path.replace(folder, folder + 'bmp')
print('%g/%g' % (i + 1, len(files))) if os.path.exists(output):
cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f)) shutil.rmtree(output) # delete output folder
os.system('rm -rf ' + f) os.makedirs(output) # make new output folder
for f in tqdm(glob.glob('%s*.jpg' % path)):
save_name = f.replace('.jpg', '.bmp').replace(folder, folder + 'bmp')
cv2.imwrite(save_name, cv2.imread(f))
for label_path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
with open(label_path, 'r') as file:
lines = file.read()
lines = lines.replace('2014/', '2014bmp/').replace('.jpg', '.bmp').replace(
'/Users/glennjocher/PycharmProjects/', '../')
with open(label_path.replace('5k', '5k_bmp'), 'w') as file:
file.write(lines)

View File

@ -3,13 +3,14 @@
# New VM # New VM
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset.sh bash yolov3/data/get_coco_dataset.sh
bash yolov3/weights/download_yolov3_weights.sh
sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3 sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3
sudo shutdown sudo shutdown
# Train # Train
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
cp -r weights yolov3 cp -r weights yolov3
cd yolov3 && python3 train.py --batch-size 16 --epochs 1 cd yolov3 && python3 train.py --batch-size 48 --epochs 1
sudo shutdown sudo shutdown
# Resume # Resume
@ -20,11 +21,17 @@ python3 detect.py
# Clone a branch # Clone a branch
sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3
cp -r weights yolov3
cd yolov3 && python3 train.py --batch-size 48 --epochs 1
sudo shutdown
# Git pull branch
git pull https://github.com/ultralytics/yolov3 multi_gpu
# Test # Test
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3 sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3
cd yolov3 && python3 test.py --save-json --conf-thres 0.005 cd yolov3 && python3 test.py --save-json --conf-thres 0.001 --img-size 416
# Test Darknet training # Test Darknet training
python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup
@ -33,7 +40,7 @@ python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup
wget https://storage.googleapis.com/ultralytics/yolov3.pt -O weights/latest.pt wget https://storage.googleapis.com/ultralytics/yolov3.pt -O weights/latest.pt
# Copy latest.pt to bucket # Copy latest.pt to bucket
gsutil cp yolov3/weights/latest.pt gs://ultralytics gsutil cp yolov3/weights/latest1gpu.pt gs://ultralytics
# Copy latest.pt from bucket # Copy latest.pt from bucket
gsutil cp gs://ultralytics/latest.pt yolov3/weights/latest.pt gsutil cp gs://ultralytics/latest.pt yolov3/weights/latest.pt

View File

@ -95,7 +95,7 @@ def weights_init_normal(m):
def xyxy2xywh(x): def xyxy2xywh(x):
# Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h] # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]
y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x) y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 y[:, 0] = (x[:, 0] + x[:, 2]) / 2
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 y[:, 1] = (x[:, 1] + x[:, 3]) / 2
y[:, 2] = x[:, 2] - x[:, 0] y[:, 2] = x[:, 2] - x[:, 0]
@ -105,7 +105,7 @@ def xyxy2xywh(x):
def xywh2xyxy(x): def xywh2xyxy(x):
# Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]
y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x) y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y[:, 0] = (x[:, 0] - x[:, 2] / 2) y[:, 0] = (x[:, 0] - x[:, 2] / 2)
y[:, 1] = (x[:, 1] - x[:, 3] / 2) y[:, 1] = (x[:, 1] - x[:, 3] / 2)
y[:, 2] = (x[:, 0] + x[:, 2] / 2) y[:, 2] = (x[:, 0] + x[:, 2] / 2)
@ -251,7 +251,7 @@ def wh_iou(box1, box2):
def compute_loss(p, targets): # predictions, targets def compute_loss(p, targets): # predictions, targets
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
loss, lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) loss, lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
txy, twh, tcls, tconf, indices = targets txy, twh, tcls, indices = targets
MSE = nn.MSELoss() MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss() CE = nn.CrossEntropyLoss()
BCE = nn.BCEWithLogitsLoss() BCE = nn.BCEWithLogitsLoss()
@ -260,18 +260,21 @@ def compute_loss(p, targets): # predictions, targets
# gp = [x.numel() for x in tconf] # grid points # gp = [x.numel() for x in tconf] # grid points
for i, pi0 in enumerate(p): # layer i predictions, i for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
tconf = torch.zeros_like(pi0[..., 0]) # conf
# Compute losses # Compute losses
k = 1 # nT / bs k = 1 # nT / bs
if len(b) > 0: if len(b) > 0:
pi = pi0[b, a, gj, gi] # predictions closest to anchors pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf
lxy += k * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy lxy += k * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy
lwh += k * MSE(pi[..., 2:4], twh[i]) # wh lwh += k * MSE(pi[..., 2:4], twh[i]) # wh
lcls += (k / 4) * CE(pi[..., 5:], tcls[i]) lcls += (k / 4) * CE(pi[..., 5:], tcls[i])
# pos_weight = FT([gp[i] / min(gp) * 4.]) # pos_weight = FT([gp[i] / min(gp) * 4.])
# BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
lconf += (k * 64) * BCE(pi0[..., 4], tconf[i]) lconf += (k * 64) * BCE(pi0[..., 4], tconf)
loss = lxy + lwh + lconf + lcls loss = lxy + lwh + lconf + lcls
# Add to dictionary # Add to dictionary
@ -283,15 +286,13 @@ def compute_loss(p, targets): # predictions, targets
return loss, d return loss, d
def build_targets(model, targets, pred): def build_targets(model, targets):
# targets = [image, class, x, y, w, h] # targets = [image, class, x, y, w, h]
if isinstance(model, nn.parallel.DistributedDataParallel): if isinstance(model, nn.parallel.DistributedDataParallel):
model = model.module model = model.module
yolo_layers = get_yolo_layers(model)
# anchors = closest_anchor(model, targets) # [layer, anchor, i, j] txy, twh, tcls, indices = [], [], [], []
txy, twh, tcls, tconf, indices = [], [], [], [], [] for i, layer in enumerate(get_yolo_layers(model)):
for i, layer in enumerate(yolo_layers):
nG = model.module_list[layer][0].nG # grid size nG = model.module_list[layer][0].nG # grid size
anchor_vec = model.module_list[layer][0].anchor_vec anchor_vec = model.module_list[layer][0].anchor_vec
@ -324,12 +325,7 @@ def build_targets(model, targets, pred):
# Class # Class
tcls.append(c) tcls.append(c)
# Conf return txy, twh, tcls, indices
tci = torch.zeros_like(pred[i][..., 0])
tci[b, a, gj, gi] = 1 # conf
tconf.append(tci)
return txy, twh, tcls, tconf, indices
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
@ -439,15 +435,6 @@ def get_yolo_layers(model):
return [i for i, x in enumerate(bool_vec) if x] # [82, 94, 106] for yolov3 return [i for i, x in enumerate(bool_vec) if x] # [82, 94, 106] for yolov3
def return_torch_unique_index(u, uv):
n = uv.shape[1] # number of columns
first_unique = torch.zeros(n, device=u.device).long()
for j in range(n):
first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0]
return first_unique
def strip_optimizer_from_checkpoint(filename='weights/best.pt'): def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size) # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
a = torch.load(filename, map_location='cpu') a = torch.load(filename, map_location='cpu')
@ -480,10 +467,9 @@ def plot_results(start=0):
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt') # import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
# from utils.utils import *; plot_results() # from utils.utils import *; plot_results()
plt.figure(figsize=(14, 7)) fig = plt.figure(figsize=(14, 7))
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP'] s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
files = sorted(glob.glob('results*.txt')) for f in sorted(glob.glob('results*.txt')):
for f in files:
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
x = range(1, results.shape[1]) x = range(1, results.shape[1])
for i in range(8): for i in range(8):
@ -492,3 +478,4 @@ def plot_results(start=0):
plt.title(s[i]) plt.title(s[i])
if i == 0: if i == 0:
plt.legend() plt.legend()
fig.tight_layout()

View File

@ -18,3 +18,4 @@ wget -c https://pjreddie.com/media/files/darknet53.conv.74
# ./darknet partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15 # ./darknet partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15
# mv yolov3-tiny.conv.15 ../ # mv yolov3-tiny.conv.15 ../
cd ..