From da9ec7d12fc767747384c05745abdccff3710245 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 8 Jul 2019 18:00:19 +0200 Subject: [PATCH] updates --- train.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index c9aadb70..b3514cc6 100644 --- a/train.py +++ b/train.py @@ -61,12 +61,13 @@ def train( cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 best_fitness = 0.0 - nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) if opt.resume or opt.transfer: # Load previously saved model if opt.transfer: # Transfer learning + nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device) model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255}, strict=False) + for p in model.parameters(): p.requires_grad = True if p.shape[0] == nf else False @@ -74,10 +75,14 @@ def train( chkpt = torch.load(latest, map_location=device) # load checkpoint model.load_state_dict(chkpt['model']) - start_epoch = chkpt['epoch'] + 1 if chkpt['optimizer'] is not None: optimizer.load_state_dict(chkpt['optimizer']) best_fitness = chkpt['best_fitness'] + + with open('results.txt', 'w') as file: + file.write(chkpt['training_results']) # write results.txt + + start_epoch = chkpt['epoch'] + 1 del chkpt else: # Initialize model with backbone (optional) @@ -246,12 +251,14 @@ def train( # Save training results save = (not opt.nosave) or (epoch == epochs - 1) if save: - # Create checkpoint - chkpt = {'epoch': epoch, - 'best_fitness': best_fitness, - 'model': model.module.state_dict() if type( - model) is nn.parallel.DistributedDataParallel else model.state_dict(), - 'optimizer': optimizer.state_dict()} + with open('results.txt', 'r') as file: + # Create checkpoint + chkpt = {'epoch': epoch, + 'best_fitness': best_fitness, + 'training_results': file.read(), + 'model': model.module.state_dict() if type( + model) is nn.parallel.DistributedDataParallel else model.state_dict(), + 'optimizer': optimizer.state_dict()} # Save latest checkpoint torch.save(chkpt, latest)