merge NMS speed/memory improvements
This commit is contained in:
		
							parent
							
								
									992e0d7cb4
								
							
						
					
					
						commit
						f4eecef700
					
				|  | @ -304,14 +304,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False): | ||||||
|     return iou |     return iou | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def box_iou(boxes1, boxes2): | def box_iou(box1, box2): | ||||||
|     # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py |     # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py | ||||||
|     """ |     """ | ||||||
|     Return intersection-over-union (Jaccard index) of boxes. |     Return intersection-over-union (Jaccard index) of boxes. | ||||||
|     Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |     Both sets of boxes are expected to be in (x1, y1, x2, y2) format. | ||||||
|     Arguments: |     Arguments: | ||||||
|         boxes1 (Tensor[N, 4]) |         box1 (Tensor[N, 4]) | ||||||
|         boxes2 (Tensor[M, 4]) |         box2 (Tensor[M, 4]) | ||||||
|     Returns: |     Returns: | ||||||
|         iou (Tensor[N, M]): the NxM matrix containing the pairwise |         iou (Tensor[N, M]): the NxM matrix containing the pairwise | ||||||
|             IoU values for every element in boxes1 and boxes2 |             IoU values for every element in boxes1 and boxes2 | ||||||
|  | @ -321,13 +321,11 @@ def box_iou(boxes1, boxes2): | ||||||
|         # box = 4xn |         # box = 4xn | ||||||
|         return (box[2] - box[0]) * (box[3] - box[1]) |         return (box[2] - box[0]) * (box[3] - box[1]) | ||||||
| 
 | 
 | ||||||
|     area1 = box_area(boxes1.t()) |     area1 = box_area(box1.t()) | ||||||
|     area2 = box_area(boxes2.t()) |     area2 = box_area(box2.t()) | ||||||
| 
 | 
 | ||||||
|     lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2] |     # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) | ||||||
|     rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2] |     inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) | ||||||
| 
 |  | ||||||
|     inter = (rb - lt).clamp(0).prod(2)  # [N,M] |  | ||||||
|     return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter) |     return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -509,6 +507,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T | ||||||
|     nc = prediction[0].shape[1] - 5  # number of classes |     nc = prediction[0].shape[1] - 5  # number of classes | ||||||
|     multi_label &= nc > 1  # multiple labels per box |     multi_label &= nc > 1  # multiple labels per box | ||||||
|     output = [None] * len(prediction) |     output = [None] * len(prediction) | ||||||
|  | 
 | ||||||
|     for xi, x in enumerate(prediction):  # image index, image inference |     for xi, x in enumerate(prediction):  # image index, image inference | ||||||
|         # Apply conf constraint |         # Apply conf constraint | ||||||
|         x = x[x[:, 4] > conf_thres] |         x = x[x[:, 4] > conf_thres] | ||||||
|  | @ -556,10 +555,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T | ||||||
|         boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores |         boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores | ||||||
|         if method == 'merge':  # Merge NMS (boxes merged using weighted mean) |         if method == 'merge':  # Merge NMS (boxes merged using weighted mean) | ||||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||||
|             if n < 1000:  # update boxes |             if n < 5000:  # update boxes | ||||||
|                 iou = box_iou(boxes, boxes).tril_()  # lower triangular iou matrix |                 weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1)  # box weights | ||||||
|                 weights = (iou > iou_thres) * scores.view(-1, 1) |                 weights /= weights.sum(0)  # normalize | ||||||
|                 weights /= weights.sum(0) |  | ||||||
|                 x[:, :4] = torch.mm(weights.T, x[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4) |                 x[:, :4] = torch.mm(weights.T, x[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4) | ||||||
|         elif method == 'vision': |         elif method == 'vision': | ||||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue