updates
This commit is contained in:
parent
95f3d8e043
commit
947ee02115
4
train.py
4
train.py
|
@ -51,7 +51,7 @@ def train(
|
|||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||
if resume: # Load previously saved model
|
||||
if transfer: # Transfer learning
|
||||
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
||||
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
||||
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
||||
strict=False)
|
||||
for p in model.parameters():
|
||||
|
@ -227,7 +227,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--img-size', type=int, default=416, help='pixels')
|
||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
|
||||
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
||||
parser.add_argument('--num-workers', type=int, default=0, help='number of Pytorch DataLoader workers')
|
||||
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
|
||||
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
|
||||
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
|
||||
|
|
Loading…
Reference in New Issue