updates
This commit is contained in:
		
							parent
							
								
									674d0de170
								
							
						
					
					
						commit
						aaaaa06156
					
				|  | @ -459,15 +459,15 @@ def build_targets(model, targets): | ||||||
|     return tcls, tbox, indices, av |     return tcls, tbox, indices, av | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'): | ||||||
|     """ |     """ | ||||||
|     Removes detections with lower object confidence score than 'conf_thres' |     Removes detections with lower object confidence score than 'conf_thres' | ||||||
|     Non-Maximum Suppression to further filter detections. |     Non-Maximum Suppression to further filter detections. | ||||||
|     Returns detections with shape: |     Returns detections with shape: | ||||||
|         (x1, y1, x2, y2, object_conf, conf, class) |         (x1, y1, x2, y2, object_conf, conf, class) | ||||||
|     """ |     """ | ||||||
|     # NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED' |     # NMS method https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch' | ||||||
|     method = 'MERGE' if conf_thres <= 0.01 else 'VISION'  # MERGE is highest mAP, VISION is fastest |     # method = 'merge' if conf_thres <= 0.01 else 'vision'  # MERGE is highest mAP, VISION is fastest | ||||||
| 
 | 
 | ||||||
|     # Box constraints |     # Box constraints | ||||||
|     min_wh, max_wh = 2, 10000  # (pixels) minimum and maximium box width and height |     min_wh, max_wh = 2, 10000  # (pixels) minimum and maximium box width and height | ||||||
|  | @ -501,19 +501,18 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|         # Box (center x, center y, width, height) to (x1, y1, x2, y2) |         # Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||||||
|         pred[:, :4] = xywh2xyxy(pred[:, :4]) |         pred[:, :4] = xywh2xyxy(pred[:, :4]) | ||||||
| 
 | 
 | ||||||
|         # Expand |         # Multi-class | ||||||
|         expand = False |         if multi_cls: | ||||||
|         if expand: |  | ||||||
|             i, j = (pred[:, 4:] > conf_thres).nonzero().t() |             i, j = (pred[:, 4:] > conf_thres).nonzero().t() | ||||||
|             pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1)  # (x1y1x2y2, conf, cls) |             pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||||
|         else: |         else: | ||||||
|             pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) |             pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)  # (xyxy, conf, cls) | ||||||
| 
 | 
 | ||||||
|         # Get detections sorted by decreasing confidence scores |         # Get detections sorted by decreasing confidence scores | ||||||
|         pred = pred[(-pred[:, 4]).argsort()] |         pred = pred[(-pred[:, 4]).argsort()] | ||||||
| 
 | 
 | ||||||
|         # Batched NMS |         # Batched NMS | ||||||
|         if method == 'VISION_BATCHED': |         if method == 'vision_batch': | ||||||
|             i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4], |             i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4], | ||||||
|                                                   scores=pred[:, 4], |                                                   scores=pred[:, 4], | ||||||
|                                                   idxs=pred[:, 6], |                                                   idxs=pred[:, 6], | ||||||
|  | @ -532,11 +531,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|             elif n > 500: |             elif n > 500: | ||||||
|                 dc = dc[:500]  # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117 |                 dc = dc[:500]  # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117 | ||||||
| 
 | 
 | ||||||
|             if method == 'VISION': |             if method == 'vision': | ||||||
|                 i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres) |                 i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres) | ||||||
|                 det_max.append(dc[i]) |                 det_max.append(dc[i]) | ||||||
| 
 | 
 | ||||||
|             elif method == 'OR':  # default |             elif method == 'or':  # default | ||||||
|                 # METHOD1 |                 # METHOD1 | ||||||
|                 # ind = list(range(len(dc))) |                 # ind = list(range(len(dc))) | ||||||
|                 # while len(ind): |                 # while len(ind): | ||||||
|  | @ -553,14 +552,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|                     iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes |                     iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes | ||||||
|                     dc = dc[1:][iou < nms_thres]  # remove ious > threshold |                     dc = dc[1:][iou < nms_thres]  # remove ious > threshold | ||||||
| 
 | 
 | ||||||
|             elif method == 'AND':  # requires overlap, single boxes erased |             elif method == 'and':  # requires overlap, single boxes erased | ||||||
|                 while len(dc) > 1: |                 while len(dc) > 1: | ||||||
|                     iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes |                     iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes | ||||||
|                     if iou.max() > 0.5: |                     if iou.max() > 0.5: | ||||||
|                         det_max.append(dc[:1]) |                         det_max.append(dc[:1]) | ||||||
|                     dc = dc[1:][iou < nms_thres]  # remove ious > threshold |                     dc = dc[1:][iou < nms_thres]  # remove ious > threshold | ||||||
| 
 | 
 | ||||||
|             elif method == 'MERGE':  # weighted mixture box |             elif method == 'merge':  # weighted mixture box | ||||||
|                 while len(dc): |                 while len(dc): | ||||||
|                     if len(dc) == 1: |                     if len(dc) == 1: | ||||||
|                         det_max.append(dc) |                         det_max.append(dc) | ||||||
|  | @ -571,7 +570,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|                     det_max.append(dc[:1]) |                     det_max.append(dc[:1]) | ||||||
|                     dc = dc[i == 0] |                     dc = dc[i == 0] | ||||||
| 
 | 
 | ||||||
|             elif method == 'SOFT':  # soft-NMS https://arxiv.org/abs/1704.04503 |             elif method == 'soft':  # soft-NMS https://arxiv.org/abs/1704.04503 | ||||||
|                 sigma = 0.5  # soft-nms sigma parameter |                 sigma = 0.5  # soft-nms sigma parameter | ||||||
|                 while len(dc): |                 while len(dc): | ||||||
|                     if len(dc) == 1: |                     if len(dc) == 1: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue