Add collate_fn() to DataLoader (#163)

Multi-GPU update with custom collate function to allow variable size target vector per image without needing to pad targets.
This commit is contained in:
Glenn Jocher 2019-03-25 14:59:38 +01:00 committed by GitHub
parent 49ae0a55b1
commit cd51e1137b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 217 additions and 194 deletions

View File

@ -17,7 +17,7 @@ This directory contains python software and an iOS App developed by Ultralytics
# Description
The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO** (https://pjreddie.com/darknet/yolo/) and to **Erik Lindernoren for the PyTorch implementation** this work is based on (https://github.com/eriklindernoren/PyTorch-YOLOv3).
The https://github.com/ultralytics/yolov3 repo contains inference and training code for YOLOv3 in PyTorch. The code works on Linux, MacOS and Windows. Training is done on the COCO dataset by default: https://cocodataset.org/#home. **Credit to Joseph Redmon for YOLO: ** https://pjreddie.com/darknet/yolo/.
# Requirements
@ -26,6 +26,7 @@ Python 3.7 or later with the following `pip3 install -U -r requirements.txt` pac
- `numpy`
- `torch >= 1.0.0`
- `opencv-python`
- `tqdm`
# Tutorials
@ -64,17 +65,22 @@ HS**V** Intensity | +/- 50%
## Speed
https://cloud.google.com/deep-learning-vm/
**Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory)
**Machine type:** n1-standard-8 (8 vCPUs, 30 GB memory)
**CPU platform:** Intel Skylake
**GPUs:** 1-4 x NVIDIA Tesla P100
**GPUs:** 1-4x P100 ($0.493/hr), 1-8x V100 ($0.803/hr)
**HDD:** 100 GB SSD
**Dataset:** COCO train 2014
GPUs | `batch_size` | speed | COCO epoch
--- |---| --- | ---
(P100) | (images) | (s/batch) | (min/epoch)
1 | 16 | 0.39s | 48min
2 | 32 | 0.48s | 29min
4 | 64 | 0.65s | 20min
GPUs | `batch_size` | batch time | epoch time | epoch cost
--- |---| --- | --- | ---
<i></i> | (images) | (s/batch) | |
1 P100 | 16 | 0.39s | 48min | $0.39
2 P100 | 32 | 0.48s | 29min | $0.47
4 P100 | 64 | 0.65s | 20min | $0.65
1 V100 | 16 | 0.25s | 31min | $0.41
2 V100 | 32 | 0.29s | 18min | $0.48
4 V100 | 64 | 0.41s | 13min | $0.70
8 V100 | 128 | 0.49s | 7min | $0.80
# Inference

View File

@ -1,7 +1,5 @@
import argparse
import shutil
import time
from pathlib import Path
from sys import platform
from models import *
@ -32,9 +30,9 @@ def detect(
# Load weights
if weights.endswith('.pt'): # pytorch format
if weights.endswith('yolov3.pt') and not os.path.exists(weights):
if (platform == 'darwin') or (platform == 'linux'):
if platform in ('darwin', 'linux'): # linux/macos
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights)
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format
_ = load_darknet_weights(model, weights)
@ -49,15 +47,15 @@ def detect(
# Get classes and colors
classes = load_classes(parse_data_cfg('cfg/coco.data')['names'])
colors = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))]
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
for i, (path, img, im0) in enumerate(dataloader):
t = time.time()
save_path = str(Path(output) / Path(path).name)
if webcam:
print('webcam frame %g: ' % (i + 1), end='')
else:
print('image %g/%g %s: ' % (i + 1, len(dataloader), path), end='')
save_path = str(Path(output) / Path(path).name)
# Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device)
@ -81,18 +79,16 @@ def detect(
print('%g %ss' % (n, classes[int(c)]), end=', ')
# Draw bounding boxes and labels of detections
for x1, y1, x2, y2, conf, cls_conf, cls in detections:
for *xyxy, conf, cls_conf, cls in detections:
if save_txt: # Write to file
with open(save_path + '.txt', 'a') as file:
file.write('%g %g %g %g %g %g\n' %
(x1, y1, x2, y2, cls, cls_conf * conf))
file.write(('%g ' * 6 + '\n') % (*xyxy, cls, cls_conf * conf))
# Add bbox to the image
label = '%s %.2f' % (classes[int(cls)], conf)
plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls)])
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
dt = time.time() - t
print('Done. (%.3fs)' % dt)
print('Done. (%.3fs)' % (time.time() - t))
if save_images: # Save generated image with detections
cv2.imwrite(save_path, im0)
@ -100,7 +96,7 @@ def detect(
if webcam: # Show live webcam
cv2.imshow(weights, im0)
if save_images and (platform == 'darwin'): # linux/macos
if save_images and platform == 'darwin': # macos
os.system('open ' + output + ' ' + save_path)

View File

@ -4,3 +4,5 @@ numpy
opencv-python
torch >= 1.0.0
matplotlib
pycocotools
tqdm

51
test.py
View File

@ -1,7 +1,6 @@
import argparse
import json
import time
from pathlib import Path
from torch.utils.data import DataLoader
@ -24,41 +23,44 @@ def test(
):
device = torch_utils.select_device()
# Configure run
data_cfg_dict = parse_data_cfg(data_cfg)
nC = int(data_cfg_dict['classes']) # number of classes (80 for COCO)
test_path = data_cfg_dict['valid']
if model is None:
# Initialize model
model = Darknet(cfg, img_size)
model = Darknet(cfg, img_size).to(device)
# Load weights
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format
_ = load_darknet_weights(model, weights)
model.to(device).eval()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
# Configure run
data_cfg = parse_data_cfg(data_cfg)
nC = int(data_cfg['classes']) # number of classes (80 for COCO)
test_path = data_cfg['valid']
# Dataloader
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=4,
pin_memory=False,
collate_fn=dataset.collate_fn)
model.eval()
mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP'))
mP, mR, mAPs, TP, jdict = [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
coco91class = coco80_to_coco91_class()
for imgs, targets, paths, shapes in dataloader:
# Unpad and collate targets
for j, t in enumerate(targets):
t[:, 0] = j
targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1)
targets = targets.to(device)
for imgs, targets, paths, shapes in tqdm(dataloader):
t = time.time()
output = model(imgs.to(device))
targets = targets.to(device)
imgs = imgs.to(device)
output = model(imgs)
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
# Compute average precision for each sample
@ -78,7 +80,7 @@ def test(
if save_json:
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
box = detections[:, :4].clone() # xyxy
scale_coords(img_size, box, (shapes[0][si], shapes[1][si])) # to original shape
scale_coords(img_size, box, shapes[si]) # to original shape
box = xyxy2xywh(box) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
@ -134,13 +136,13 @@ def test(
mean_R = np.mean(mR)
mean_mAP = np.mean(mAPs)
# Print image mAP and running mean mAP
print(('%11s%11s' + '%11.3g' * 4 + 's') %
(seen, len(dataset), mean_P, mean_R, mean_mAP, time.time() - t))
# Print image mAP and running mean mAP
print(('%11s%11s' + '%11.3g' * 4 + 's') %
(seen, len(dataset), mean_P, mean_R, mean_mAP, time.time() - t))
# Print mAP per class
print('\nmAP Per Class:')
for i, c in enumerate(load_classes(data_cfg_dict['names'])):
for i, c in enumerate(load_classes(data_cfg['names'])):
if AP_accum_count[i]:
print('%15s: %-.4f' % (c, AP_accum[i] / (AP_accum_count[i])))
@ -191,4 +193,5 @@ if __name__ == '__main__':
opt.iou_thres,
opt.conf_thres,
opt.nms_thres,
opt.save_json)
opt.save_json
)

118
train.py
View File

@ -1,6 +1,7 @@
import argparse
import time
import torch.distributed as dist
from torch.utils.data import DataLoader
import test # Import test.py to get mAP after each epoch
@ -20,7 +21,7 @@ def train(
accumulate=1,
multi_scale=False,
freeze_backbone=False,
num_workers=0
num_workers=4
):
weights = 'weights' + os.sep
latest = weights + 'latest.pt'
@ -40,7 +41,7 @@ def train(
# Optimizer
lr0 = 0.001 # initial learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9,weight_decay = 0.0005)
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9, weight_decay=0.0005)
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
@ -65,52 +66,55 @@ def train(
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
# Dataloader
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
if torch.cuda.device_count() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler=None
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,sampler=train_sampler)
# Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()):
# p.requires_grad = True if (p.shape[0] == 255) else False
# Set scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
# Set scheduler (reduce lr at epoch 250)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1, last_epoch=start_epoch - 1)
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
# Initialize distributed training
if torch.cuda.device_count() > 1:
dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = None
# Dataloader
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=False,
collate_fn=dataset.collate_fn,
sampler=sampler)
# Start training
t0 = time.time()
nB = len(dataloader)
t = time.time()
model_info(model)
n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches
for epoch in range(epochs):
n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
for epoch in range(start_epoch, epochs):
model.train()
epoch += start_epoch
print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
print(('\n%8s%12s' + '%10s' * 7) % (
'Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
# Update scheduler (automatic)
# scheduler.step()
# Update scheduler (manual)
lr = lr0 / 10 if epoch > 250 else lr0
for x in optimizer.param_groups:
x['lr'] = lr
# Update scheduler
scheduler.step()
# Freeze backbone at epoch 0, unfreeze at epoch 1
if freeze_backbone and epoch < 2:
for i, (name, p) in enumerate(model.named_parameters()):
for name, p in model.named_parameters():
if int(name.split('.')[1]) < cutoff: # if layer < 75
p.requires_grad = False if (epoch == 0) else True
p.requires_grad = False if epoch == 0 else True
rloss = defaultdict(float)
mloss = defaultdict(float) # mean loss
for i, (imgs, targets, _, _) in enumerate(dataloader):
# Unpad and collate targets
for j, t in enumerate(targets):
t[:, 0] = j
targets = torch.cat([t[t[:, 5].nonzero()] for t in targets], 0).squeeze(1)
imgs = imgs.to(device)
targets = targets.to(device)
nT = len(targets)
if nT == 0: # if no targets continue
@ -119,25 +123,26 @@ def train(
# Plot images with bounding boxes
plot_images = False
if plot_images:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
fig = plt.figure(figsize=(10, 10))
for ip in range(batch_size):
labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size
plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0))
plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-')
plt.axis('off')
fig.tight_layout()
fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi)
# SGD burn-in
if (epoch == 0) and (i <= n_burnin):
if epoch == 0 and i <= n_burnin:
lr = lr0 * (i / n_burnin) ** 4
for x in optimizer.param_groups:
x['lr'] = lr
# Run model
pred = model(imgs.to(device))
pred = model(imgs)
# Build targets
target_list = build_targets(model, targets.to(device), pred)
target_list = build_targets(model, targets)
# Compute loss
loss, loss_dict = compute_loss(pred, target_list)
@ -146,21 +151,19 @@ def train(
loss.backward()
# Accumulate gradient for x batches before optimizing
if (i + 1) % accumulate == 0 or (i + 1) == len(dataloader):
if (i + 1) % accumulate == 0 or (i + 1) == nB:
optimizer.step()
optimizer.zero_grad()
# Running epoch-means of tracked metrics
for key, val in loss_dict.items():
rloss[key] = (rloss[key] * i + val) / (i + 1)
mloss[key] = (mloss[key] * i + val) / (i + 1)
s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1),
'%g/%g' % (i, len(dataloader) - 1),
rloss['xy'], rloss['wh'], rloss['conf'],
rloss['cls'], rloss['total'],
nT, time.time() - t0)
t0 = time.time()
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1),
mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'],
mloss['total'], nT, time.time() - t)
t = time.time()
print(s)
# Multi-Scale training (320 - 608 pixels) every 10 batches
@ -169,8 +172,8 @@ def train(
print('multi_scale img_size = %g' % dataset.img_size)
# Update best loss
if rloss['total'] < best_loss:
best_loss = rloss['total']
if mloss['total'] < best_loss:
best_loss = mloss['total']
# Save training results
save = True
@ -178,23 +181,24 @@ def train(
# Save latest checkpoint
checkpoint = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
# Save best checkpoint
if best_loss == rloss['total']:
if best_loss == mloss['total']:
os.system('cp ' + latest + ' ' + best)
# Save backup weights every 5 epochs (optional)
if (epoch > 0) and (epoch % 5 == 0):
os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch))
if epoch > 0 and epoch % 5 == 0:
os.system('cp ' + latest + ' ' + weights + 'backup%g.pt' % epoch)
# Calculate mAP
if type(model) is nn.parallel.DistributedDataParallel:
model = model.module
with torch.no_grad():
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, model=model)
P, R, mAP = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
# Write epoch results
with open('results.txt', 'a') as file:
@ -212,10 +216,10 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str,help='url used to set up distributed training')
parser.add_argument('--rank', default=0, type=int,help='node rank for distributed training')
parser.add_argument('--world-size', default=1, type=int,help='number of nodes for distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
parser.add_argument('--backend', default='nccl', type=str, help='distributed backend')
opt = parser.parse_args()
print(opt, end='\n\n')

View File

@ -2,11 +2,14 @@ import glob
import math
import os
import random
import shutil
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.utils import xyxy2xywh
@ -97,7 +100,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
assert len(self.img_files) > 0, 'No images found in %s' % path
self.img_size = img_size
self.augment = augment
self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
self.label_files = [x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt')
for x in self.img_files]
def __len__(self):
@ -136,58 +139,61 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img, ratio, padw, padh = letterbox(img, height=self.img_size)
# Load labels
labels = []
if os.path.isfile(label_path):
with open(label_path, 'r') as file:
lines = file.read().splitlines()
x = np.array([x.split() for x in lines], dtype=np.float32)
if x.size is 0:
# Empty labels file
labels = np.array([])
else:
if x.size > 0:
# Normalized xywh to pixel xyxy format
labels = x.copy()
labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
else:
labels = np.array([])
# Augment image and labels
if self.augment:
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))
img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))
nL = len(labels)
if nL > 0:
nL = len(labels) # number of labels
if nL:
# convert xyxy to xywh
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size
if self.augment:
# random left-right flip
lr_flip = True
if lr_flip & (random.random() > 0.5):
if lr_flip and random.random() > 0.5:
img = np.fliplr(img)
if nL > 0:
if nL:
labels[:, 1] = 1 - labels[:, 1]
# random up-down flip
ud_flip = False
if ud_flip & (random.random() > 0.5):
if ud_flip and random.random() > 0.5:
img = np.flipud(img)
if nL > 0:
if nL:
labels[:, 2] = 1 - labels[:, 2]
labels_out = np.zeros((100, 6), dtype=np.float32)
if nL > 0:
labels_out[:nL, 1:] = labels # max 100 labels per image
labels_out = torch.zeros((nL, 6))
if nL:
labels_out[:, 1:] = torch.from_numpy(labels)
# Normalize
img = img[:, :, ::-1].transpose(2, 0, 1) # list to np.array and BGR to RGB
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
return torch.from_numpy(img), torch.from_numpy(labels_out), img_path, (h, w)
return torch.from_numpy(img), labels_out, img_path, (h, w)
@staticmethod
def collate_fn(batch):
img, label, path, hw = list(zip(*batch)) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, hw
def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square
@ -203,11 +209,13 @@ def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectang
return img, ratio, dw, dh
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
def random_affine(img, targets=(), degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
if targets is None:
targets = []
border = 0 # width of added border (optional)
height = max(img.shape[0], img.shape[1]) + border * 2
@ -233,52 +241,61 @@ def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scal
borderValue=borderValue) # BGR order borderValue
# Return warped points also
if targets is not None:
if len(targets) > 0:
n = targets.shape[0]
points = targets[:, 1:5].copy()
area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])
if len(targets) > 0:
n = targets.shape[0]
points = targets[:, 1:5].copy()
area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = (xy @ M.T)[:, :2].reshape(n, 8)
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = (xy @ M.T)[:, :2].reshape(n, 8)
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# apply angle-based reduction
radians = a * math.pi / 180
reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
x = (xy[:, 2] + xy[:, 0]) / 2
y = (xy[:, 3] + xy[:, 1]) / 2
w = (xy[:, 2] - xy[:, 0]) * reduction
h = (xy[:, 3] - xy[:, 1]) * reduction
xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# apply angle-based reduction
radians = a * math.pi / 180
reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
x = (xy[:, 2] + xy[:, 0]) / 2
y = (xy[:, 3] + xy[:, 1]) / 2
w = (xy[:, 2] - xy[:, 0]) * reduction
h = (xy[:, 3] - xy[:, 1]) * reduction
xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# reject warped points outside of image
np.clip(xy, 0, height, out=xy)
w = xy[:, 2] - xy[:, 0]
h = xy[:, 3] - xy[:, 1]
area = w * h
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
# reject warped points outside of image
np.clip(xy, 0, height, out=xy)
w = xy[:, 2] - xy[:, 0]
h = xy[:, 3] - xy[:, 1]
area = w * h
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
targets = targets[i]
targets[:, 1:5] = xy[i]
targets = targets[i]
targets[:, 1:5] = xy[i]
return imw, targets, M
else:
return imw
return imw, targets
def convert_tif2bmp(p='../xview/val_images_bmp'):
import glob
import cv2
files = sorted(glob.glob('%s/*.tif' % p))
for i, f in enumerate(files):
print('%g/%g' % (i + 1, len(files)))
cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f))
os.system('rm -rf ' + f)
def convert_images2bmp():
# cv2.imread() jpg at 230 img/s, *.bmp at 400 img/s
for path in ['../coco/images/val2014/', '../coco/images/train2014/']:
folder = os.sep + Path(path).name
output = path.replace(folder, folder + 'bmp')
if os.path.exists(output):
shutil.rmtree(output) # delete output folder
os.makedirs(output) # make new output folder
for f in tqdm(glob.glob('%s*.jpg' % path)):
save_name = f.replace('.jpg', '.bmp').replace(folder, folder + 'bmp')
cv2.imwrite(save_name, cv2.imread(f))
for label_path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
with open(label_path, 'r') as file:
lines = file.read()
lines = lines.replace('2014/', '2014bmp/').replace('.jpg', '.bmp').replace(
'/Users/glennjocher/PycharmProjects/', '../')
with open(label_path.replace('5k', '5k_bmp'), 'w') as file:
file.write(lines)

