diff --git a/models.py b/models.py index 1d9dae0e..e9675f89 100755 --- a/models.py +++ b/models.py @@ -191,7 +191,27 @@ class YOLOLayer(nn.Module): create_grids(self, img_size, (nx, ny)) def forward(self, p, img_size, out): - if ONNX_EXPORT: + ASFF = False # https://arxiv.org/abs/1911.09516 + if ASFF: + i, n = self.index, self.nl # index in layers, number of layers + p = out[self.layers[i]] + bs, _, ny, nx = p.shape # bs, 255, 13, 13 + if (self.nx, self.ny) != (nx, ny): + create_grids(self, img_size, (nx, ny), p.device, p.dtype) + + # outputs and weights + # w = F.softmax(p[:, -n:], 1) # normalized weights + w = torch.sigmoid(p[:, -n:]) * (2 / n) # sigmoid weights (faster) + # w = w / w.sum(1).unsqueeze(1) # normalize across layer dimension + + # weighted ASFF sum + p = out[self.layers[i]][:, :-n] * w[:, i:i + 1] + for j in range(n): + if j != i: + p += w[:, j:j + 1] * \ + F.interpolate(out[self.layers[j]][:, :-n], size=[ny, nx], mode='bilinear', align_corners=False) + + elif ONNX_EXPORT: bs = 1 # batch size else: bs, _, ny, nx = p.shape # bs, 255, 13, 13