CUBLAS bug fix #1139
This commit is contained in:
		
							parent
							
								
									832ceba559
								
							
						
					
					
						commit
						965155ee60
					
				|  | @ -556,13 +556,15 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T | ||||||
|         boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores |         boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores | ||||||
|         if method == 'merge':  # Merge NMS (boxes merged using weighted mean) |         if method == 'merge':  # Merge NMS (boxes merged using weighted mean) | ||||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||||
|             if n < 3E3:  # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) |             if 1 < n < 3E3:  # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) | ||||||
|  |                 try: | ||||||
|                     # weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1)  # box weights |                     # weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1)  # box weights | ||||||
|                     # weights /= weights.sum(0)  # normalize |                     # weights /= weights.sum(0)  # normalize | ||||||
|                     # x[:, :4] = torch.mm(weights.T, x[:, :4]) |                     # x[:, :4] = torch.mm(weights.T, x[:, :4]) | ||||||
|                     weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None]  # box weights |                     weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None]  # box weights | ||||||
|                     x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float()  # merged boxes |                     x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float()  # merged boxes | ||||||
| 
 |                 except:  # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139 | ||||||
|  |                     pass | ||||||
|         elif method == 'vision': |         elif method == 'vision': | ||||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||||
|         elif method == 'fast':  # FastNMS from https://github.com/dbolya/yolact |         elif method == 'fast':  # FastNMS from https://github.com/dbolya/yolact | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue