This commit is contained in:
Glenn Jocher 2020-01-17 17:44:22 -08:00
parent 1ba9bd746b
commit cdb4680390
2 changed files with 3 additions and 2 deletions

View File

@ -195,7 +195,7 @@ class YOLOLayer(nn.Module):
io[..., :4] *= self.stride io[..., :4] *= self.stride
if 'default' in self.arc: # seperate obj and cls if 'default' in self.arc: # seperate obj and cls
torch.sigmoid_(io[..., 4]) torch.sigmoid_(io[..., 4:])
elif 'BCE' in self.arc: # unified BCE (80 classes) elif 'BCE' in self.arc: # unified BCE (80 classes)
torch.sigmoid_(io[..., 5:]) torch.sigmoid_(io[..., 5:])
io[..., 4] = 1 io[..., 4] = 1

View File

@ -512,6 +512,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
method = 'vision_batch' method = 'vision_batch'
nc = prediction[0].shape[1] - 5 # number of classes
multi_cls = multi_cls and (nc > 1)
output = [None] * len(prediction) output = [None] * len(prediction)
for image_i, pred in enumerate(prediction): for image_i, pred in enumerate(prediction):
# Apply conf constraint # Apply conf constraint
@ -525,7 +527,6 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
continue continue
# Compute conf # Compute conf
torch.sigmoid_(pred[..., 5:])
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)