updates
This commit is contained in:
parent
0f3018124f
commit
12e605165e
5
train.py
5
train.py
|
@ -50,10 +50,9 @@ def train(
|
||||||
# model = nn.DataParallel(model)
|
# model = nn.DataParallel(model)
|
||||||
model.to(device).train()
|
model.to(device).train()
|
||||||
|
|
||||||
# # Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
# for i, (name, p) in enumerate(model.named_parameters()):
|
# for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
# if p.shape[0] != 650: # not YOLO layer
|
# p.requires_grad = True if (p.shape[0] == 255) else False
|
||||||
# p.requires_grad = False
|
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
|
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
|
||||||
|
|
Loading…
Reference in New Issue