diff --git a/utils/utils.py b/utils/utils.py index 165a0f4b..7a584026 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -243,7 +243,7 @@ def wh_iou(box1, box2): def compute_loss(p, targets): # predictions, targets - FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor + FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]) txy, twh, tcls, indices = targets MSE = nn.MSELoss()