This commit is contained in:
Glenn Jocher 2020-01-10 09:30:05 -08:00
parent 0219eb094e
commit 793f6389dc
2 changed files with 6 additions and 3 deletions

View File

@ -86,7 +86,7 @@ def detect(save_img=False):
pred = pred.float() pred = pred.float()
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
# Apply Classifier # Apply Classifier
if classify: if classify:
@ -169,6 +169,7 @@ if __name__ == '__main__':
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='display results')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class') parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)

View File

@ -499,7 +499,7 @@ def build_targets(model, targets):
return tcls, tbox, indices, av return tcls, tbox, indices, av
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None): def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, classes=None, agnostic=False):
""" """
Removes detections with lower object confidence score than 'conf_thres' Removes detections with lower object confidence score than 'conf_thres'
Non-Maximum Suppression to further filter detections. Non-Maximum Suppression to further filter detections.
@ -511,6 +511,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
# Box constraints # Box constraints
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
method = 'vision_batch'
output = [None] * len(prediction) output = [None] * len(prediction)
for image_i, pred in enumerate(prediction): for image_i, pred in enumerate(prediction):
# Apply conf constraint # Apply conf constraint
@ -548,7 +549,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
# Batched NMS # Batched NMS
if method == 'vision_batch': if method == 'vision_batch':
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], iou_thres)] c = j * 0 if agnostic else j # class-agnostic NMS
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)]
continue continue
# Sort by confidence # Sort by confidence