From eb81c0b9ae9592da3c801da8ca25a961127484ce Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 4 Mar 2020 01:47:31 -0800 Subject: [PATCH] updates --- test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test.py b/test.py index 6bcf3d07..bfbdb09d 100644 --- a/test.py +++ b/test.py @@ -74,6 +74,7 @@ def test(cfg, imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) _, _, height, width = imgs.shape # batch size, channels, height, width + whwh = torch.Tensor([width, height, width, height]).to(device) # Plot images with bounding boxes f = 'test_batch%g.png' % batch_i # filename @@ -126,13 +127,13 @@ def test(cfg, 'score': floatn(d[4], 5)}) # Assign all predictions as incorrect - correct = torch.zeros(len(pred), niou, dtype=torch.bool) + correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device) if nl: detected = [] # target indices tcls_tensor = labels[:, 0] # target boxes - tbox = xywh2xyxy(labels[:, 1:5]) * torch.Tensor([width, height, width, height]).to(device) + tbox = xywh2xyxy(labels[:, 1:5]) * whwh # Per target class for cls in torch.unique(tcls_tensor): @@ -140,7 +141,7 @@ def test(cfg, pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices # Search for detections - if len(pi): + if pi.shape[0]: # Prediction to target ious ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices @@ -149,12 +150,12 @@ def test(cfg, d = ti[i[j]] # detected target if d not in detected: detected.append(d) - correct[pi[j]] = (ious[j] > iouv).cpu() # iou_thres is 1xn + correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn if len(detected) == nl: # all targets already located in image break # Append statistics (correct, conf, pcls, tcls) - stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) + stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy