diff --git a/test.py b/test.py index c9516953..fe1a4128 100644 --- a/test.py +++ b/test.py @@ -124,16 +124,17 @@ def test( break # Continue if predicted class not among image classes - if pcls.item() not in tcls: + m = (pcls == tcls_tensor).nonzero().view(-1) # matches + if not any(m): continue # Best iou, index between pred and targets - iou, bi = bbox_iou(pbox, tbox).max(0) + iou, bi = bbox_iou(pbox, tbox[m]).max(0) # If iou > threshold and class is correct mark as correct - if iou > iou_thres and bi not in detected: # and pcls == tcls[bi]: + if iou > iou_thres and m[bi] not in detected: # and pcls == tcls[bi]: correct[i] = 1 - detected.append(bi) + detected.append(m[bi]) # Append statistics (correct, conf, pcls, tcls) stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))