updates
This commit is contained in:
parent
c828f5459f
commit
9086caf0bb
|
@ -146,7 +146,7 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
def forward(self, p, targets=None, var=None):
|
def forward(self, p, targets=None, var=None):
|
||||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||||
bs = 1 if ONNX_EXPORT else p.shape[0] # batch size
|
bs = p.shape[0] # batch size
|
||||||
nG = self.nG # number of grid points
|
nG = self.nG # number of grid points
|
||||||
|
|
||||||
if p.is_cuda and not self.weights.is_cuda:
|
if p.is_cuda and not self.weights.is_cuda:
|
||||||
|
|
Loading…
Reference in New Issue