Update models.py
This commit is contained in:
parent
6fb14fc903
commit
473eb8d0c9
|
@ -265,6 +265,11 @@ class Darknet(nn.Module):
|
|||
return sum(output) if is_training else torch.cat(output, 1)
|
||||
|
||||
|
||||
def get_yolo_layers(model):
|
||||
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]
|
||||
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
||||
|
||||
|
||||
def create_grids(self, img_size, nG):
|
||||
self.stride = img_size / nG
|
||||
|
||||
|
|
Loading…
Reference in New Issue