updates
This commit is contained in:
parent
935bbfcc2b
commit
d92b75aec8
|
@ -86,7 +86,7 @@ def detect(save_txt=False, save_img=False):
|
||||||
pred = pred.float()
|
pred = pred.float()
|
||||||
|
|
||||||
# Apply NMS
|
# Apply NMS
|
||||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres)
|
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes)
|
||||||
|
|
||||||
# Apply Classifier
|
# Apply Classifier
|
||||||
if classify:
|
if classify:
|
||||||
|
@ -110,9 +110,6 @@ def detect(save_txt=False, save_img=False):
|
||||||
n = (det[:, -1] == c).sum() # detections per class
|
n = (det[:, -1] == c).sum() # detections per class
|
||||||
s += '%g %ss, ' % (n, names[int(c)]) # add to string
|
s += '%g %ss, ' % (n, names[int(c)]) # add to string
|
||||||
|
|
||||||
# Print time (inference + NMS)
|
|
||||||
print('%sDone. (%.3fs)' % (s, time.time() - t))
|
|
||||||
|
|
||||||
# Write results
|
# Write results
|
||||||
for *xyxy, conf, cls in det:
|
for *xyxy, conf, cls in det:
|
||||||
if save_txt: # Write to file
|
if save_txt: # Write to file
|
||||||
|
@ -123,6 +120,9 @@ def detect(save_txt=False, save_img=False):
|
||||||
label = '%s %.2f' % (names[int(cls)], conf)
|
label = '%s %.2f' % (names[int(cls)], conf)
|
||||||
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
|
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
|
||||||
|
|
||||||
|
# Print time (inference + NMS)
|
||||||
|
print('%sDone. (%.3fs)' % (s, time.time() - t))
|
||||||
|
|
||||||
# Stream results
|
# Stream results
|
||||||
if view_img:
|
if view_img:
|
||||||
cv2.imshow(p, im0)
|
cv2.imshow(p, im0)
|
||||||
|
@ -167,6 +167,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
|
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
|
||||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
||||||
parser.add_argument('--view-img', action='store_true', help='display results')
|
parser.add_argument('--view-img', action='store_true', help='display results')
|
||||||
|
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
|
|
|
@ -498,7 +498,7 @@ def build_targets(model, targets):
|
||||||
return tcls, tbox, indices, av
|
return tcls, tbox, indices, av
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
|
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
|
||||||
"""
|
"""
|
||||||
Removes detections with lower object confidence score than 'conf_thres'
|
Removes detections with lower object confidence score than 'conf_thres'
|
||||||
Non-Maximum Suppression to further filter detections.
|
Non-Maximum Suppression to further filter detections.
|
||||||
|
@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||||
conf, j = pred[:, 5:].max(1)
|
conf, j = pred[:, 5:].max(1)
|
||||||
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
|
|
||||||
|
# Filter by class
|
||||||
|
if classes:
|
||||||
|
pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)]
|
||||||
|
|
||||||
# Apply finite constraint
|
# Apply finite constraint
|
||||||
if not torch.isfinite(pred).all():
|
if not torch.isfinite(pred).all():
|
||||||
pred = pred[torch.isfinite(pred).all(1)]
|
pred = pred[torch.isfinite(pred).all(1)]
|
||||||
|
|
Loading…
Reference in New Issue