diff --git a/test.py b/test.py index 4ed93b17..40f9faf9 100644 --- a/test.py +++ b/test.py @@ -47,7 +47,7 @@ def test( dataset = LoadImagesAndLabels(test_path, img_size=img_size) dataloader = DataLoader(dataset, batch_size=batch_size, - num_workers=0, + num_workers=4, pin_memory=False, collate_fn=dataset.collate_fn) @@ -67,7 +67,8 @@ def test( # Per image for si, pred in enumerate(output): labels = targets[targets[:, 0] == si, 1:] - correct, detected, tcls = [], [], [] + correct, detected = [], [] + tcls = torch.Tensor() seen += 1 if pred is None: @@ -93,8 +94,8 @@ def test( correct.extend([0] * len(pred)) else: # Extract target boxes as (x1, y1, x2, y2) - tbox = xywh2xyxy(labels[:, 1:5]) * img_size - tcls = labels[:, 0].cpu() + tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes + tcls = labels[:, 0] # target classes for *pbox, pconf, pcls_conf, pcls in pred: if pcls not in tcls: @@ -112,7 +113,7 @@ def test( correct.append(0) # Append Statistics (correct, conf, pcls, tcls) - stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) + stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu())) # Compute means stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]