View File

@ -3,13 +3,14 @@
# New VM
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset.sh
bash yolov3/weights/download_yolov3_weights.sh
sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3
sudo shutdown
# Train
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
cp -r weights yolov3
cd yolov3 && python3 train.py --batch-size 16 --epochs 1
cd yolov3 && python3 train.py --batch-size 48 --epochs 1
sudo shutdown
# Resume
@ -20,11 +21,17 @@ python3 detect.py
# Clone a branch
sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3
cp -r weights yolov3
cd yolov3 && python3 train.py --batch-size 48 --epochs 1
sudo shutdown
# Git pull branch
git pull https://github.com/ultralytics/yolov3 multi_gpu
# Test
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3
cd yolov3 && python3 test.py --save-json --conf-thres 0.005
cd yolov3 && python3 test.py --save-json --conf-thres 0.001 --img-size 416
# Test Darknet training
python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup
@ -33,7 +40,7 @@ python3 test.py --img_size 416 --weights ../darknet/backup/yolov3.backup
wget https://storage.googleapis.com/ultralytics/yolov3.pt -O weights/latest.pt
# Copy latest.pt to bucket
gsutil cp yolov3/weights/latest.pt gs://ultralytics
gsutil cp yolov3/weights/latest1gpu.pt gs://ultralytics
# Copy latest.pt from bucket
gsutil cp gs://ultralytics/latest.pt yolov3/weights/latest.pt

View File

@ -95,7 +95,7 @@ def weights_init_normal(m):
def xyxy2xywh(x):
# Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]
y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x)
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2
y[:, 1] = (x[:, 1] + x[:, 3]) / 2
y[:, 2] = x[:, 2] - x[:, 0]
@ -105,7 +105,7 @@ def xyxy2xywh(x):
def xywh2xyxy(x):
# Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]
y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x)
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y[:, 0] = (x[:, 0] - x[:, 2] / 2)
y[:, 1] = (x[:, 1] - x[:, 3] / 2)
y[:, 2] = (x[:, 0] + x[:, 2] / 2)
@ -251,7 +251,7 @@ def wh_iou(box1, box2):
def compute_loss(p, targets): # predictions, targets
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
loss, lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
txy, twh, tcls, tconf, indices = targets
txy, twh, tcls, indices = targets
MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss()
BCE = nn.BCEWithLogitsLoss()
@ -260,18 +260,21 @@ def compute_loss(p, targets): # predictions, targets
# gp = [x.numel() for x in tconf] # grid points
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
tconf = torch.zeros_like(pi0[..., 0]) # conf
# Compute losses
k = 1 # nT / bs
if len(b) > 0:
pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf
lxy += k * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy
lwh += k * MSE(pi[..., 2:4], twh[i]) # wh
lcls += (k / 4) * CE(pi[..., 5:], tcls[i])
# pos_weight = FT([gp[i] / min(gp) * 4.])
# BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
lconf += (k * 64) * BCE(pi0[..., 4], tconf[i])
lconf += (k * 64) * BCE(pi0[..., 4], tconf)
loss = lxy + lwh + lconf + lcls
# Add to dictionary
@ -283,15 +286,13 @@ def compute_loss(p, targets): # predictions, targets
return loss, d
def build_targets(model, targets, pred):
def build_targets(model, targets):
# targets = [image, class, x, y, w, h]
if isinstance(model, nn.parallel.DistributedDataParallel):
model = model.module
yolo_layers = get_yolo_layers(model)
# anchors = closest_anchor(model, targets) # [layer, anchor, i, j]
txy, twh, tcls, tconf, indices = [], [], [], [], []
for i, layer in enumerate(yolo_layers):
txy, twh, tcls, indices = [], [], [], []
for i, layer in enumerate(get_yolo_layers(model)):
nG = model.module_list[layer][0].nG # grid size
anchor_vec = model.module_list[layer][0].anchor_vec
@ -324,12 +325,7 @@ def build_targets(model, targets, pred):
# Class
tcls.append(c)
# Conf
tci = torch.zeros_like(pred[i][..., 0])
tci[b, a, gj, gi] = 1 # conf
tconf.append(tci)
return txy, twh, tcls, tconf, indices
return txy, twh, tcls, indices
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
@ -439,15 +435,6 @@ def get_yolo_layers(model):
return [i for i, x in enumerate(bool_vec) if x] # [82, 94, 106] for yolov3
def return_torch_unique_index(u, uv):
n = uv.shape[1] # number of columns
first_unique = torch.zeros(n, device=u.device).long()
for j in range(n):
first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0]
return first_unique
def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
a = torch.load(filename, map_location='cpu')
@ -480,10 +467,9 @@ def plot_results(start=0):
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
# from utils.utils import *; plot_results()
plt.figure(figsize=(14, 7))
fig = plt.figure(figsize=(14, 7))
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
files = sorted(glob.glob('results*.txt'))
for f in files:
for f in sorted(glob.glob('results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
x = range(1, results.shape[1])
for i in range(8):
@ -492,3 +478,4 @@ def plot_results(start=0):
plt.title(s[i])
if i == 0:
plt.legend()
fig.tight_layout()

View File

@ -18,3 +18,4 @@ wget -c https://pjreddie.com/media/files/darknet53.conv.74
# ./darknet partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15
# mv yolov3-tiny.conv.15 ../
cd ..