diff --git a/models.py b/models.py index ded1ae71..a2db26e9 100755 --- a/models.py +++ b/models.py @@ -265,6 +265,11 @@ class Darknet(nn.Module): return sum(output) if is_training else torch.cat(output, 1) +def get_yolo_layers(model): + a = [module_def['type'] == 'yolo' for module_def in model.module_defs] + return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3 + + def create_grids(self, img_size, nG): self.stride = img_size / nG