This commit is contained in:
Glenn Jocher 2020-03-04 01:47:31 -08:00
parent be01fc357b
commit eb81c0b9ae
1 changed files with 6 additions and 5 deletions

11
test.py
View File

@ -74,6 +74,7 @@ def test(cfg,
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
targets = targets.to(device) targets = targets.to(device)
_, _, height, width = imgs.shape # batch size, channels, height, width _, _, height, width = imgs.shape # batch size, channels, height, width
whwh = torch.Tensor([width, height, width, height]).to(device)
# Plot images with bounding boxes # Plot images with bounding boxes
f = 'test_batch%g.png' % batch_i # filename f = 'test_batch%g.png' % batch_i # filename
@ -126,13 +127,13 @@ def test(cfg,
'score': floatn(d[4], 5)}) 'score': floatn(d[4], 5)})
# Assign all predictions as incorrect # Assign all predictions as incorrect
correct = torch.zeros(len(pred), niou, dtype=torch.bool) correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device)
if nl: if nl:
detected = [] # target indices detected = [] # target indices
tcls_tensor = labels[:, 0] tcls_tensor = labels[:, 0]
# target boxes # target boxes
tbox = xywh2xyxy(labels[:, 1:5]) * torch.Tensor([width, height, width, height]).to(device) tbox = xywh2xyxy(labels[:, 1:5]) * whwh
# Per target class # Per target class
for cls in torch.unique(tcls_tensor): for cls in torch.unique(tcls_tensor):
@ -140,7 +141,7 @@ def test(cfg,
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
# Search for detections # Search for detections
if len(pi): if pi.shape[0]:
# Prediction to target ious # Prediction to target ious
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
@ -149,12 +150,12 @@ def test(cfg,
d = ti[i[j]] # detected target d = ti[i[j]] # detected target
if d not in detected: if d not in detected:
detected.append(d) detected.append(d)
correct[pi[j]] = (ious[j] > iouv).cpu() # iou_thres is 1xn correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
if len(detected) == nl: # all targets already located in image if len(detected) == nl: # all targets already located in image
break break
# Append statistics (correct, conf, pcls, tcls) # Append statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
# Compute statistics # Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy