This commit is contained in:
Glenn Jocher 2019-02-16 14:47:16 +01:00
parent c828f5459f
commit 9086caf0bb
1 changed files with 1 additions and 1 deletions

View File

@ -146,7 +146,7 @@ class YOLOLayer(nn.Module):
def forward(self, p, targets=None, var=None): def forward(self, p, targets=None, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = 1 if ONNX_EXPORT else p.shape[0] # batch size bs = p.shape[0] # batch size
nG = self.nG # number of grid points nG = self.nG # number of grid points
if p.is_cuda and not self.weights.is_cuda: if p.is_cuda and not self.weights.is_cuda: