updates
This commit is contained in:
parent
51eb173416
commit
6f0086103c
3
test.py
3
test.py
|
@ -44,8 +44,6 @@ def test(
|
||||||
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
||||||
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
||||||
for batch_i, (imgs, targets) in enumerate(dataloader):
|
for batch_i, (imgs, targets) in enumerate(dataloader):
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(imgs.to(device))
|
output = model(imgs.to(device))
|
||||||
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
|
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
|
||||||
|
|
||||||
|
@ -134,6 +132,7 @@ if __name__ == '__main__':
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt, end='\n\n')
|
print(opt, end='\n\n')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
mAP = test(
|
mAP = test(
|
||||||
opt.cfg,
|
opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data_cfg,
|
||||||
|
|
1
train.py
1
train.py
|
@ -195,6 +195,7 @@ def train(
|
||||||
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
|
with torch.no_grad():
|
||||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
|
|
Loading…
Reference in New Issue