cleanup for #1114
This commit is contained in:
		
							parent
							
								
									fb1b5e09b2
								
							
						
					
					
						commit
						0ffbf5534e
					
				
							
								
								
									
										6
									
								
								test.py
								
								
								
								
							
							
						
						
									
										6
									
								
								test.py
								
								
								
								
							|  | @ -165,9 +165,9 @@ def test(cfg, | ||||||
|         # Plot images |         # Plot images | ||||||
|         if batch_i < 1: |         if batch_i < 1: | ||||||
|             f = 'test_batch%g_gt.jpg' % batch_i  # filename |             f = 'test_batch%g_gt.jpg' % batch_i  # filename | ||||||
|             plot_images(images=imgs, targets=targets, paths=paths, names=names, fname=f)  # ground truth |             plot_images(imgs, targets, paths=paths, names=names, fname=f)  # ground truth | ||||||
|             f = 'test_batch%g_pred.jpg' % batch_i  # filename |             f = 'test_batch%g_pred.jpg' % batch_i | ||||||
|             plot_images(images=imgs, targets=output_to_target(output, width, height), paths=paths, names=names, fname=f)  # predictions |             plot_images(imgs, output_to_target(output, width, height), paths=paths, names=names, fname=f)  # predictions | ||||||
| 
 | 
 | ||||||
|     # Compute statistics |     # Compute statistics | ||||||
|     stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy |     stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy | ||||||
|  |  | ||||||
							
								
								
									
										123
									
								
								utils/utils.py
								
								
								
								
							
							
						
						
									
										123
									
								
								utils/utils.py
								
								
								
								
							|  | @ -836,7 +836,7 @@ def output_to_target(output, width, height): | ||||||
|     [batch_id, class_id, x, y, w, h, conf] |     [batch_id, class_id, x, y, w, h, conf] | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|      | 
 | ||||||
|     if isinstance(output, torch.Tensor): |     if isinstance(output, torch.Tensor): | ||||||
|         output = output.cpu().numpy() |         output = output.cpu().numpy() | ||||||
| 
 | 
 | ||||||
|  | @ -846,10 +846,10 @@ def output_to_target(output, width, height): | ||||||
|         if o is not None: |         if o is not None: | ||||||
|             for pred in o: |             for pred in o: | ||||||
|                 box = pred[:4] |                 box = pred[:4] | ||||||
|                 w = (box[2]-box[0])/width |                 w = (box[2] - box[0]) / width | ||||||
|                 h = (box[3]-box[1])/height |                 h = (box[3] - box[1]) / height | ||||||
|                 x = box[0]/width + w/2 |                 x = box[0] / width + w / 2 | ||||||
|                 y = box[1]/height + h/2 |                 y = box[1] / height + h / 2 | ||||||
|                 conf = pred[4] |                 conf = pred[4] | ||||||
|                 cls = int(pred[5]) |                 cls = int(pred[5]) | ||||||
| 
 | 
 | ||||||
|  | @ -893,111 +893,80 @@ def plot_wh_methods():  # from utils.utils import *; plot_wh_methods() | ||||||
|     fig.savefig('comparison.png', dpi=200) |     fig.savefig('comparison.png', dpi=200) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def plot_images(images, targets, paths=None, fname='images.jpg', names=None, class_labels=True, confidence_labels=True, max_size=640, max_subplots=16): | def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): | ||||||
|  |     tl = 3  # line thickness | ||||||
|  |     tf = max(tl - 1, 1)  # font thickness | ||||||
| 
 | 
 | ||||||
|     if isinstance(images, torch.Tensor): |     if isinstance(images, torch.Tensor): | ||||||
|         images = images.cpu().numpy() |         images = images.cpu().numpy() | ||||||
|      | 
 | ||||||
|     if isinstance(targets, torch.Tensor): |     if isinstance(targets, torch.Tensor): | ||||||
|         targets = targets.cpu().numpy() |         targets = targets.cpu().numpy() | ||||||
|      | 
 | ||||||
|     # un-normalise |     # un-normalise | ||||||
|     if np.max(images[0]) <= 1: |     if np.max(images[0]) <= 1: | ||||||
|         images *= 255 |         images *= 255 | ||||||
|      | 
 | ||||||
|     bs, _, h, w = images.shape  # batch size, _, height, width |     bs, _, h, w = images.shape  # batch size, _, height, width | ||||||
|     bs = min(bs, max_subplots)  # limit plot images |     bs = min(bs, max_subplots)  # limit plot images | ||||||
|     ns = np.ceil(bs ** 0.5)  # number of subplots (square) |     ns = np.ceil(bs ** 0.5)  # number of subplots (square) | ||||||
|      | 
 | ||||||
|     # Check if we should resize |     # Check if we should resize | ||||||
|     should_resize = False |     scale_factor = max_size / max(h, w) | ||||||
|     if w > max_size or h > max_size: |     if scale_factor < 1: | ||||||
|         scale_factor = max_size/max(h, w) |         h = math.ceil(scale_factor * h) | ||||||
|         h = math.ceil(scale_factor*h) |         w = math.ceil(scale_factor * w) | ||||||
|         w = math.ceil(scale_factor*w) | 
 | ||||||
|         should_resize=True |  | ||||||
|          |  | ||||||
|     # Empty array for output |     # Empty array for output | ||||||
|     mosaic_width = int(ns*w) |     mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) | ||||||
|     mosaic_height = int(ns*h) | 
 | ||||||
|     mosaic = 255*np.ones((mosaic_height, mosaic_width, 3), dtype=np.uint8) |  | ||||||
|      |  | ||||||
|     # Fix class - colour map |     # Fix class - colour map | ||||||
|     prop_cycle = plt.rcParams['axes.prop_cycle'] |     prop_cycle = plt.rcParams['axes.prop_cycle'] | ||||||
|     # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb |     # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb | ||||||
|     hex2rgb = lambda h : tuple(int(h[1+i:1+i+2], 16) for i in (0, 2, 4))     |     hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) | ||||||
|     color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']] |     color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']] | ||||||
| 
 | 
 | ||||||
|     for i, image in enumerate(images): |     for i, img in enumerate(images): | ||||||
|          |         if i == max_subplots:  # if last batch has fewer images than we expect | ||||||
|         # e.g. if the last batch has fewer images than we expect |  | ||||||
|         if i == max_subplots: |  | ||||||
|             break |             break | ||||||
|              | 
 | ||||||
|         block_x = int(w * (i // ns)) |         block_x = int(w * (i // ns)) | ||||||
|         block_y = int(h * (i % ns)) |         block_y = int(h * (i % ns)) | ||||||
|          | 
 | ||||||
|         image = image.transpose(1,2,0) |         img = img.transpose(1, 2, 0) | ||||||
|          |         if scale_factor < 1: | ||||||
|         if should_resize: |             img = cv2.resize(img, (w, h)) | ||||||
|             image = cv2.resize(image, (w, h)) | 
 | ||||||
|          |         mosaic[block_y:block_y + h, block_x:block_x + w, :] = img | ||||||
|         mosaic[block_y:block_y+h, block_x:block_x+w,:] = image |  | ||||||
|          |  | ||||||
|         if targets is not None: |         if targets is not None: | ||||||
|             image_targets = targets[targets[:, 0] == i] |             image_targets = targets[targets[:, 0] == i] | ||||||
|             boxes = xywh2xyxy(image_targets[:,2:6]).T |             boxes = xywh2xyxy(image_targets[:, 2:6]).T | ||||||
|             classes = image_targets[:,1].astype('int') |             classes = image_targets[:, 1].astype('int') | ||||||
|              |             gt = image_targets.shape[1] == 6  # ground truth if no conf column | ||||||
|             # Check if we have object confidences (gt vs pred) |             conf = None if gt else image_targets[:, 6]  # check for confidence presence (gt vs pred) | ||||||
|             confidences = None | 
 | ||||||
|             if image_targets.shape[1] > 6: |  | ||||||
|                 confidences = image_targets[:,6] |  | ||||||
|                          |  | ||||||
|             boxes[[0, 2]] *= w |             boxes[[0, 2]] *= w | ||||||
|             boxes[[0, 2]] += block_x |             boxes[[0, 2]] += block_x | ||||||
|          |  | ||||||
|             boxes[[1, 3]] *= h |             boxes[[1, 3]] *= h | ||||||
|             boxes[[1, 3]] += block_y |             boxes[[1, 3]] += block_y | ||||||
|              |  | ||||||
|             for j, box in enumerate(boxes.T): |             for j, box in enumerate(boxes.T): | ||||||
|                 color = color_lut[int(classes[j]) % len(color_lut)] |                 cls = int(classes[j]) | ||||||
|                 box = box.astype(int) |                 color = color_lut[cls % len(color_lut)] | ||||||
|                 cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=2) |                 cls = names[cls] if names else cls | ||||||
|                  |                 if gt or conf[j] > 0.3:  # 0.3 conf thresh | ||||||
|                 # Draw class label |                     label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j]) | ||||||
|                 if class_labels and max_size > 250: |                     plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl) | ||||||
|                     label = str(classes[j]) if names is None else names[classes[j]] |  | ||||||
|                     if confidences is not None and confidence_labels: |  | ||||||
|                         label += " {:1.2f}".format(confidences[j]) |  | ||||||
|                      |  | ||||||
|                     font_scale = 0.4/10 * min(20, h * 0.05) |  | ||||||
|                     font_thickness = 2 if max(w, h) > 320 else 1 |  | ||||||
| 
 | 
 | ||||||
|                     label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) |  | ||||||
|                     cv2.rectangle(mosaic, (box[0], box[1]), (box[0]+label_size[0], box[1]-label_size[1]), color, thickness=-1) |  | ||||||
|                     cv2.putText(mosaic, label, (box[0], box[1]), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_thickness, color=(255,255,255)) |  | ||||||
|          |  | ||||||
|         # Draw image filename labels |         # Draw image filename labels | ||||||
|         if paths is not None: |         if paths is not None: | ||||||
|             # Trim to 40 chars |             label = os.path.basename(paths[i])[:40]  # trim to 40 char | ||||||
|             label = os.path.basename(paths[i])[:40] |             t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] | ||||||
|  |             cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf, | ||||||
|  |                         lineType=cv2.LINE_AA) | ||||||
| 
 | 
 | ||||||
|             # Empirical calculation to fit label |  | ||||||
|             # 0.4 is at most (13, 10) px per char at thickness = 1 |  | ||||||
|             # Fit label to 20px high, or shrink if it would be too big |  | ||||||
|             max_font_scale = (w/len(label))*(0.4/8) |  | ||||||
|             font_scale = min(0.4 * 20/8.5, max_font_scale) |  | ||||||
|             font_thickness = 1 |  | ||||||
|                  |  | ||||||
|             label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness) |  | ||||||
|              |  | ||||||
|             cv2.rectangle(mosaic, (block_x+5, block_y+label_size[1]+baseline+5), (block_x+label_size[0]+5, block_y), 0, thickness=-1) |  | ||||||
|             cv2.putText(mosaic, label, (block_x+5, block_y+label_size[1]+5), cv2.FONT_HERSHEY_DUPLEX, font_scale, (255,255,255), font_thickness) |  | ||||||
|              |  | ||||||
|         # Image border |         # Image border | ||||||
|         cv2.rectangle(mosaic, (block_x, block_y), (block_x+w, block_y+h), (255,255,255), thickness=3) |         cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) | ||||||
|          | 
 | ||||||
|     if fname is not None: |     if fname is not None: | ||||||
|         cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) |         cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue