updates
This commit is contained in:
parent
748ff9b5b9
commit
47400aa066
3
train.py
3
train.py
|
@ -53,7 +53,7 @@ def train(
|
||||||
if checkpoint['optimizer'] is not None:
|
if checkpoint['optimizer'] is not None:
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
best_loss = checkpoint['best_loss']
|
best_loss = checkpoint['best_loss']
|
||||||
del checkpoint # current, saved
|
del checkpoint
|
||||||
|
|
||||||
else: # Initialize model with backbone (optional)
|
else: # Initialize model with backbone (optional)
|
||||||
if cfg.endswith('yolov3.cfg'):
|
if cfg.endswith('yolov3.cfg'):
|
||||||
|
@ -180,6 +180,7 @@ def train(
|
||||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, latest)
|
torch.save(checkpoint, latest)
|
||||||
|
del checkpoint
|
||||||
|
|
||||||
# Save best checkpoint
|
# Save best checkpoint
|
||||||
if best_loss == mloss['total']:
|
if best_loss == mloss['total']:
|
||||||
|
|
Loading…
Reference in New Issue