updates
This commit is contained in:
parent
e67cee4a0c
commit
3006c33c29
15
train.py
15
train.py
|
@ -167,6 +167,10 @@ def train(
|
||||||
from apex import amp
|
from apex import amp
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
|
|
||||||
|
# Remove old results
|
||||||
|
for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg') + 'results.txt':
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
@ -175,8 +179,6 @@ def train(
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
|
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
|
||||||
for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg'):
|
|
||||||
os.remove(f)
|
|
||||||
t, t0 = time.time(), time.time()
|
t, t0 = time.time(), time.time()
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -185,7 +187,7 @@ def train(
|
||||||
# Update scheduler
|
# Update scheduler
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Freeze backbone at epoch 0, unfreeze at epoch 1
|
# Freeze backbone at epoch 0, unfreeze at epoch 1 (optional)
|
||||||
if freeze_backbone and epoch < 2:
|
if freeze_backbone and epoch < 2:
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
||||||
|
@ -200,7 +202,6 @@ def train(
|
||||||
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
nt = len(targets)
|
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
# Plot images with bounding boxes
|
||||||
if epoch == 0 and i == 0:
|
if epoch == 0 and i == 0:
|
||||||
|
@ -233,13 +234,11 @@ def train(
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Update running mean of tracked metrics
|
|
||||||
mloss = (mloss * i + loss_items) / (i + 1)
|
|
||||||
|
|
||||||
# Print batch results
|
# Print batch results
|
||||||
|
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
||||||
s = ('%8s%12s' + '%10.3g' * 7) % (
|
s = ('%8s%12s' + '%10.3g' * 7) % (
|
||||||
'%g/%g' % (epoch, epochs - 1),
|
'%g/%g' % (epoch, epochs - 1),
|
||||||
'%g/%g' % (i, nb - 1), *mloss, nt, time.time() - t)
|
'%g/%g' % (i, nb - 1), *mloss, len(targets), time.time() - t)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue