updates
This commit is contained in:
		
							parent
							
								
									06e88fec08
								
							
						
					
					
						commit
						db26b08f5b
					
				|  | @ -491,13 +491,9 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru | |||
| 
 | ||||
|     output = [None] * len(prediction) | ||||
|     for image_i, pred in enumerate(prediction): | ||||
|         # Retain > conf | ||||
|         # Apply conf constraint | ||||
|         pred = pred[pred[:, 4] > conf_thres] | ||||
| 
 | ||||
|         # Compute conf | ||||
|         torch.sigmoid_(pred[..., 5:]) | ||||
|         pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf | ||||
| 
 | ||||
|         # Apply width-height constraint | ||||
|         pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)] | ||||
| 
 | ||||
|  | @ -505,6 +501,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru | |||
|         if len(pred) == 0: | ||||
|             continue | ||||
| 
 | ||||
|         # Compute conf | ||||
|         torch.sigmoid_(pred[..., 5:]) | ||||
|         pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf | ||||
| 
 | ||||
|         # Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||||
|         box = xywh2xyxy(pred[:, :4]) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue