updates
This commit is contained in:
parent
a8fb235647
commit
100f443722
|
@ -246,19 +246,19 @@ def compute_loss(p, targets): # predictions, targets
|
||||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||||
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
|
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
|
||||||
txy, twh, tcls, indices = targets
|
txy, twh, tcls, indices = targets
|
||||||
bs = p[0].shape[0] # batch size
|
|
||||||
MSE = nn.MSELoss()
|
MSE = nn.MSELoss()
|
||||||
CE = nn.CrossEntropyLoss()
|
CE = nn.CrossEntropyLoss()
|
||||||
BCE = nn.BCEWithLogitsLoss()
|
BCE = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
|
# bs = p[0].shape[0] # batch size
|
||||||
# gp = [x.numel() for x in tconf] # grid points
|
# gp = [x.numel() for x in tconf] # grid points
|
||||||
for i, pi0 in enumerate(p): # layer i predictions, i
|
for i, pi0 in enumerate(p): # layer i predictions, i
|
||||||
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
|
b, a, gj, gi = indices[i] # image, anchor, gridx, gridy
|
||||||
tconf = torch.zeros_like(pi0[..., 0]) # conf
|
tconf = torch.zeros_like(pi0[..., 0]) # conf
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
k = 8.4875 * bs
|
k = 135.8
|
||||||
if len(b): # number of targets
|
if len(b): # number of targets
|
||||||
pi = pi0[b, a, gj, gi] # predictions closest to anchors
|
pi = pi0[b, a, gj, gi] # predictions closest to anchors
|
||||||
tconf[b, a, gj, gi] = 1 # conf
|
tconf[b, a, gj, gi] = 1 # conf
|
||||||
|
|
Loading…
Reference in New Issue