updates
This commit is contained in:
parent
01569d15e3
commit
c9328f663f
11
test.py
11
test.py
|
@ -47,7 +47,7 @@ def test(
|
|||
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=0,
|
||||
num_workers=4,
|
||||
pin_memory=False,
|
||||
collate_fn=dataset.collate_fn)
|
||||
|
||||
|
@ -67,7 +67,8 @@ def test(
|
|||
# Per image
|
||||
for si, pred in enumerate(output):
|
||||
labels = targets[targets[:, 0] == si, 1:]
|
||||
correct, detected, tcls = [], [], []
|
||||
correct, detected = [], []
|
||||
tcls = torch.Tensor()
|
||||
seen += 1
|
||||
|
||||
if pred is None:
|
||||
|
@ -93,8 +94,8 @@ def test(
|
|||
correct.extend([0] * len(pred))
|
||||
else:
|
||||
# Extract target boxes as (x1, y1, x2, y2)
|
||||
tbox = xywh2xyxy(labels[:, 1:5]) * img_size
|
||||
tcls = labels[:, 0].cpu()
|
||||
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
|
||||
tcls = labels[:, 0] # target classes
|
||||
|
||||
for *pbox, pconf, pcls_conf, pcls in pred:
|
||||
if pcls not in tcls:
|
||||
|
@ -112,7 +113,7 @@ def test(
|
|||
correct.append(0)
|
||||
|
||||
# Append Statistics (correct, conf, pcls, tcls)
|
||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
|
||||
|
||||
# Compute means
|
||||
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
||||
|
|
Loading…
Reference in New Issue