This commit is contained in:
Glenn Jocher 2019-08-23 12:57:26 +02:00
parent ff7f73b642
commit 0d71fd8228
3 changed files with 3 additions and 2 deletions

View File

@ -88,6 +88,7 @@ def create_modules(module_defs, img_size):
bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85 bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85
bias[:, 4] += b[0] # obj bias[:, 4] += b[0] # obj
bias[:, 5:] += b[1] # cls bias[:, 5:] += b[1] # cls
# bias = torch.load('weights/yolov3-spp.bias.pt')[yolo_index] # list of tensors [3x85, 3x85, 3x85]
module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1)) module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1))
# utils.print_model_biases(model) # utils.print_model_biases(model)
except: except:

View File

@ -332,7 +332,7 @@ def train(cfg,
'training_results': file.read(), 'training_results': file.read(),
'model': model.module.state_dict() if type( 'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(), model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': None if final_epoch else optimizer.state_dict()}
# Save last checkpoint # Save last checkpoint
torch.save(chkpt, last) torch.save(chkpt, last)

View File

@ -17,7 +17,7 @@ from . import torch_utils # , google_utils
matplotlib.rc('font', **{'size': 11}) matplotlib.rc('font', **{'size': 11})
# Set printoptions # Set printoptions
torch.set_printoptions(linewidth=1320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
# Prevent OpenCV from multithreading (to use PyTorch DataLoader) # Prevent OpenCV from multithreading (to use PyTorch DataLoader)