updates
This commit is contained in:
parent
935bbfcc2b
commit
d92b75aec8
|
@ -86,7 +86,7 @@ def detect(save_txt=False, save_img=False):
|
|||
pred = pred.float()
|
||||
|
||||
# Apply NMS
|
||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres)
|
||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes)
|
||||
|
||||
# Apply Classifier
|
||||
if classify:
|
||||
|
@ -110,9 +110,6 @@ def detect(save_txt=False, save_img=False):
|
|||
n = (det[:, -1] == c).sum() # detections per class
|
||||
s += '%g %ss, ' % (n, names[int(c)]) # add to string
|
||||
|
||||
# Print time (inference + NMS)
|
||||
print('%sDone. (%.3fs)' % (s, time.time() - t))
|
||||
|
||||
# Write results
|
||||
for *xyxy, conf, cls in det:
|
||||
if save_txt: # Write to file
|
||||
|
@ -123,6 +120,9 @@ def detect(save_txt=False, save_img=False):
|
|||
label = '%s %.2f' % (names[int(cls)], conf)
|
||||
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
|
||||
|
||||
# Print time (inference + NMS)
|
||||
print('%sDone. (%.3fs)' % (s, time.time() - t))
|
||||
|
||||
# Stream results
|
||||
if view_img:
|
||||
cv2.imshow(p, im0)
|
||||
|
@ -167,6 +167,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
|
||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
||||
parser.add_argument('--view-img', action='store_true', help='display results')
|
||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
|
|
|
@ -498,7 +498,7 @@ def build_targets(model, targets):
|
|||
return tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
|
@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
|||
conf, j = pred[:, 5:].max(1)
|
||||
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
|
||||
# Filter by class
|
||||
if classes:
|
||||
pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
if not torch.isfinite(pred).all():
|
||||
pred = pred[torch.isfinite(pred).all(1)]
|
||||
|
|
Loading…
Reference in New Issue