This commit is contained in:
Glenn Jocher 2019-04-05 15:34:42 +02:00
parent 7e82df6edc
commit cb352be02c
5 changed files with 96 additions and 76 deletions

View File

@ -57,7 +57,7 @@ def detect(
if ONNX_EXPORT:
torch.onnx.export(model, img, 'weights/model.onnx', verbose=True)
return
pred = model(img)
pred, _ = model(img)
detections = non_max_suppression(pred, conf_thres, nms_thres)[0]
if detections is not None and len(detections) > 0:

View File

@ -156,16 +156,16 @@ class YOLOLayer(nn.Module):
return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t()
else: # inference
p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy
p[..., 2:4] = torch.exp(p[..., 2:4]) * self.anchor_wh # wh yolo method
# p[..., 2:4] = ((torch.sigmoid(p[..., 2:4]) * 2) ** 2) * self.anchor_wh # wh power method
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
p[..., 5:] = torch.sigmoid(p[..., 5:]) # p_class
# p[..., 5:] = F.softmax(p[..., 5:], dim=4) # p_class
p[..., :4] *= self.stride
io = p.clone() # inference output
io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 2) * self.anchor_wh # wh power method
io[..., 4:] = torch.sigmoid(io[..., 4:]) # p_conf, p_cls
# io[..., 5:] = F.softmax(io[..., 5:], dim=4) # p_cls
io[..., :4] *= self.stride
# reshape from [1, 3, 13, 13, 85] to [1, 507, 85]
return p.view(bs, -1, 5 + self.nC)
return io.view(bs, -1, 5 + self.nC), p
class Darknet(nn.Module):
@ -202,11 +202,14 @@ class Darknet(nn.Module):
output.append(x)
layer_outputs.append(x)
if ONNX_EXPORT:
if self.training:
return output
elif ONNX_EXPORT:
output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
else:
return output if self.training else torch.cat(output, 1)
io, p = list(zip(*output)) # inference output, training output
return torch.cat(io, 1), p
def get_yolo_layers(model):

69
test.py
View File

@ -39,32 +39,42 @@ def test(
# Configure run
data_cfg = parse_data_cfg(data_cfg)
test_path = data_cfg['valid']
# if (os.sep + 'coco' + os.sep) in test_path: # COCO dataset probable
# save_json = True # use pycocotools
nc = int(data_cfg['classes']) # number of classes
test_path = data_cfg['valid'] # path to test images
names = load_classes(data_cfg['names']) # class names
# Dataloader
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=4,
num_workers=0,
pin_memory=False,
collate_fn=dataset.collate_fn)
model.eval()
seen = 0
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP'))
mP, mR, mAP, mAPj = 0.0, 0.0, 0.0, 0.0
jdict, tdict, stats, AP, AP_class = [], [], [], [], []
model.eval()
coco91class = coco80_to_coco91_class()
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc='Calculating mAP')):
print('%15s' * 7 % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP', 'F1'))
loss, p, r, f1, mp, mr, map, mf1 = 0., 0., 0., 0., 0., 0., 0., 0.
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc='Computing mAP')):
targets = targets.to(device)
imgs = imgs.to(device)
output = model(imgs)
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
# Run model
inf_out, train_out = model(imgs) # inference and training outputs
# Per image
# Build targets
target_list = build_targets(model, targets)
# Compute loss
loss_i, _ = compute_loss(train_out, target_list)
loss += loss_i.item()
# Run NMS
output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres)
# Statistics per image
for si, pred in enumerate(output):
labels = targets[targets[:, 0] == si, 1:]
correct, detected = [], []
@ -77,7 +87,8 @@ def test(
stats.append((correct, torch.Tensor(), torch.Tensor(), tcls))
continue
if save_json: # add to json pred dictionary
# Append to pycocotools JSON dictionary
if save_json:
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
image_id = int(Path(paths[si]).stem.split('_')[-1])
box = pred[:, :4].clone() # xyxy
@ -118,24 +129,23 @@ def test(
# Append Statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
# Compute means
# Compute statistics
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class
if len(stats_np):
AP, AP_class, R, P = ap_per_class(*stats_np)
mP, mR, mAP = P.mean(), R.mean(), AP.mean()
p, r, ap, f1, ap_class = ap_per_class(*stats_np)
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
# Print P, R, mAP
print(('%11s%11s' + '%11.3g' * 3) % (seen, len(dataset), mP, mR, mAP))
# Print results
print(('%15s' + '%15.3g' * 6) % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
# Print mAP per class
if len(stats_np):
print('\nmAP Per Class:')
names = load_classes(data_cfg['names'])
for c, a in zip(AP_class, AP):
print('%15s: %-.4f' % (names[c], a))
# Print results per class
if nc > 1 and len(stats_np):
for i, c in enumerate(ap_class):
print(('%15s' + '%15.3g' * 6) % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))
# Save JSON
if save_json and mAP and len(jdict):
if save_json and map and len(jdict):
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataset.img_files]
with open('results.json', 'w') as file:
json.dump(jdict, file)
@ -152,13 +162,10 @@ def test(
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
mAP = cocoEval.stats[1] # update mAP to pycocotools mAP
map = cocoEval.stats[1] # update mAP to pycocotools mAP
# F1 score = harmonic mean of precision and recall
# F1 = 2 * (mP * mR) / (mP + mR)
# Return mAP
return mP, mR, mAP
# Return results
return mp, mr, map, mf1, loss
if __name__ == '__main__':

View File

@ -121,8 +121,8 @@ def train(
imgs = imgs.to(device)
targets = targets.to(device)
nT = len(targets)
if nT == 0: # if no targets continue
nt = len(targets)
if nt == 0: # if no targets continue
continue
# Plot images with bounding boxes
@ -167,7 +167,7 @@ def train(
s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1),
mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'],
mloss['total'], nT, time.time() - t)
mloss['total'], nt, time.time() - t)
t = time.time()
print(s)
@ -176,38 +176,42 @@ def train(
dataset.img_size = random.choice(range(10, 20)) * 32
print('multi_scale img_size = %g' % dataset.img_size)
# Update best loss
if mloss['total'] < best_loss:
best_loss = mloss['total']
# Save training results
save = True
if save:
# Save latest checkpoint
chkpt = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(chkpt, latest)
# Save best checkpoint
if best_loss == mloss['total']:
torch.save(chkpt, best)
# Save backup every 10 epochs (optional)
if epoch > 0 and epoch % 10 == 0:
torch.save(chkpt, weights + 'backup%g.pt' % epoch)
del chkpt
# Calculate mAP
with torch.no_grad():
results = test.test(cfg, data_cfg, batch_size=batch_size, img_size=img_size, model=model)
# Write epoch results
with open('results.txt', 'a') as file:
file.write(s + '%11.3g' * 3 % results + '\n') # append P, R, mAP
file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss
# Update best loss
test_loss = results[4]
if test_loss < best_loss:
best_loss = results[0]
# Save training results
save = True and not opt.no_save
if save:
# Create checkpoint
chkpt = {'epoch': epoch,
'best_loss': best_loss,
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
# Save latest checkpoint
torch.save(chkpt, latest)
# Save best checkpoint
if best_loss == test_loss:
torch.save(chkpt, best)
# Save backup every 10 epochs (optional)
if epoch > 0 and epoch % 10 == 0:
torch.save(chkpt, weights + 'backup%g.pt' % epoch)
# Delete checkpoint
del chkpt
if __name__ == '__main__':
@ -226,6 +230,7 @@ if __name__ == '__main__':
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
parser.add_argument('--backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--no-save', action='store_false', help='transfer learning flag')
opt = parser.parse_args()
print(opt, end='\n\n')

View File

@ -175,7 +175,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
# Plot
# plt.plot(recall_curve, precision_curve)
return np.array(ap), unique_classes.astype('int32'), np.array(r), np.array(p)
# Compute F1 score (harmonic mean of precision and recall)
p, r, ap = np.array(p), np.array(r), np.array(ap)
f1 = 2 * p * r / (p + r + 1e-16)
return p, r, ap, f1, unique_classes.astype('int32')
def compute_ap(recall, precision):
@ -484,12 +488,13 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
fig = plt.figure(figsize=(14, 7))
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Train Loss', 'Precision', 'Recall', 'mAP', 'F1',
'Test Loss']
for f in sorted(glob.glob('results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11, 12, 13]).T
x = range(start, stop if stop else results.shape[1])
for i in range(8):
plt.subplot(2, 4, i + 1)
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.plot(x, results[i, x], marker='.', label=f)
plt.title(s[i])
if i == 0: