ASFF implementation

This commit is contained in:
Glenn Jocher 2020-03-14 17:04:38 -07:00
parent ea4c26b32d
commit 5ebbb2db28
1 changed files with 21 additions and 1 deletions

View File

@ -191,7 +191,27 @@ class YOLOLayer(nn.Module):
create_grids(self, img_size, (nx, ny)) create_grids(self, img_size, (nx, ny))
def forward(self, p, img_size, out): def forward(self, p, img_size, out):
if ONNX_EXPORT: ASFF = False # https://arxiv.org/abs/1911.09516
if ASFF:
i, n = self.index, self.nl # index in layers, number of layers
p = out[self.layers[i]]
bs, _, ny, nx = p.shape # bs, 255, 13, 13
if (self.nx, self.ny) != (nx, ny):
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
# outputs and weights
# w = F.softmax(p[:, -n:], 1) # normalized weights
w = torch.sigmoid(p[:, -n:]) * (2 / n) # sigmoid weights (faster)
# w = w / w.sum(1).unsqueeze(1) # normalize across layer dimension
# weighted ASFF sum
p = out[self.layers[i]][:, :-n] * w[:, i:i + 1]
for j in range(n):
if j != i:
p += w[:, j:j + 1] * \
F.interpolate(out[self.layers[j]][:, :-n], size=[ny, nx], mode='bilinear', align_corners=False)
elif ONNX_EXPORT:
bs = 1 # batch size bs = 1 # batch size
else: else:
bs, _, ny, nx = p.shape # bs, 255, 13, 13 bs, _, ny, nx = p.shape # bs, 255, 13, 13