Update models.py
This commit is contained in:
parent
6fb14fc903
commit
473eb8d0c9
|
@ -265,6 +265,11 @@ class Darknet(nn.Module):
|
||||||
return sum(output) if is_training else torch.cat(output, 1)
|
return sum(output) if is_training else torch.cat(output, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_yolo_layers(model):
|
||||||
|
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]
|
||||||
|
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
||||||
|
|
||||||
|
|
||||||
def create_grids(self, img_size, nG):
|
def create_grids(self, img_size, nG):
|
||||||
self.stride = img_size / nG
|
self.stride = img_size / nG
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue