This commit is contained in:
Glenn Jocher 2019-04-02 13:56:54 +02:00
parent 01569d15e3
commit c9328f663f
1 changed files with 6 additions and 5 deletions

11
test.py
View File

@ -47,7 +47,7 @@ def test(
dataset = LoadImagesAndLabels(test_path, img_size=img_size) dataset = LoadImagesAndLabels(test_path, img_size=img_size)
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=0, num_workers=4,
pin_memory=False, pin_memory=False,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)
@ -67,7 +67,8 @@ def test(
# Per image # Per image
for si, pred in enumerate(output): for si, pred in enumerate(output):
labels = targets[targets[:, 0] == si, 1:] labels = targets[targets[:, 0] == si, 1:]
correct, detected, tcls = [], [], [] correct, detected = [], []
tcls = torch.Tensor()
seen += 1 seen += 1
if pred is None: if pred is None:
@ -93,8 +94,8 @@ def test(
correct.extend([0] * len(pred)) correct.extend([0] * len(pred))
else: else:
# Extract target boxes as (x1, y1, x2, y2) # Extract target boxes as (x1, y1, x2, y2)
tbox = xywh2xyxy(labels[:, 1:5]) * img_size tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
tcls = labels[:, 0].cpu() tcls = labels[:, 0] # target classes
for *pbox, pconf, pcls_conf, pcls in pred: for *pbox, pconf, pcls_conf, pcls in pred:
if pcls not in tcls: if pcls not in tcls:
@ -112,7 +113,7 @@ def test(
correct.append(0) correct.append(0)
# Append Statistics (correct, conf, pcls, tcls) # Append Statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
# Compute means # Compute means
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))] stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]