Merge NMS update
This commit is contained in:
		
							parent
							
								
									eac07f9da3
								
							
						
					
					
						commit
						171b4129b5
					
				|  | @ -495,78 +495,75 @@ def build_targets(model, targets): | |||
| 
 | ||||
| def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False): | ||||
|     """ | ||||
|     Removes detections with lower object confidence score than 'conf_thres' | ||||
|     Non-Maximum Suppression to further filter detections. | ||||
|     Performs  Non-Maximum Suppression on inference results | ||||
|     Returns detections with shape: | ||||
|         (x1, y1, x2, y2, object_conf, conf, class) | ||||
|         nx6 (x1, y1, x2, y2, conf, cls) | ||||
|     """ | ||||
|     # NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch' | ||||
| 
 | ||||
|     # Box constraints | ||||
|     min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height | ||||
| 
 | ||||
|     method = 'vision' | ||||
|     method = 'merge' | ||||
|     nc = prediction[0].shape[1] - 5  # number of classes | ||||
|     multi_label &= nc > 1  # multiple labels per box | ||||
|     output = [None] * len(prediction) | ||||
|     for image_i, pred in enumerate(prediction): | ||||
|     for xi, x in enumerate(prediction):  # image index, image inference | ||||
|         # Apply conf constraint | ||||
|         pred = pred[pred[:, 4] > conf_thres] | ||||
|         x = x[x[:, 4] > conf_thres] | ||||
| 
 | ||||
|         # Apply width-height constraint | ||||
|         pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)] | ||||
|         x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] | ||||
| 
 | ||||
|         # If none remain process next image | ||||
|         if not pred.shape[0]: | ||||
|         if not x.shape[0]: | ||||
|             continue | ||||
| 
 | ||||
|         # Compute conf | ||||
|         pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf | ||||
|         x[..., 5:] *= x[..., 4:5]  # conf = obj_conf * cls_conf | ||||
| 
 | ||||
|         # Box (center x, center y, width, height) to (x1, y1, x2, y2) | ||||
|         box = xywh2xyxy(pred[:, :4]) | ||||
|         box = xywh2xyxy(x[:, :4]) | ||||
| 
 | ||||
|         # Detections matrix nx6 (xyxy, conf, cls) | ||||
|         if multi_label: | ||||
|             i, j = (pred[:, 5:] > conf_thres).nonzero().t() | ||||
|             pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||
|             i, j = (x[:, 5:] > conf_thres).nonzero().t() | ||||
|             x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||
|         else:  # best class only | ||||
|             conf, j = pred[:, 5:].max(1) | ||||
|             pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||
|             conf, j = x[:, 5:].max(1) | ||||
|             x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||
| 
 | ||||
|         # Filter by class | ||||
|         if classes: | ||||
|             pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)] | ||||
|             x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)] | ||||
| 
 | ||||
|         # Apply finite constraint | ||||
|         if not torch.isfinite(pred).all(): | ||||
|             pred = pred[torch.isfinite(pred).all(1)] | ||||
|         if not torch.isfinite(x).all(): | ||||
|             x = x[torch.isfinite(x).all(1)] | ||||
| 
 | ||||
|         # If none remain process next image | ||||
|         if not pred.shape[0]: | ||||
|         if not x.shape[0]: | ||||
|             continue | ||||
| 
 | ||||
|         # Sort by confidence | ||||
|         # if method == 'fast_batch': | ||||
|         #    pred = pred[pred[:, 4].argsort(descending=True)] | ||||
|         #    x = x[x[:, 4].argsort(descending=True)] | ||||
| 
 | ||||
|         # Batched NMS | ||||
|         c = pred[:, 5] * 0 if agnostic else pred[:, 5]  # classes | ||||
|         boxes, scores = pred[:, :4].clone(), pred[:, 4] | ||||
|         boxes += c.view(-1, 1) * max_wh  # offset boxes by class | ||||
|         if method == 'vision': | ||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||
|         elif method == 'merge':  # Merge NMS (boxes merged using weighted mean) | ||||
|         c = x[:, 5] * 0 if agnostic else x[:, 5]  # classes | ||||
|         boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores | ||||
|         if method == 'merge':  # Merge NMS (boxes merged using weighted mean) | ||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||
|             iou = box_iou(boxes, boxes[i]).tril_()  # lower triangular iou matrix | ||||
|             weights = (iou > iou_thres) * scores.view(-1, 1) | ||||
|             weights /= weights.sum(0) | ||||
|             pred[i, :4] = torch.matmul(weights.T, pred[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4) | ||||
|             x[i, :4] = torch.mm(weights.T, x[:, :4])  # merged_boxes(n,4) = weights(n,n) * boxes(n,4) | ||||
|         elif method == 'vision': | ||||
|             i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) | ||||
|         elif method == 'fast':  # FastNMS from https://github.com/dbolya/yolact | ||||
|             iou = box_iou(boxes, boxes).triu_(diagonal=1)  # upper triangular iou matrix | ||||
|             i = iou.max(0)[0] < iou_thres | ||||
| 
 | ||||
|         output[image_i] = pred[i] | ||||
|         output[xi] = x[i] | ||||
| 
 | ||||
| 
 | ||||
| def get_yolo_layers(model): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue