This commit is contained in:
Glenn Jocher 2019-04-17 15:52:51 +02:00
parent 9c5524ba82
commit 0b8a28e3dd
4 changed files with 52 additions and 32 deletions

View File

@ -68,11 +68,8 @@ def test(
# Run model
inf_out, train_out = model(imgs) # inference and training outputs
# Build targets
target_list = build_targets(model, targets)
# Compute loss
loss_i, _ = compute_loss(train_out, target_list)
loss_i, _ = compute_loss(train_out, targets, model)
loss += loss_i.item()
# Run NMS

View File

@ -2,6 +2,7 @@ import argparse
import time
import torch.distributed as dist
import torch.optim as optim
from torch.utils.data import DataLoader
import test # Import test.py to get mAP after each epoch
@ -41,9 +42,21 @@ def train(
# Initialize model
model = Darknet(cfg, img_size).to(device)
# Initialize hyperparameters
hyp = {'k': 8.4875, # loss multiple
'xy': 0.079756, # xy loss fraction
'wh': 0.010461, # wh loss fraction
'cls': 0.02105, # cls loss fraction
'conf': 0.88873, # conf loss fraction
'iou_t': 0.1, # iou target-anchor training threshold
'lr0': 0.001, # initial learning rate
'lrf': -2., # final learning rate = lr0 * (10 ** lrf)
'momentum': 0.9, # SGD momentum
'weight_decay': 0.0005, # optimizer weight decay
}
# Optimizer
lr0 = 0.001 # initial learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=0.9, weight_decay=0.0005)
optimizer = optim.SGD(model.parameters(), lr=hyp['lr0'], momentum=hyp['momentum'], weight_decay=hyp['weight_decay'])
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
@ -74,8 +87,11 @@ def train(
cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74')
# Scheduler (reduce lr at epochs 218, 245, i.e. batches 400k, 450k)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[218, 245], gamma=0.1,
last_epoch=start_epoch - 1)
# lf = lambda x: 1 - x / epochs # linear ramp to zero
# lf = lambda x: 10 ** (-2 * x / epochs) # exp ramp to lr0 * 1e-2
# lf = lambda x: 1 - 10 ** (-2 * (1 - x / epochs)) # inv exp ramp to lr0 * 1e-2
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf, last_epoch=start_epoch - 1)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[218, 245], gamma=0.1, last_epoch=start_epoch - 1)
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
@ -105,9 +121,10 @@ def train(
# Start training
t = time.time()
model.hyp = hyp # attach hyperparameters to model
model_info(model)
nB = len(dataloader)
n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
nb = len(dataloader)
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None
os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None
for epoch in range(start_epoch, epochs):
@ -137,18 +154,15 @@ def train(
# SGD burn-in
if epoch == 0 and i <= n_burnin:
lr = lr0 * (i / n_burnin) ** 4
lr = hyp['lr0'] * (i / n_burnin) ** 4
for x in optimizer.param_groups:
x['lr'] = lr
# Run model
pred = model(imgs)
# Build targets
target_list = build_targets(model, targets)
# Compute loss
loss, loss_items = compute_loss(pred, target_list)
loss, loss_items = compute_loss(pred, targets, model)
# Compute gradient
if mixed_precision:
@ -158,7 +172,7 @@ def train(
loss.backward()
# Accumulate gradient for x batches before optimizing
if (i + 1) % accumulate == 0 or (i + 1) == nB:
if (i + 1) % accumulate == 0 or (i + 1) == nb:
optimizer.step()
optimizer.zero_grad()
@ -168,7 +182,7 @@ def train(
# Print batch results
s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1),
'%g/%g' % (i, nB - 1), *mloss, nt, time.time() - t)
'%g/%g' % (i, nb - 1), *mloss, nt, time.time() - t)
t = time.time()
print(s)
@ -182,7 +196,8 @@ def train(
results = (0, 0, 0, 0, 0)
else:
with torch.no_grad():
results = test.test(cfg, data_cfg, batch_size=batch_size, img_size=img_size, model=model, conf_thres=0.1)
results = test.test(cfg, data_cfg, batch_size=batch_size, img_size=img_size, model=model,
conf_thres=0.1)
# Write epoch results
with open('results.txt', 'a') as file:
@ -235,6 +250,7 @@ if __name__ == '__main__':
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
parser.add_argument('--backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--nosave', action='store_true', help='do not save training results')
parser.add_argument('--var', default=0, type=int, help='debug variable')
opt = parser.parse_args()
print(opt, end='\n\n')

View File

@ -50,14 +50,19 @@ git clone https://github.com/ultralytics/yolov3 # master
cp -r weights yolov3
cp -r cocoapi/PythonAPI/pycocotools yolov3
cd yolov3
python3 train.py --nosave --data data/coco_100val.data
python3 train.py --nosave --data data/coco_32img.data --var 4 && mv results.txt results_t2.txt
python3 train.py --nosave --data data/coco_32img.data --var 5 && mv results.txt results_t3.txt
python3 -c "from utils import utils; utils.plot_results()"
gsutil cp results*.txt gs://ultralytics
gsutil cp results.png gs://ultralytics
sudo shutdown
#mv ../utils.py utils
mv ../train.py .
rm results*.txt # WARNING: removes existing results
python3 train.py --nosave --data data/coco_1img.data && mv results.txt results3_1img.txt
python3 train.py --nosave --data data/coco_10img.data && mv results.txt results3_10img.txt
python3 train.py --nosave --data data/coco_100img.data && mv results.txt results3_100img.txt
python3 train.py --nosave --data data/coco_100img.data && mv results.txt results4_100img.txt
python3 train.py --nosave --data data/coco_100img.data --transfer && mv results.txt results3_100imgTL.txt
# python3 train.py --nosave --data data/coco_1000img.data && mv results.txt results_1000img.txt
python3 -c "from utils import utils; utils.plot_results()"

View File

@ -242,35 +242,37 @@ def wh_iou(box1, box2):
return inter_area / union_area # iou
def compute_loss(p, targets): # predictions, targets
def compute_loss(p, targets, model): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, indices = targets
txy, twh, tcls, indices = build_targets(model, targets)
# Define criteria
MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss()
BCE = nn.BCEWithLogitsLoss()
# Compute losses
h = model.hyp # hyperparameters
bs = p[0].shape[0] # batch size
# gp = [x.numel() for x in tconf] # grid points
k = h['k'] * bs # loss gain
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
tconf = torch.zeros_like(pi0[..., 0]) # conf
# Compute losses
k = 8.4875 * bs
if len(b): # number of targets
pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
lxy += (k * 0.079756) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * 0.010461) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
# lwh += (k * 0.010461) * MSE(torch.sigmoid(pi[..., 2:4]), twh[i]) # wh power loss
lcls += (k * 0.02105) * CE(pi[..., 5:], tcls[i]) # class_conf loss
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss
# pos_weight = ft([gp[i] / min(gp) * 4.])
# BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
lconf += (k * 0.88873) * BCE(pi0[..., 4], tconf) # obj_conf loss
lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf) # obj_conf loss
loss = lxy + lwh + lconf + lcls
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
@ -296,7 +298,7 @@ def build_targets(model, targets):
# reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True
if reject:
j = iou > 0.10
j = iou > model.hyp['iou_t'] # hyperparameter
t, a, gwh = targets[j], a[j], gwh[j]
# Indices