updates
This commit is contained in:
parent
51eb173416
commit
6f0086103c
27
test.py
27
test.py
|
@ -44,10 +44,8 @@ def test(
|
||||||
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
||||||
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
||||||
for batch_i, (imgs, targets) in enumerate(dataloader):
|
for batch_i, (imgs, targets) in enumerate(dataloader):
|
||||||
|
output = model(imgs.to(device))
|
||||||
with torch.no_grad():
|
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
|
||||||
output = model(imgs.to(device))
|
|
||||||
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
|
|
||||||
|
|
||||||
# Compute average precision for each sample
|
# Compute average precision for each sample
|
||||||
for sample_i, (labels, detections) in enumerate(zip(targets, output)):
|
for sample_i, (labels, detections) in enumerate(zip(targets, output)):
|
||||||
|
@ -134,13 +132,14 @@ if __name__ == '__main__':
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt, end='\n\n')
|
print(opt, end='\n\n')
|
||||||
|
|
||||||
mAP = test(
|
with torch.no_grad():
|
||||||
opt.cfg,
|
mAP = test(
|
||||||
opt.data_cfg,
|
opt.cfg,
|
||||||
opt.weights,
|
opt.data_cfg,
|
||||||
opt.batch_size,
|
opt.weights,
|
||||||
opt.img_size,
|
opt.batch_size,
|
||||||
opt.iou_thres,
|
opt.img_size,
|
||||||
opt.conf_thres,
|
opt.iou_thres,
|
||||||
opt.nms_thres
|
opt.conf_thres,
|
||||||
)
|
opt.nms_thres
|
||||||
|
)
|
||||||
|
|
3
train.py
3
train.py
|
@ -195,7 +195,8 @@ def train(
|
||||||
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
with torch.no_grad():
|
||||||
|
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
with open('results.txt', 'a') as file:
|
with open('results.txt', 'a') as file:
|
||||||
|
|
Loading…
Reference in New Issue