Fixed NMS bug causing big CPU usage. Note that using 'cross_class_nms' still takes a huge amount of time and should be fixed somehow.
This commit is contained in:
		
							parent
							
								
									a46e500f9e
								
							
						
					
					
						commit
						d41f85702d
					
				| 
						 | 
				
			
			@ -285,8 +285,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
 | 
			
		||||
    prediction = prediction.cpu()
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    Removes detections with lower object confidence score than 'conf_thres' and performs
 | 
			
		||||
    Non-Maximum Suppression to further filter detections.
 | 
			
		||||
| 
						 | 
				
			
			@ -302,15 +300,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
 | 
			
		|||
        # cross-class NMS
 | 
			
		||||
        cross_class_nms = False
 | 
			
		||||
        if cross_class_nms:
 | 
			
		||||
            thresh = 0.85
 | 
			
		||||
            # thresh = 0.85
 | 
			
		||||
            thresh = nms_thres
 | 
			
		||||
            a = pred.clone()
 | 
			
		||||
            a = a[np.argsort(-a[:, 4])]  # sort best to worst
 | 
			
		||||
            _, indices = torch.sort(-a[:, 4], 0) # sort best to worst
 | 
			
		||||
            a = a[indices]
 | 
			
		||||
            radius = 30  # area to search for cross-class ious
 | 
			
		||||
            for i in range(len(a)):
 | 
			
		||||
                if i >= len(a) - 1:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                close = (np.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (np.abs(a[i, 1] - a[i + 1:, 1]) < radius)
 | 
			
		||||
                close = (torch.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (torch.abs(a[i, 1] - a[i + 1:, 1]) < radius)
 | 
			
		||||
                close = close.nonzero()
 | 
			
		||||
 | 
			
		||||
                if len(close) > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -324,10 +324,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
 | 
			
		|||
                        a = a[mask]
 | 
			
		||||
            pred = a
 | 
			
		||||
 | 
			
		||||
        x, y, w, h = pred[:, 0].numpy(), pred[:, 1].numpy(), pred[:, 2].numpy(), pred[:, 3].numpy()
 | 
			
		||||
        x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
 | 
			
		||||
        a = w * h  # area
 | 
			
		||||
        ar = w / (h + 1e-16)  # aspect ratio
 | 
			
		||||
        log_w, log_h, log_a, log_ar = np.log(w), np.log(h), np.log(a), np.log(ar)
 | 
			
		||||
 | 
			
		||||
        log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
 | 
			
		||||
 | 
			
		||||
        # n = len(w)
 | 
			
		||||
        # shape_likelihood = np.zeros((n, 60), dtype=np.float32)
 | 
			
		||||
| 
						 | 
				
			
			@ -338,8 +339,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
 | 
			
		|||
 | 
			
		||||
        class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
 | 
			
		||||
 | 
			
		||||
        v = ((pred[:, 4] > conf_thres) & (class_prob > .3)).numpy()
 | 
			
		||||
        v = v.nonzero()
 | 
			
		||||
        v = ((pred[:, 4] > conf_thres) & (class_prob > .3))
 | 
			
		||||
        v = v.nonzero().squeeze()
 | 
			
		||||
        if len(v.shape) == 0:
 | 
			
		||||
            v = v.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        pred = pred[v]
 | 
			
		||||
        class_prob = class_prob[v]
 | 
			
		||||
| 
						 | 
				
			
			@ -363,7 +366,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
 | 
			
		|||
        # Iterate through all predicted classes
 | 
			
		||||
        unique_labels = detections[:, -1].cpu().unique()
 | 
			
		||||
        if prediction.is_cuda:
 | 
			
		||||
            unique_labels = unique_labels.cuda()
 | 
			
		||||
            unique_labels = unique_labels.cuda(prediction.device)
 | 
			
		||||
 | 
			
		||||
        nms_style = 'OR'  # 'AND' or 'OR' (classical)
 | 
			
		||||
        for c in unique_labels:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue