This commit is contained in:
Glenn Jocher 2019-04-16 13:03:24 +02:00
parent a8fb235647
commit 100f443722
1 changed files with 2 additions and 2 deletions

View File

@ -246,19 +246,19 @@ def compute_loss(p, targets): # predictions, targets
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0]) lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, indices = targets txy, twh, tcls, indices = targets
bs = p[0].shape[0] # batch size
MSE = nn.MSELoss() MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss() CE = nn.CrossEntropyLoss()
BCE = nn.BCEWithLogitsLoss() BCE = nn.BCEWithLogitsLoss()
# Compute losses # Compute losses
# bs = p[0].shape[0] # batch size
# gp = [x.numel() for x in tconf] # grid points # gp = [x.numel() for x in tconf] # grid points
for i, pi0 in enumerate(p): # layer i predictions, i for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
tconf = torch.zeros_like(pi0[..., 0]) # conf tconf = torch.zeros_like(pi0[..., 0]) # conf
# Compute losses # Compute losses
k = 8.4875 * bs k = 135.8
if len(b): # number of targets if len(b): # number of targets
pi = pi0[b, a, gj, gi] # predictions closest to anchors pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf tconf[b, a, gj, gi] = 1 # conf