updates
This commit is contained in:
parent
aecf840701
commit
2856af5036
6
train.py
6
train.py
|
@ -64,9 +64,9 @@ def train(
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# # Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
# for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
p.requires_grad = True if (p.shape[0] == 255) else False
|
# p.requires_grad = True if (p.shape[0] == 255) else False
|
||||||
|
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||||
|
|
Loading…
Reference in New Issue