updates
This commit is contained in:
parent
3e633783d8
commit
9c661e2d53
8
test.py
8
test.py
|
@ -69,7 +69,7 @@ def test(cfg,
|
|||
coco91class = coco80_to_coco91_class()
|
||||
s = ('%20s' + '%10s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@0.5', 'F1')
|
||||
p, r, f1, mp, mr, map, mf1, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
|
||||
loss = torch.zeros(3)
|
||||
loss = torch.zeros(3, device=device)
|
||||
jdict, stats, ap, ap_class = [], [], [], []
|
||||
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
|
||||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
||||
|
@ -91,7 +91,7 @@ def test(cfg,
|
|||
|
||||
# Compute loss
|
||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
||||
loss += compute_loss(train_out, targets, model)[1][:3].cpu() # GIoU, obj, cls
|
||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||
|
||||
# Run NMS
|
||||
t = torch_utils.time_synchronized()
|
||||
|
@ -132,7 +132,7 @@ def test(cfg,
|
|||
'score': floatn(d[4], 5)})
|
||||
|
||||
# Assign all predictions as incorrect
|
||||
correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device)
|
||||
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
|
||||
if nl:
|
||||
detected = [] # target indices
|
||||
tcls_tensor = labels[:, 0]
|
||||
|
@ -214,7 +214,7 @@ def test(cfg,
|
|||
maps = np.zeros(nc) + map
|
||||
for i, c in enumerate(ap_class):
|
||||
maps[c] = ap[i]
|
||||
return (mp, mr, map, mf1, *(loss / len(dataloader)).tolist()), maps
|
||||
return (mp, mr, map, mf1, *(loss.cpu() / len(dataloader)).tolist()), maps
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue