diff --git a/models.py b/models.py index 0df43951..fd404375 100755 --- a/models.py +++ b/models.py @@ -101,7 +101,7 @@ class YOLOLayer(nn.Module): self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1)) self.weights = class_weights() - self.loss_means = torch.zeros(6) + self.loss_means = torch.ones(6) def forward(self, p, targets=None, batch_report=False): FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor @@ -113,7 +113,7 @@ class YOLOLayer(nn.Module): if p.is_cuda and not self.grid_x.is_cuda: self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda() - self.weights = self.weights.cuda() + self.weights, self.loss_means = self.weights.cuda(), self.loss_means.cuda() # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction @@ -172,7 +172,7 @@ class YOLOLayer(nn.Module): lh = k * MSELoss(h[mask], th[mask]) # lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float()) - lconf = (k * 5) * BCEWithLogitsLoss(pred_conf, mask.float()) + lconf = (k * 10) * BCEWithLogitsLoss(pred_conf, mask.float()) lcls = (k / 10) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) # lcls = (k * 10) * BCEWithLogitsLoss(pred_cls[mask], tcls.float()) @@ -185,11 +185,11 @@ class YOLOLayer(nn.Module): # Sum loss components balance_losses_flag = False if balance_losses_flag: - loss_vec = torch.FloatTensor([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data]) - self.loss_means = self.loss_means * 0.99 + loss_vec * 0.01 k = 1 / self.loss_means.clone() - k /= k.sum() - loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) * loss_vec.sum() + loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) / k.mean() + + self.loss_means = self.loss_means * 0.99 + \ + FT([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data]) * 0.01 else: loss = lx + ly + lw + lh + lconf + lcls