updates
This commit is contained in:
parent
ff7f73b642
commit
0d71fd8228
|
@ -88,6 +88,7 @@ def create_modules(module_defs, img_size):
|
|||
bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85
|
||||
bias[:, 4] += b[0] # obj
|
||||
bias[:, 5:] += b[1] # cls
|
||||
# bias = torch.load('weights/yolov3-spp.bias.pt')[yolo_index] # list of tensors [3x85, 3x85, 3x85]
|
||||
module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1))
|
||||
# utils.print_model_biases(model)
|
||||
except:
|
||||
|
|
2
train.py
2
train.py
|
@ -332,7 +332,7 @@ def train(cfg,
|
|||
'training_results': file.read(),
|
||||
'model': model.module.state_dict() if type(
|
||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||
|
||||
# Save last checkpoint
|
||||
torch.save(chkpt, last)
|
||||
|
|
|
@ -17,7 +17,7 @@ from . import torch_utils # , google_utils
|
|||
matplotlib.rc('font', **{'size': 11})
|
||||
|
||||
# Set printoptions
|
||||
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
|
||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
|
||||
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
|
||||
|
|
Loading…
Reference in New Issue