This commit is contained in:
Glenn Jocher 2019-04-02 13:56:54 +02:00
parent 01569d15e3
commit c9328f663f
1 changed files with 6 additions and 5 deletions

11
test.py
View File

@ -47,7 +47,7 @@ def test(
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=0,
num_workers=4,
pin_memory=False,
collate_fn=dataset.collate_fn)
@ -67,7 +67,8 @@ def test(
# Per image
for si, pred in enumerate(output):
labels = targets[targets[:, 0] == si, 1:]
correct, detected, tcls = [], [], []
correct, detected = [], []
tcls = torch.Tensor()
seen += 1
if pred is None:
@ -93,8 +94,8 @@ def test(
correct.extend([0] * len(pred))
else:
# Extract target boxes as (x1, y1, x2, y2)
tbox = xywh2xyxy(labels[:, 1:5]) * img_size
tcls = labels[:, 0].cpu()
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
tcls = labels[:, 0] # target classes
for *pbox, pconf, pcls_conf, pcls in pred:
if pcls not in tcls:
@ -112,7 +113,7 @@ def test(
correct.append(0)
# Append Statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
# Compute means
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]