move inference augmentation to model.forward()
This commit is contained in:
		
							parent
							
								
									4da5c6c114
								
							
						
					
					
						commit
						c6d4e80335
					
				|  | @ -84,9 +84,13 @@ def detect(save_img=False): | ||||||
| 
 | 
 | ||||||
|         # Inference |         # Inference | ||||||
|         t1 = torch_utils.time_synchronized() |         t1 = torch_utils.time_synchronized() | ||||||
|         pred = model(img)[0].float() if half else model(img)[0] |         pred = model(img, augment=opt.augment)[0] | ||||||
|         t2 = torch_utils.time_synchronized() |         t2 = torch_utils.time_synchronized() | ||||||
| 
 | 
 | ||||||
|  |         # to float | ||||||
|  |         if half: | ||||||
|  |             pred = pred.float() | ||||||
|  | 
 | ||||||
|         # Apply NMS |         # Apply NMS | ||||||
|         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, |         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, | ||||||
|                                    multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms) |                                    multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms) | ||||||
|  | @ -173,6 +177,7 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') |     parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') | ||||||
|     parser.add_argument('--classes', nargs='+', type=int, help='filter by class') |     parser.add_argument('--classes', nargs='+', type=int, help='filter by class') | ||||||
|     parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') |     parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') | ||||||
|  |     parser.add_argument('--augment', action='store_true', help='augmented inference') | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     print(opt) |     print(opt) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										25
									
								
								models.py
								
								
								
								
							
							
						
						
									
										25
									
								
								models.py
								
								
								
								
							|  | @ -226,13 +226,21 @@ class Darknet(nn.Module): | ||||||
|         self.seen = np.array([0], dtype=np.int64)  # (int64) number of images seen during training |         self.seen = np.array([0], dtype=np.int64)  # (int64) number of images seen during training | ||||||
|         self.info(verbose)  # print model description |         self.info(verbose)  # print model description | ||||||
| 
 | 
 | ||||||
|     def forward(self, x, verbose=False): |     def forward(self, x, augment=False, verbose=False): | ||||||
|         img_size = x.shape[-2:] |         img_size = x.shape[-2:]  # height, width | ||||||
|         yolo_out, out = [], [] |         yolo_out, out = [], [] | ||||||
|         if verbose: |         if verbose: | ||||||
|             print('0', x.shape) |             print('0', x.shape) | ||||||
|             str = '' |             str = '' | ||||||
| 
 | 
 | ||||||
|  |         # Augment images (inference and test only) | ||||||
|  |         if augment:  # https://github.com/ultralytics/yolov3/issues/931 | ||||||
|  |             nb = x.shape[0]  # batch size | ||||||
|  |             x = torch.cat((x, | ||||||
|  |                            torch_utils.scale_img(x.flip(3), 0.9),  # flip-lr and scale | ||||||
|  |                            torch_utils.scale_img(x, 0.7),  # scale | ||||||
|  |                            ), 0) | ||||||
|  | 
 | ||||||
|         for i, module in enumerate(self.module_list): |         for i, module in enumerate(self.module_list): | ||||||
|             name = module.__class__.__name__ |             name = module.__class__.__name__ | ||||||
|             if name in ['WeightedFeatureFusion', 'FeatureConcat']:  # sum, concat |             if name in ['WeightedFeatureFusion', 'FeatureConcat']:  # sum, concat | ||||||
|  | @ -256,9 +264,16 @@ class Darknet(nn.Module): | ||||||
|         elif ONNX_EXPORT:  # export |         elif ONNX_EXPORT:  # export | ||||||
|             x = [torch.cat(x, 0) for x in zip(*yolo_out)] |             x = [torch.cat(x, 0) for x in zip(*yolo_out)] | ||||||
|             return x[0], torch.cat(x[1:3], 1)  # scores, boxes: 3780x80, 3780x4 |             return x[0], torch.cat(x[1:3], 1)  # scores, boxes: 3780x80, 3780x4 | ||||||
|         else:  # test |         else:  # inference or test | ||||||
|             io, p = zip(*yolo_out)  # inference output, training output |             x, p = zip(*yolo_out)  # inference output, training output | ||||||
|             return torch.cat(io, 1), p |             x = torch.cat(x, 1)  # cat yolo outputs | ||||||
|  |             if augment:  # de-augment results | ||||||
|  |                 x = torch.split(x, nb, dim=0) | ||||||
|  |                 x[1][..., :4] /= 0.9  # scale | ||||||
|  |                 x[1][..., 0] = img_size[1] - x[1][..., 0]  # flip lr | ||||||
|  |                 x[2][..., :4] /= 0.7  # scale | ||||||
|  |                 x = torch.cat(x, 1) | ||||||
|  |             return x, p | ||||||
| 
 | 
 | ||||||
|     def fuse(self): |     def fuse(self): | ||||||
|         # Fuse Conv2d + BatchNorm2d layers throughout model |         # Fuse Conv2d + BatchNorm2d layers throughout model | ||||||
|  |  | ||||||
							
								
								
									
										19
									
								
								test.py
								
								
								
								
							
							
						
						
									
										19
									
								
								test.py
								
								
								
								
							|  | @ -88,26 +88,11 @@ def test(cfg, | ||||||
| 
 | 
 | ||||||
|         # Disable gradients |         # Disable gradients | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             # Augment images |  | ||||||
|             if augment:  # https://github.com/ultralytics/yolov3/issues/931 |  | ||||||
|                 imgs = torch.cat((imgs, |  | ||||||
|                                   torch_utils.scale_img(imgs.flip(3), 0.9),  # flip-lr and scale |  | ||||||
|                                   torch_utils.scale_img(imgs, 0.7),  # scale |  | ||||||
|                                   ), 0) |  | ||||||
| 
 |  | ||||||
|             # Run model |             # Run model | ||||||
|             t = torch_utils.time_synchronized() |             t = torch_utils.time_synchronized() | ||||||
|             inf_out, train_out = model(imgs)  # inference and training outputs |             inf_out, train_out = model(imgs, augment=augment)  # inference and training outputs | ||||||
|             t0 += torch_utils.time_synchronized() - t |             t0 += torch_utils.time_synchronized() - t | ||||||
| 
 | 
 | ||||||
|             # De-augment results |  | ||||||
|             if augment: |  | ||||||
|                 x = torch.split(inf_out, nb, dim=0) |  | ||||||
|                 x[1][..., :4] /= 0.9  # scale |  | ||||||
|                 x[1][..., 0] = width - x[1][..., 0]  # flip lr |  | ||||||
|                 x[2][..., :4] /= 0.7  # scale |  | ||||||
|                 inf_out = torch.cat(x, 1) |  | ||||||
| 
 |  | ||||||
|             # Compute loss |             # Compute loss | ||||||
|             if hasattr(model, 'hyp'):  # if model has loss hyperparameters |             if hasattr(model, 'hyp'):  # if model has loss hyperparameters | ||||||
|                 loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls |                 loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls | ||||||
|  | @ -250,7 +235,7 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'") |     parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'") | ||||||
|     parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') |     parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') | ||||||
|     parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') |     parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') | ||||||
|     parser.add_argument('--augment', action='store_true', help='augmented testing') |     parser.add_argument('--augment', action='store_true', help='augmented inference') | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) |     opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) | ||||||
|     print(opt) |     print(opt) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue