updates
This commit is contained in:
		
							parent
							
								
									fd949a8eb3
								
							
						
					
					
						commit
						674d0de170
					
				|  | @ -501,8 +501,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | |||
|         # Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||||
|         pred[:, :4] = xywh2xyxy(pred[:, :4]) | ||||
| 
 | ||||
|         # Detections ordered as (x1y1x2y2, conf, cls) | ||||
|         pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) | ||||
|         # Expand | ||||
|         expand = False | ||||
|         if expand: | ||||
|             i, j = (pred[:, 4:] > conf_thres).nonzero().t() | ||||
|             pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1)  # (x1y1x2y2, conf, cls) | ||||
|         else: | ||||
|             pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) | ||||
| 
 | ||||
|         # Get detections sorted by decreasing confidence scores | ||||
|         pred = pred[(-pred[:, 4]).argsort()] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue