This commit is contained in:
Glenn Jocher 2020-03-04 12:17:37 -08:00
parent 3e633783d8
commit 9c661e2d53
1 changed files with 4 additions and 4 deletions

View File

@ -69,7 +69,7 @@ def test(cfg,
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%10s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@0.5', 'F1')
p, r, f1, mp, mr, map, mf1, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3)
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
@ -91,7 +91,7 @@ def test(cfg,
# Compute loss
if hasattr(model, 'hyp'): # if model has loss hyperparameters
loss += compute_loss(train_out, targets, model)[1][:3].cpu() # GIoU, obj, cls
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
# Run NMS
t = torch_utils.time_synchronized()
@ -132,7 +132,7 @@ def test(cfg,
'score': floatn(d[4], 5)})
# Assign all predictions as incorrect
correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device)
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
if nl:
detected = [] # target indices
tcls_tensor = labels[:, 0]
@ -214,7 +214,7 @@ def test(cfg,
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
return (mp, mr, map, mf1, *(loss / len(dataloader)).tolist()), maps
return (mp, mr, map, mf1, *(loss.cpu() / len(dataloader)).tolist()), maps
if __name__ == '__main__':