updates
This commit is contained in:
		
							parent
							
								
									fd949a8eb3
								
							
						
					
					
						commit
						674d0de170
					
				|  | @ -501,8 +501,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|         # Box (center x, center y, width, height) to (x1, y1, x2, y2) |         # Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||||||
|         pred[:, :4] = xywh2xyxy(pred[:, :4]) |         pred[:, :4] = xywh2xyxy(pred[:, :4]) | ||||||
| 
 | 
 | ||||||
|         # Detections ordered as (x1y1x2y2, conf, cls) |         # Expand | ||||||
|         pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) |         expand = False | ||||||
|  |         if expand: | ||||||
|  |             i, j = (pred[:, 4:] > conf_thres).nonzero().t() | ||||||
|  |             pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1)  # (x1y1x2y2, conf, cls) | ||||||
|  |         else: | ||||||
|  |             pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) | ||||||
| 
 | 
 | ||||||
|         # Get detections sorted by decreasing confidence scores |         # Get detections sorted by decreasing confidence scores | ||||||
|         pred = pred[(-pred[:, 4]).argsort()] |         pred = pred[(-pred[:, 4]).argsort()] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue