updates
This commit is contained in:
parent
1ba9bd746b
commit
cdb4680390
|
@ -195,7 +195,7 @@ class YOLOLayer(nn.Module):
|
||||||
io[..., :4] *= self.stride
|
io[..., :4] *= self.stride
|
||||||
|
|
||||||
if 'default' in self.arc: # seperate obj and cls
|
if 'default' in self.arc: # seperate obj and cls
|
||||||
torch.sigmoid_(io[..., 4])
|
torch.sigmoid_(io[..., 4:])
|
||||||
elif 'BCE' in self.arc: # unified BCE (80 classes)
|
elif 'BCE' in self.arc: # unified BCE (80 classes)
|
||||||
torch.sigmoid_(io[..., 5:])
|
torch.sigmoid_(io[..., 5:])
|
||||||
io[..., 4] = 1
|
io[..., 4] = 1
|
||||||
|
|
|
@ -512,6 +512,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||||
|
|
||||||
method = 'vision_batch'
|
method = 'vision_batch'
|
||||||
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
|
multi_cls = multi_cls and (nc > 1)
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
for image_i, pred in enumerate(prediction):
|
for image_i, pred in enumerate(prediction):
|
||||||
# Apply conf constraint
|
# Apply conf constraint
|
||||||
|
@ -525,7 +527,6 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Compute conf
|
# Compute conf
|
||||||
torch.sigmoid_(pred[..., 5:])
|
|
||||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
||||||
|
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
|
|
Loading…
Reference in New Issue