diff --git a/detect.py b/detect.py index 74838a49..da919b37 100644 --- a/detect.py +++ b/detect.py @@ -86,7 +86,7 @@ def detect(save_txt=False, save_img=False): pred = pred.float() # Apply NMS - pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres) + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes) # Apply Classifier if classify: @@ -110,9 +110,6 @@ def detect(save_txt=False, save_img=False): n = (det[:, -1] == c).sum() # detections per class s += '%g %ss, ' % (n, names[int(c)]) # add to string - # Print time (inference + NMS) - print('%sDone. (%.3fs)' % (s, time.time() - t)) - # Write results for *xyxy, conf, cls in det: if save_txt: # Write to file @@ -123,6 +120,9 @@ def detect(save_txt=False, save_img=False): label = '%s %.2f' % (names[int(cls)], conf) plot_one_box(xyxy, im0, label=label, color=colors[int(cls)]) + # Print time (inference + NMS) + print('%sDone. (%.3fs)' % (s, time.time() - t)) + # Stream results if view_img: cv2.imshow(p, im0) @@ -167,6 +167,7 @@ if __name__ == '__main__': parser.add_argument('--half', action='store_true', help='half precision FP16 inference') parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') parser.add_argument('--view-img', action='store_true', help='display results') + parser.add_argument('--classes', nargs='+', type=int, help='filter by class') opt = parser.parse_args() print(opt) diff --git a/utils/utils.py b/utils/utils.py index c7da014c..684fbf48 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -498,7 +498,7 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'): +def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. @@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru conf, j = pred[:, 5:].max(1) pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) + # Filter by class + if classes: + pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)] + # Apply finite constraint if not torch.isfinite(pred).all(): pred = pred[torch.isfinite(pred).all(1)]