Update train.py
This commit is contained in:
parent
c7192f64c9
commit
06d264198c
5
train.py
5
train.py
|
@ -61,11 +61,6 @@ def train(
|
||||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||||
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||||
|
|
||||||
#initialize for distributed training
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
dist.init_process_group(backend=opt.backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
|
||||||
|
|
||||||
# Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
# for i, (name, p) in enumerate(model.named_parameters()):
|
# for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
# p.requires_grad = True if (p.shape[0] == 255) else False
|
# p.requires_grad = True if (p.shape[0] == 255) else False
|
||||||
|
|
Loading…
Reference in New Issue