updates
This commit is contained in:
parent
be01fc357b
commit
eb81c0b9ae
11
test.py
11
test.py
|
@ -74,6 +74,7 @@ def test(cfg,
|
||||||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
_, _, height, width = imgs.shape # batch size, channels, height, width
|
_, _, height, width = imgs.shape # batch size, channels, height, width
|
||||||
|
whwh = torch.Tensor([width, height, width, height]).to(device)
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
# Plot images with bounding boxes
|
||||||
f = 'test_batch%g.png' % batch_i # filename
|
f = 'test_batch%g.png' % batch_i # filename
|
||||||
|
@ -126,13 +127,13 @@ def test(cfg,
|
||||||
'score': floatn(d[4], 5)})
|
'score': floatn(d[4], 5)})
|
||||||
|
|
||||||
# Assign all predictions as incorrect
|
# Assign all predictions as incorrect
|
||||||
correct = torch.zeros(len(pred), niou, dtype=torch.bool)
|
correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device)
|
||||||
if nl:
|
if nl:
|
||||||
detected = [] # target indices
|
detected = [] # target indices
|
||||||
tcls_tensor = labels[:, 0]
|
tcls_tensor = labels[:, 0]
|
||||||
|
|
||||||
# target boxes
|
# target boxes
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) * torch.Tensor([width, height, width, height]).to(device)
|
tbox = xywh2xyxy(labels[:, 1:5]) * whwh
|
||||||
|
|
||||||
# Per target class
|
# Per target class
|
||||||
for cls in torch.unique(tcls_tensor):
|
for cls in torch.unique(tcls_tensor):
|
||||||
|
@ -140,7 +141,7 @@ def test(cfg,
|
||||||
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
|
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
|
||||||
|
|
||||||
# Search for detections
|
# Search for detections
|
||||||
if len(pi):
|
if pi.shape[0]:
|
||||||
# Prediction to target ious
|
# Prediction to target ious
|
||||||
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
||||||
|
|
||||||
|
@ -149,12 +150,12 @@ def test(cfg,
|
||||||
d = ti[i[j]] # detected target
|
d = ti[i[j]] # detected target
|
||||||
if d not in detected:
|
if d not in detected:
|
||||||
detected.append(d)
|
detected.append(d)
|
||||||
correct[pi[j]] = (ious[j] > iouv).cpu() # iou_thres is 1xn
|
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
|
||||||
if len(detected) == nl: # all targets already located in image
|
if len(detected) == nl: # all targets already located in image
|
||||||
break
|
break
|
||||||
|
|
||||||
# Append statistics (correct, conf, pcls, tcls)
|
# Append statistics (correct, conf, pcls, tcls)
|
||||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||||
|
|
||||||
# Compute statistics
|
# Compute statistics
|
||||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||||
|
|
Loading…
Reference in New Issue