updates
This commit is contained in:
parent
01569d15e3
commit
c9328f663f
11
test.py
11
test.py
|
@ -47,7 +47,7 @@ def test(
|
||||||
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
|
dataset = LoadImagesAndLabels(test_path, img_size=img_size)
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
|
@ -67,7 +67,8 @@ def test(
|
||||||
# Per image
|
# Per image
|
||||||
for si, pred in enumerate(output):
|
for si, pred in enumerate(output):
|
||||||
labels = targets[targets[:, 0] == si, 1:]
|
labels = targets[targets[:, 0] == si, 1:]
|
||||||
correct, detected, tcls = [], [], []
|
correct, detected = [], []
|
||||||
|
tcls = torch.Tensor()
|
||||||
seen += 1
|
seen += 1
|
||||||
|
|
||||||
if pred is None:
|
if pred is None:
|
||||||
|
@ -93,8 +94,8 @@ def test(
|
||||||
correct.extend([0] * len(pred))
|
correct.extend([0] * len(pred))
|
||||||
else:
|
else:
|
||||||
# Extract target boxes as (x1, y1, x2, y2)
|
# Extract target boxes as (x1, y1, x2, y2)
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) * img_size
|
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
|
||||||
tcls = labels[:, 0].cpu()
|
tcls = labels[:, 0] # target classes
|
||||||
|
|
||||||
for *pbox, pconf, pcls_conf, pcls in pred:
|
for *pbox, pconf, pcls_conf, pcls in pred:
|
||||||
if pcls not in tcls:
|
if pcls not in tcls:
|
||||||
|
@ -112,7 +113,7 @@ def test(
|
||||||
correct.append(0)
|
correct.append(0)
|
||||||
|
|
||||||
# Append Statistics (correct, conf, pcls, tcls)
|
# Append Statistics (correct, conf, pcls, tcls)
|
||||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
|
||||||
|
|
||||||
# Compute means
|
# Compute means
|
||||||
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
||||||
|
|
Loading…
Reference in New Issue