This commit is contained in:
Glenn Jocher 2020-01-01 12:44:33 -08:00
parent 935bbfcc2b
commit d92b75aec8
2 changed files with 10 additions and 5 deletions

View File

@ -86,7 +86,7 @@ def detect(save_txt=False, save_img=False):
pred = pred.float()
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres)
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes)
# Apply Classifier
if classify:
@ -110,9 +110,6 @@ def detect(save_txt=False, save_img=False):
n = (det[:, -1] == c).sum() # detections per class
s += '%g %ss, ' % (n, names[int(c)]) # add to string
# Print time (inference + NMS)
print('%sDone. (%.3fs)' % (s, time.time() - t))
# Write results
for *xyxy, conf, cls in det:
if save_txt: # Write to file
@ -123,6 +120,9 @@ def detect(save_txt=False, save_img=False):
label = '%s %.2f' % (names[int(cls)], conf)
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
# Print time (inference + NMS)
print('%sDone. (%.3fs)' % (s, time.time() - t))
# Stream results
if view_img:
cv2.imshow(p, im0)
@ -167,6 +167,7 @@ if __name__ == '__main__':
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
opt = parser.parse_args()
print(opt)

View File

@ -498,7 +498,7 @@ def build_targets(model, targets):
return tcls, tbox, indices, av
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
"""
Removes detections with lower object confidence score than 'conf_thres'
Non-Maximum Suppression to further filter detections.
@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
conf, j = pred[:, 5:].max(1)
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
# Filter by class
if classes:
pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)]
# Apply finite constraint
if not torch.isfinite(pred).all():
pred = pred[torch.isfinite(pred).all(1)]