From 0d71fd822846477b19cc8e728452e0f19f9bd995 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 23 Aug 2019 12:57:26 +0200 Subject: [PATCH] updates --- models.py | 1 + train.py | 2 +- utils/utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 249afb0c..0099e0a4 100755 --- a/models.py +++ b/models.py @@ -88,6 +88,7 @@ def create_modules(module_defs, img_size): bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85 bias[:, 4] += b[0] # obj bias[:, 5:] += b[1] # cls + # bias = torch.load('weights/yolov3-spp.bias.pt')[yolo_index] # list of tensors [3x85, 3x85, 3x85] module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1)) # utils.print_model_biases(model) except: diff --git a/train.py b/train.py index 58958764..e19d6cbc 100644 --- a/train.py +++ b/train.py @@ -332,7 +332,7 @@ def train(cfg, 'training_results': file.read(), 'model': model.module.state_dict() if type( model) is nn.parallel.DistributedDataParallel else model.state_dict(), - 'optimizer': optimizer.state_dict()} + 'optimizer': None if final_epoch else optimizer.state_dict()} # Save last checkpoint torch.save(chkpt, last) diff --git a/utils/utils.py b/utils/utils.py index ec4c9dbf..f7bd210e 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -17,7 +17,7 @@ from . import torch_utils # , google_utils matplotlib.rc('font', **{'size': 11}) # Set printoptions -torch.set_printoptions(linewidth=1320, precision=5, profile='long') +torch.set_printoptions(linewidth=320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 # Prevent OpenCV from multithreading (to use PyTorch DataLoader)