This commit is contained in:
Glenn Jocher 2018-11-22 17:13:47 +01:00
parent 57f2b3f6d7
commit 120af70798
1 changed files with 7 additions and 7 deletions

View File

@ -101,7 +101,7 @@ class YOLOLayer(nn.Module):
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
self.weights = class_weights()
self.loss_means = torch.zeros(6)
self.loss_means = torch.ones(6)
def forward(self, p, targets=None, batch_report=False):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
@ -113,7 +113,7 @@ class YOLOLayer(nn.Module):
if p.is_cuda and not self.grid_x.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
self.weights = self.weights.cuda()
self.weights, self.loss_means = self.weights.cuda(), self.loss_means.cuda()
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
@ -172,7 +172,7 @@ class YOLOLayer(nn.Module):
lh = k * MSELoss(h[mask], th[mask])
# lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
lconf = (k * 5) * BCEWithLogitsLoss(pred_conf, mask.float())
lconf = (k * 10) * BCEWithLogitsLoss(pred_conf, mask.float())
lcls = (k / 10) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
# lcls = (k * 10) * BCEWithLogitsLoss(pred_cls[mask], tcls.float())
@ -185,11 +185,11 @@ class YOLOLayer(nn.Module):
# Sum loss components
balance_losses_flag = False
if balance_losses_flag:
loss_vec = torch.FloatTensor([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data])
self.loss_means = self.loss_means * 0.99 + loss_vec * 0.01
k = 1 / self.loss_means.clone()
k /= k.sum()
loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) * loss_vec.sum()
loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) / k.mean()
self.loss_means = self.loss_means * 0.99 + \
FT([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data]) * 0.01
else:
loss = lx + ly + lw + lh + lconf + lcls