This commit is contained in:
Glenn Jocher 2019-02-10 21:10:50 +01:00
parent 51eb173416
commit 6f0086103c
2 changed files with 15 additions and 15 deletions

27
test.py
View File

@ -44,10 +44,8 @@ def test(
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
for batch_i, (imgs, targets) in enumerate(dataloader):
with torch.no_grad():
output = model(imgs.to(device))
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
output = model(imgs.to(device))
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
# Compute average precision for each sample
for sample_i, (labels, detections) in enumerate(zip(targets, output)):
@ -134,13 +132,14 @@ if __name__ == '__main__':
opt = parser.parse_args()
print(opt, end='\n\n')
mAP = test(
opt.cfg,
opt.data_cfg,
opt.weights,
opt.batch_size,
opt.img_size,
opt.iou_thres,
opt.conf_thres,
opt.nms_thres
)
with torch.no_grad():
mAP = test(
opt.cfg,
opt.data_cfg,
opt.weights,
opt.batch_size,
opt.img_size,
opt.iou_thres,
opt.conf_thres,
opt.nms_thres
)

View File

@ -195,7 +195,8 @@ def train(
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
# Calculate mAP
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
with torch.no_grad():
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
# Write epoch results
with open('results.txt', 'a') as file: