diff --git a/models.py b/models.py index cf59daef..48815452 100755 --- a/models.py +++ b/models.py @@ -152,7 +152,6 @@ class YOLOLayer(nn.Module): self.no = nc + 5 # number of outputs self.nx = 0 # initialize number of x gridpoints self.ny = 0 # initialize number of y gridpoints - self.oi = [0, 1, 2, 3] + list(range(5, self.no)) # output indices self.arc = arc if ONNX_EXPORT: # grids must be computed in __init__ @@ -210,7 +209,7 @@ class YOLOLayer(nn.Module): io[..., :4] *= self.stride if 'default' in self.arc: # seperate obj and cls - torch.sigmoid_(io[..., 4:]) + torch.sigmoid_(io[..., 4]) elif 'BCE' in self.arc: # unified BCE (80 classes) torch.sigmoid_(io[..., 5:]) io[..., 4] = 1 @@ -221,11 +220,8 @@ class YOLOLayer(nn.Module): if self.nc == 1: io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235 - # compute conf - io[..., 5:] *= io[..., 4:5] # conf = obj_conf * cls_conf - # reshape from [1, 3, 13, 13, 85] to [1, 507, 84], remove obj_conf - return io[..., self.oi].view(bs, -1, self.no - 1), p + return io.view(bs, -1, self.no), p class Darknet(nn.Module): diff --git a/utils/utils.py b/utils/utils.py index 8d37e62b..c75d3077 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -492,9 +492,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru output = [None] * len(prediction) for image_i, pred in enumerate(prediction): # Remove rows - pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold + pred = pred[pred[:, 4] > conf_thres] # retain above threshold - # Select only suitable predictions + # compute conf + torch.sigmoid_(pred[..., 5:]) + pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf + + # Apply width-height constraint i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1) pred = pred[i] @@ -507,10 +511,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru # Multi-class if multi_cls or conf_thres < 0.01: - i, j = (pred[:, 4:] > conf_thres).nonzero().t() - pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1) + i, j = (pred[:, 5:] > conf_thres).nonzero().t() + pred = torch.cat((pred[i, :4], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only - conf, j = pred[:, 4:].max(1) + conf, j = pred[:, 5:].max(1) pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls) # Get detections sorted by decreasing confidence scores