updates
This commit is contained in:
		
							parent
							
								
									707d6ea965
								
							
						
					
					
						commit
						cb63ce30ec
					
				
							
								
								
									
										27
									
								
								test.py
								
								
								
								
							
							
						
						
									
										27
									
								
								test.py
								
								
								
								
							|  | @ -37,18 +37,15 @@ def test( | ||||||
|     model.to(device).eval() |     model.to(device).eval() | ||||||
| 
 | 
 | ||||||
|     # Get dataloader |     # Get dataloader | ||||||
|     # dataloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path), batch_size=batch_size)  # pytorch |     # dataloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path), batch_size=batch_size) | ||||||
|     dataloader = LoadImagesAndLabels(test_path, batch_size=batch_size, img_size=img_size) |     dataloader = LoadImagesAndLabels(test_path, batch_size=batch_size, img_size=img_size) | ||||||
| 
 | 
 | ||||||
|     # Create JSON |  | ||||||
|     jdict = [] |  | ||||||
|     float3 = lambda x: float(format(x, '.3f'))  # print json to 3 decimals |  | ||||||
|     # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... |  | ||||||
| 
 |  | ||||||
|     mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0 |     mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0 | ||||||
|     print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP')) |     print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP')) | ||||||
|     outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], [] |     outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class, jdict = \ | ||||||
|  |         [], [], [], [], [], [], [], [], [] | ||||||
|     AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) |     AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) | ||||||
|  |     coco91class = coco80_to_coco91_class() | ||||||
|     for batch_i, (imgs, targets, paths, shapes) in enumerate(dataloader): |     for batch_i, (imgs, targets, paths, shapes) in enumerate(dataloader): | ||||||
|         output = model(imgs.to(device)) |         output = model(imgs.to(device)) | ||||||
|         output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) |         output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) | ||||||
|  | @ -67,18 +64,18 @@ def test( | ||||||
|             detections = detections.cpu().numpy() |             detections = detections.cpu().numpy() | ||||||
|             detections = detections[np.argsort(-detections[:, 4])] |             detections = detections[np.argsort(-detections[:, 4])] | ||||||
| 
 | 
 | ||||||
|             # Save JSON |  | ||||||
|             if save_json: |             if save_json: | ||||||
|                 # rescale box to original image size, top left origin |                 # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... | ||||||
|                 box = torch.from_numpy(detections[:, :4]).clone()  # x1y1x2y2 |                 box = torch.from_numpy(detections[:, :4]).clone()  # xyxy | ||||||
|                 scale_coords(img_size, box, shapes[si]) |                 scale_coords(img_size, box, shapes[si])  # to original shape | ||||||
|                 box = xyxy2xywh(box) |                 box = xyxy2xywh(box)  # xywh | ||||||
|                 box[:, :2] -= box[:, 2:] / 2  # origin center to corner |                 box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||||
| 
 | 
 | ||||||
|  |                 # add to json dictionary | ||||||
|                 for di, d in enumerate(detections): |                 for di, d in enumerate(detections): | ||||||
|                     jdict.append({  # add to json dictionary |                     jdict.append({ | ||||||
|                         'image_id': int(Path(paths[si]).stem.split('_')[-1]), |                         'image_id': int(Path(paths[si]).stem.split('_')[-1]), | ||||||
|                         'category_id': darknet2coco_class(int(d[6])), |                         'category_id': coco91class(int(d[6])), | ||||||
|                         'bbox': [float3(x) for x in box[di]], |                         'bbox': [float3(x) for x in box[di]], | ||||||
|                         'score': float3(d[4] * d[5]) |                         'score': float3(d[4] * d[5]) | ||||||
|                     }) |                     }) | ||||||
|  |  | ||||||
|  | @ -12,6 +12,10 @@ torch.set_printoptions(linewidth=1320, precision=5, profile='long') | ||||||
| np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5 | np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def float3(x):  # format floats to 3 decimals | ||||||
|  |     return float(format(x, '.3f')) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def init_seeds(seed=0): | def init_seeds(seed=0): | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     np.random.seed(seed) |     np.random.seed(seed) | ||||||
|  | @ -49,12 +53,12 @@ def coco_class_weights():  # frequency of each class in coco train2014 | ||||||
|     return weights |     return weights | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def darknet2coco_class(c):  # returns the coco class for each darknet class | def coco80_to_coco91_class():  # returns the coco class for each darknet class | ||||||
|     # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ |     # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ | ||||||
|     a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') |     a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') | ||||||
|     b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') |     b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') | ||||||
|     x = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco |     x = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco | ||||||
|     return x[c] |     return x | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def plot_one_box(x, img, color=None, label=None, line_thickness=None):  # Plots one bounding box on image img | def plot_one_box(x, img, color=None, label=None, line_thickness=None):  # Plots one bounding box on image img | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue