Non-output layer freeze in train.py (#1333)

Freeze layers that aren't of type YOLOLayer and that aren't the conv layers preceeding them
This commit is contained in:
Oulbacha Reda 2020-06-22 16:15:40 -04:00 committed by GitHub
parent ca7794ed05
commit a97f350461
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 0 deletions

View File

@ -144,6 +144,16 @@ def train(hyp):
# possible weights are '*.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc. # possible weights are '*.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc.
load_darknet_weights(model, weights) load_darknet_weights(model, weights)
if opt.freeze_layers:
output_layer_indices = [idx - 1 for idx, module in enumerate(model.module_list) \
if isinstance(module, YOLOLayer)]
freeze_layer_indices = [x for x in range(len(model.module_list)) if\
(x not in output_layer_indices) and \
(x - 1 not in output_layer_indices)]
for idx in freeze_layer_indices:
for parameter in model.module_list[idx].parameters():
parameter.requires_grad_(False)
# Mixed precision training https://github.com/NVIDIA/apex # Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision: if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
@ -394,6 +404,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)') parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--adam', action='store_true', help='use adam optimizer') parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--freeze-layers', action='store_true', help='Freeze non-output layers')
opt = parser.parse_args() opt = parser.parse_args()
opt.weights = last if opt.resume else opt.weights opt.weights = last if opt.resume else opt.weights
check_git_status() check_git_status()