Merge NMS update
This commit is contained in:
		
							parent
							
								
									eac07f9da3
								
							
						
					
					
						commit
						171b4129b5
					
				| 
						 | 
				
			
			@ -495,78 +495,75 @@ def build_targets(model, targets):
 | 
			
		|||
 | 
			
		||||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
 | 
			
		||||
    """
 | 
			
		||||
    Removes detections with lower object confidence score than 'conf_thres'
 | 
			
		||||
    Non-Maximum Suppression to further filter detections.
 | 
			
		||||
    Performs  Non-Maximum Suppression on inference results
 | 
			
		||||
    Returns detections with shape:
 | 
			
		||||
        (x1, y1, x2, y2, object_conf, conf, class)
 | 
			
		||||
        nx6 (x1, y1, x2, y2, conf, cls)
 | 
			
		||||
    """
 | 
			
		||||
    # NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'
 | 
			
		||||
 | 
			
		||||
    # Box constraints
 | 
			
		||||
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
 | 
			
		||||
 | 
			
		||||
    method = 'vision'
 | 
			
		||||
    method = 'merge'
 | 
			
		||||
    nc = prediction[0].shape[1] - 5  # number of classes
 | 
			
		||||
    multi_label &= nc > 1  # multiple labels per box
 | 
			
		||||
    output = [None] * len(prediction)
 | 
			
		||||
    for image_i, pred in enumerate(prediction):
 | 
			
		||||
    for xi, x in enumerate(prediction):  # image index, image inference
 | 
			
		||||
        # Apply conf constraint
 | 
			
		||||
        pred = pred[pred[:, 4] > conf_thres]
 | 
			
		||||
        x = x[x[:, 4] > conf_thres]
 | 
			
		||||
 | 
			
		||||
        # Apply width-height constraint
 | 
			
		||||
        pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)]
 | 
			
		||||
        x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]
 | 
			
		||||
 | 
			
		||||
        # If none remain process next image
 | 
			
		||||
        if not pred.shape[0]:
 | 
			
		||||
        if not x.shape[0]:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Compute conf
 | 
			
		||||
        pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf
 | 
			
		||||
        x[..., 5:] *= x[..., 4:5]  # conf = obj_conf * cls_conf
 | 
			
		||||
 | 
			
		||||
        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
 | 
			
		||||
        box = xywh2xyxy(pred[:, :4])
 | 
			
		||||
        box = xywh2xyxy(x[:, :4])
 | 
			
		||||
 | 
			
		||||
        # Detections matrix nx6 (xyxy, conf, cls)
 | 
			
		||||
        if multi_label:
 | 
			
		||||
            i, j = (pred[:, 5:] > conf_thres).nonzero().t()
 | 
			
		||||
            pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
 | 
			
		||||
            i, j = (x[:, 5:] > conf_thres).nonzero().t()
 | 
			
		||||
            x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
 | 
			
		||||
        else:  # best class only
 | 
			
		||||
            conf, j = pred[:, 5:].max(1)
 | 
			
		||||
            pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
 | 
			
		||||
            conf, j = x[:, 5:].max(1)
 | 
			
		||||
            x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
 | 
			
		||||
 | 
			
		||||
        # Filter by class
 | 
			
		||||
        if classes:
 | 
			
		||||
            pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
 | 
			
		||||
            x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
 | 
			
		||||
 | 
			
		||||
        # Apply finite constraint
 | 
			
		||||
        if not torch.isfinite(pred).all():
 | 
			
		||||
            pred = pred[torch.isfinite(pred).all(1)]
 | 
			
		||||
        if not torch.isfinite(x).all():
 | 
			
		||||
            x = x[torch.isfinite(x).all(1)]
 | 
			
		||||
 | 
			
		||||
        # If none remain process next image
 | 
			
		||||
        if not pred.shape[0]:
 | 
			
		||||
        if not x.shape[0]:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Sort by confidence
 | 
			
		||||
        # if method == 'fast_batch':
 | 
			
		||||
        #    pred = pred[pred[:, 4].argsort(descending=True)]
 | 
			
		||||
        #    x = x[x[:, 4].argsort(descending=True)]
 | 
			
		||||
 | 
			
		||||
        # Batched NMS
 | 
			
		||||
        c = pred[:, 5] * 0 if agnostic else pred[:, 5]  # classes
 | 
			
		||||
        boxes, scores = pred[:, :4].clone(), pred[:, 4]
 | 
			
		||||
        boxes += c.view(-1, 1) * max_wh  # offset boxes by class
 | 
			
		||||
        if method == 'vision':
 | 
			
		||||
            i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
 | 
			
		||||
        elif method == 'merge':  # Merge NMS (boxes merged using weighted mean)
 | 
			
		||||
        c = x[:, 5] * 0 if agnostic else x[:, 5]  # classes
 | 
			
		||||
        boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores
 | 
			
		||||
        if method == 'merge':  # Merge NMS (boxes merged using weighted mean)
 | 
			
		||||
            i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
 | 
			
		||||
            iou = box_iou(boxes, boxes[i]).tril_()  # lower triangular iou matrix
 | 
			
		||||
            weights = (iou > iou_thres) * scores.view(-1, 1)
 | 
			
		||||
            weights /= weights.sum(0)
 | 
			
		||||
            pred[i, :4] = torch.matmul(weights.T, pred[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
 | 
			
		||||
            x[i, :4] = torch.mm(weights.T, x[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
 | 
			
		||||
        elif method == 'vision':
 | 
			
		||||
            i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
 | 
			
		||||
        elif method == 'fast':  # FastNMS from https://github.com/dbolya/yolact
 | 
			
		||||
            iou = box_iou(boxes, boxes).triu_(diagonal=1)  # upper triangular iou matrix
 | 
			
		||||
            i = iou.max(0)[0] < iou_thres
 | 
			
		||||
 | 
			
		||||
        output[image_i] = pred[i]
 | 
			
		||||
        output[xi] = x[i]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_yolo_layers(model):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue