diff --git a/detect.py b/detect.py index f2501970..25c71600 100644 --- a/detect.py +++ b/detect.py @@ -84,9 +84,13 @@ def detect(save_img=False): # Inference t1 = torch_utils.time_synchronized() - pred = model(img)[0].float() if half else model(img)[0] + pred = model(img, augment=opt.augment)[0] t2 = torch_utils.time_synchronized() + # to float + if half: + pred = pred.float() + # Apply NMS pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms) @@ -173,6 +177,7 @@ if __name__ == '__main__': parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--classes', nargs='+', type=int, help='filter by class') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') + parser.add_argument('--augment', action='store_true', help='augmented inference') opt = parser.parse_args() print(opt) diff --git a/models.py b/models.py index 557fe187..c4a29aee 100755 --- a/models.py +++ b/models.py @@ -226,13 +226,21 @@ class Darknet(nn.Module): self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training self.info(verbose) # print model description - def forward(self, x, verbose=False): - img_size = x.shape[-2:] + def forward(self, x, augment=False, verbose=False): + img_size = x.shape[-2:] # height, width yolo_out, out = [], [] if verbose: print('0', x.shape) str = '' + # Augment images (inference and test only) + if augment: # https://github.com/ultralytics/yolov3/issues/931 + nb = x.shape[0] # batch size + x = torch.cat((x, + torch_utils.scale_img(x.flip(3), 0.9), # flip-lr and scale + torch_utils.scale_img(x, 0.7), # scale + ), 0) + for i, module in enumerate(self.module_list): name = module.__class__.__name__ if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat @@ -256,9 +264,16 @@ class Darknet(nn.Module): elif ONNX_EXPORT: # export x = [torch.cat(x, 0) for x in zip(*yolo_out)] return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4 - else: # test - io, p = zip(*yolo_out) # inference output, training output - return torch.cat(io, 1), p + else: # inference or test + x, p = zip(*yolo_out) # inference output, training output + x = torch.cat(x, 1) # cat yolo outputs + if augment: # de-augment results + x = torch.split(x, nb, dim=0) + x[1][..., :4] /= 0.9 # scale + x[1][..., 0] = img_size[1] - x[1][..., 0] # flip lr + x[2][..., :4] /= 0.7 # scale + x = torch.cat(x, 1) + return x, p def fuse(self): # Fuse Conv2d + BatchNorm2d layers throughout model diff --git a/test.py b/test.py index e41f998f..9e03c051 100644 --- a/test.py +++ b/test.py @@ -88,26 +88,11 @@ def test(cfg, # Disable gradients with torch.no_grad(): - # Augment images - if augment: # https://github.com/ultralytics/yolov3/issues/931 - imgs = torch.cat((imgs, - torch_utils.scale_img(imgs.flip(3), 0.9), # flip-lr and scale - torch_utils.scale_img(imgs, 0.7), # scale - ), 0) - # Run model t = torch_utils.time_synchronized() - inf_out, train_out = model(imgs) # inference and training outputs + inf_out, train_out = model(imgs, augment=augment) # inference and training outputs t0 += torch_utils.time_synchronized() - t - # De-augment results - if augment: - x = torch.split(inf_out, nb, dim=0) - x[1][..., :4] /= 0.9 # scale - x[1][..., 0] = width - x[1][..., 0] # flip lr - x[2][..., :4] /= 0.7 # scale - inf_out = torch.cat(x, 1) - # Compute loss if hasattr(model, 'hyp'): # if model has loss hyperparameters loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls @@ -250,7 +235,7 @@ if __name__ == '__main__': parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'") parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') - parser.add_argument('--augment', action='store_true', help='augmented testing') + parser.add_argument('--augment', action='store_true', help='augmented inference') opt = parser.parse_args() opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) print(opt)