simplify train.py
This commit is contained in:
parent
68de92f1a1
commit
29fbcb059f
73
train.py
73
train.py
|
@ -90,7 +90,7 @@ def main(opt):
|
||||||
modelinfo(model)
|
modelinfo(model)
|
||||||
t0, t1 = time.time(), time.time()
|
t0, t1 = time.time(), time.time()
|
||||||
print('%10s' * 16 % (
|
print('%10s' * 16 % (
|
||||||
'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nGT', 'TP', 'FP', 'FN', 'time'))
|
'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nTargets', 'TP', 'FP', 'FN', 'time'))
|
||||||
for epoch in range(opt.epochs):
|
for epoch in range(opt.epochs):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
|
@ -115,56 +115,49 @@ def main(opt):
|
||||||
metrics = torch.zeros(4, num_classes)
|
metrics = torch.zeros(4, num_classes)
|
||||||
for i, (imgs, targets) in enumerate(dataloader):
|
for i, (imgs, targets) in enumerate(dataloader):
|
||||||
|
|
||||||
n = opt.batch_size # number of pictures at a time
|
if sum([len(x) for x in targets]) < 1: # if no targets continue
|
||||||
for j in range(int(len(imgs) / n)):
|
continue
|
||||||
targets_j = targets[j * n:j * n + n]
|
|
||||||
nGT = sum([len(x) for x in targets_j])
|
|
||||||
if nGT < 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
loss = model(imgs[j * n:j * n + n].to(device), targets_j, requestPrecision=True, epoch=epoch)
|
loss = model(imgs.to(device), targets, requestPrecision=True, epoch=epoch)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
ui += 1
|
ui += 1
|
||||||
metrics += model.losses['metrics']
|
metrics += model.losses['metrics']
|
||||||
for key, val in model.losses.items():
|
for key, val in model.losses.items():
|
||||||
rloss[key] = (rloss[key] * ui + val) / (ui + 1)
|
rloss[key] = (rloss[key] * ui + val) / (ui + 1)
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16)
|
precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16)
|
||||||
k = (metrics[0] + metrics[1]) > 0
|
k = (metrics[0] + metrics[1]) > 0
|
||||||
if k.sum() > 0:
|
if k.sum() > 0:
|
||||||
mean_precision = precision[k].mean()
|
mean_precision = precision[k].mean()
|
||||||
else:
|
else:
|
||||||
mean_precision = 0
|
mean_precision = 0
|
||||||
|
|
||||||
# Recall
|
# Recall
|
||||||
recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16)
|
recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16)
|
||||||
k = (metrics[0] + metrics[2]) > 0
|
k = (metrics[0] + metrics[2]) > 0
|
||||||
if k.sum() > 0:
|
if k.sum() > 0:
|
||||||
mean_recall = recall[k].mean()
|
mean_recall = recall[k].mean()
|
||||||
else:
|
else:
|
||||||
mean_recall = 0
|
mean_recall = 0
|
||||||
|
|
||||||
s = ('%10s%10s' + '%10.3g' * 14) % (
|
s = ('%10s%10s' + '%10.3g' * 14) % (
|
||||||
'%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
|
'%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
|
||||||
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
|
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
|
||||||
rloss['loss'], mean_precision, mean_recall, model.losses['nGT'], model.losses['TP'],
|
rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
|
||||||
model.losses['FP'], model.losses['FN'], time.time() - t1)
|
model.losses['FP'], model.losses['FN'], time.time() - t1)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
# if i == 1:
|
|
||||||
# return
|
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
with open('results.txt', 'a') as file:
|
with open('results.txt', 'a') as file:
|
||||||
file.write(s + '\n')
|
file.write(s + '\n')
|
||||||
|
|
||||||
# Update best loss
|
# Update best loss
|
||||||
loss_per_target = rloss['loss'] / rloss['nGT']
|
loss_per_target = rloss['loss'] / rloss['nT']
|
||||||
if loss_per_target < best_loss:
|
if loss_per_target < best_loss:
|
||||||
best_loss = loss_per_target
|
best_loss = loss_per_target
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue