updates
This commit is contained in:
parent
1ba9bd746b
commit
cdb4680390
|
@ -195,7 +195,7 @@ class YOLOLayer(nn.Module):
|
|||
io[..., :4] *= self.stride
|
||||
|
||||
if 'default' in self.arc: # seperate obj and cls
|
||||
torch.sigmoid_(io[..., 4])
|
||||
torch.sigmoid_(io[..., 4:])
|
||||
elif 'BCE' in self.arc: # unified BCE (80 classes)
|
||||
torch.sigmoid_(io[..., 5:])
|
||||
io[..., 4] = 1
|
||||
|
|
|
@ -512,6 +512,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
|||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||
|
||||
method = 'vision_batch'
|
||||
nc = prediction[0].shape[1] - 5 # number of classes
|
||||
multi_cls = multi_cls and (nc > 1)
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
# Apply conf constraint
|
||||
|
@ -525,7 +527,6 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
|||
continue
|
||||
|
||||
# Compute conf
|
||||
torch.sigmoid_(pred[..., 5:])
|
||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
|
|
Loading…
Reference in New Issue