updates
This commit is contained in:
		
							parent
							
								
									f788a57009
								
							
						
					
					
						commit
						2ef92f5651
					
				|  | @ -369,12 +369,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): | ||||||
|         if prediction.is_cuda: |         if prediction.is_cuda: | ||||||
|             unique_labels = unique_labels.cuda(prediction.device) |             unique_labels = unique_labels.cuda(prediction.device) | ||||||
| 
 | 
 | ||||||
|         nms_style = 'MERGE'  # 'OR' (default), 'AND', 'MERGE' (experimental) |         nms_style = 'OR'  # 'OR' (default), 'AND', 'MERGE' (experimental) | ||||||
|         for c in unique_labels: |         for c in unique_labels: | ||||||
|             # Get the detections with class c |             # Get the detections with class c | ||||||
|             dc = detections[detections[:, -1] == c] |             dc = detections[detections[:, -1] == c] | ||||||
|             # Sort the detections by maximum object confidence |             # Sort the detections by maximum object confidence | ||||||
|             _, conf_sort_index = torch.sort(dc[:, 4], descending=True) |             _, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True) | ||||||
|             dc = dc[conf_sort_index] |             dc = dc[conf_sort_index] | ||||||
| 
 | 
 | ||||||
|             # Non-maximum suppression |             # Non-maximum suppression | ||||||
|  | @ -411,6 +411,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): | ||||||
|                 #  4964       5000      0.632      0.597      0.588  # normal |                 #  4964       5000      0.632      0.597      0.588  # normal | ||||||
|                 #  4964       5000      0.632      0.597      0.588  # squared |                 #  4964       5000      0.632      0.597      0.588  # squared | ||||||
|                 #  4964       5000      0.631      0.597      0.588  # sqrt |                 #  4964       5000      0.631      0.597      0.588  # sqrt | ||||||
|  |                 # normal best_v1_0.pt | ||||||
| 
 | 
 | ||||||
|             if len(det_max) > 0: |             if len(det_max) > 0: | ||||||
|                 det_max = torch.cat(det_max) |                 det_max = torch.cat(det_max) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue