diff --git a/models.py b/models.py index 09de14d5..bfc5ac9f 100755 --- a/models.py +++ b/models.py @@ -155,14 +155,17 @@ class YOLOLayer(nn.Module): # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method io[..., :4] *= self.stride - arc = 'normal' # (normal, uCE, uBCE) architecture types + arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures if arc == 'normal': - io[..., 4:] = torch.sigmoid(io[..., 4:]) - elif arc == 'uCE': - io[..., 4:] = F.softmax(io[..., 4:], dim=4) # unified detection CE + torch.sigmoid_(io[..., 4:]) + elif arc == 'uCE': # unified CE (1 background + 80 classes) + io[..., 4:] = F.softmax(io[..., 4:], dim=4) io[..., 4] = 1 - elif arc == 'uBCE': - io[..., 4] = io[..., 5:].max(4)[0] # unified detection BCE + elif arc == 'uBCE': # unified BCE (1 background + 80 classes) + torch.sigmoid_(io[..., 4:]) + io[..., 4] = 1 - io[..., 4] + elif arc == 'uBCEs': # unified BCE simplified (80 classes) + torch.sigmoid_(io[..., 4:]) if self.nc == 1: io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235 diff --git a/utils/utils.py b/utils/utils.py index 76b53ebc..4c7e1fec 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -321,12 +321,12 @@ def compute_loss(p, targets, model): # predictions, targets, model # Define criteria BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) - CE = nn.CrossEntropyLoss(weight=model.class_weights) + # CE = nn.CrossEntropyLoss(weight=model.class_weights) # Compute losses bs = p[0].shape[0] # batch size k = bs / 64 # loss gain - arc = 'normal' # (normal, uCE, uBCE) architecture types + arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures for i, pi0 in enumerate(p): # layer i predictions, i b, a, gj, gi = indices[i] # image, anchor, gridy, gridx tobj = torch.zeros_like(pi0[..., 0]) # target obj @@ -342,33 +342,42 @@ def compute_loss(p, targets, model): # predictions, targets, model pxy = torch.sigmoid(pi[..., 0:2]) # pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy) pbox = torch.cat((pxy, torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation - lbox += (k * h['giou']) * (1.0 - giou).mean() # giou loss + lbox += (1.0 - giou).mean() # giou loss if arc == 'normal' and model.nc > 1: # cls loss (only if multiple classes) - tclsm = torch.zeros_like(pi[..., 5:]) - tclsm[range(nb), tcls[i]] = 1.0 - lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE - # lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE + t = torch.zeros_like(pi[..., 5:]) # targets + t[range(nb), tcls[i]] = 1.0 + lcls += BCEcls(pi[..., 5:], t) # BCE + # lcls += CE(pi[..., 5:], tcls[i]) # CE # Append targets to text file # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] if arc == 'normal': - lobj += (k * h['obj']) * BCEobj(pi0[..., 4], tobj) # obj loss + lobj += BCEobj(pi0[..., 4], tobj) # obj loss - elif arc == 'uCE': # suggest h['cls']=5. - udm_ce = torch.zeros_like(pi0[..., 0]).long() # unified detection matrix for CE + elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20 + t = torch.zeros_like(pi0[..., 0], dtype=torch.long) # targets if nb: - udm_ce[b, a, gj, gi] = tcls[i] + 1 - lcls += (k * h['cls']) * CE(pi0[..., 4:].view(-1, model.nc + 1), udm_ce.view(-1)) # unified CE + t[b, a, gj, gi] = tcls[i] + 1 + lcls += CE(pi0[..., 4:].view(-1, model.nc + 1), t.view(-1)) - elif arc == 'uBCE': - udm = torch.zeros_like(pi0[..., 5:]) # unified detection matrix for BCE + elif arc == 'uBCE': # unified BCE (1 background + 80 classes), hyps 200-30 + t = torch.zeros_like(pi0[..., 5:]) # targets if nb: - udm[b, a, gj, gi, tcls[i]] = 1.0 - lcls += (k * h['cls']) * BCEcls(pi0[..., 5:], udm) # unified BCE (hyps 200-30) + t[b, a, gj, gi, tcls[i]] = 1.0 + lcls += BCEcls(pi0[..., 5:], t) + elif arc == 'uBCEs': # unified BCE simplified (80 classes) + t = torch.zeros_like(pi0[..., 5:]) # targets + if nb: + t[b, a, gj, gi, tcls[i]] = 1.0 + lcls += BCEcls(pi0[..., 5:], t) + + lbox *= k * h['giou'] + lobj *= k * h['obj'] + lcls *= k * h['cls'] loss = lbox + lobj + lcls return loss, torch.cat((lbox, ft([0]), lobj, lcls, loss)).detach()