updates
This commit is contained in:
parent
57f2b3f6d7
commit
120af70798
14
models.py
14
models.py
|
@ -101,7 +101,7 @@ class YOLOLayer(nn.Module):
|
|||
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
|
||||
self.weights = class_weights()
|
||||
|
||||
self.loss_means = torch.zeros(6)
|
||||
self.loss_means = torch.ones(6)
|
||||
|
||||
def forward(self, p, targets=None, batch_report=False):
|
||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||
|
@ -113,7 +113,7 @@ class YOLOLayer(nn.Module):
|
|||
if p.is_cuda and not self.grid_x.is_cuda:
|
||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
||||
self.weights = self.weights.cuda()
|
||||
self.weights, self.loss_means = self.weights.cuda(), self.loss_means.cuda()
|
||||
|
||||
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
@ -172,7 +172,7 @@ class YOLOLayer(nn.Module):
|
|||
lh = k * MSELoss(h[mask], th[mask])
|
||||
|
||||
# lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
|
||||
lconf = (k * 5) * BCEWithLogitsLoss(pred_conf, mask.float())
|
||||
lconf = (k * 10) * BCEWithLogitsLoss(pred_conf, mask.float())
|
||||
|
||||
lcls = (k / 10) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
||||
# lcls = (k * 10) * BCEWithLogitsLoss(pred_cls[mask], tcls.float())
|
||||
|
@ -185,11 +185,11 @@ class YOLOLayer(nn.Module):
|
|||
# Sum loss components
|
||||
balance_losses_flag = False
|
||||
if balance_losses_flag:
|
||||
loss_vec = torch.FloatTensor([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data])
|
||||
self.loss_means = self.loss_means * 0.99 + loss_vec * 0.01
|
||||
k = 1 / self.loss_means.clone()
|
||||
k /= k.sum()
|
||||
loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) * loss_vec.sum()
|
||||
loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) / k.mean()
|
||||
|
||||
self.loss_means = self.loss_means * 0.99 + \
|
||||
FT([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data]) * 0.01
|
||||
else:
|
||||
loss = lx + ly + lw + lh + lconf + lcls
|
||||
|
||||
|
|
Loading…
Reference in New Issue