From 07d2f0ad03d1c6a0dcff1737c320e983c1d5a01f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Mar 2020 18:39:54 -0700 Subject: [PATCH] test/inference time augmentation --- test.py | 15 ++++++++++++++- utils/torch_utils.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/test.py b/test.py index b7d33d37..d9439f6b 100644 --- a/test.py +++ b/test.py @@ -73,7 +73,7 @@ def test(cfg, for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) - _, _, height, width = imgs.shape # batch size, channels, height, width + nb, _, height, width = imgs.shape # batch size, channels, height, width whwh = torch.Tensor([width, height, width, height]).to(device) # Plot images with bounding boxes @@ -83,11 +83,24 @@ def test(cfg, # Disable gradients with torch.no_grad(): + aug = False # augment https://github.com/ultralytics/yolov3/issues/931 + if aug: + imgs = torch.cat((imgs, + imgs.flip(3), # flip-lr + torch_utils.scale_img(imgs, 0.7), # scale + ), 0) + # Run model t = torch_utils.time_synchronized() inf_out, train_out = model(imgs) # inference and training outputs t0 += torch_utils.time_synchronized() - t + if aug: + x = torch.split(inf_out, nb, dim=0) + x[1][..., 0] = width - x[1][..., 0] # flip lr + x[2][..., :4] /= 0.7 # scale + inf_out = torch.cat(x, 1) + # Compute loss if hasattr(model, 'hyp'): # if model has loss hyperparameters loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 36ab9ead..187d5142 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -5,6 +5,7 @@ from copy import deepcopy import torch import torch.backends.cudnn as cudnn import torch.nn as nn +import torch.nn.functional as F def init_seeds(seed=0): @@ -105,6 +106,16 @@ def load_classifier(name='resnet101', n=2): return model +def scale_img(img, r=1.0): # img(16,3,256,416), r=ratio + # scales a batch of pytorch images while retaining same input shape (cropped or grey-padded) + h, w = img.shape[2:] + s = (int(h * r), int(w * r)) # new size + p = h - s[0], w - s[1] # pad/crop pixels + img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + return F.pad(img, [0, p[1], 0, p[0]], value=0.5) if r < 1.0 else img[:, :, :p[0], :p[1]] # pad/crop + # cv2.imwrite('scaled.jpg', np.array(img[0].permute((1, 2, 0)) * 255.0)) + + class ModelEMA: """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers).