move inference augmentation to model.forward()
This commit is contained in:
parent
4da5c6c114
commit
c6d4e80335
|
@ -84,9 +84,13 @@ def detect(save_img=False):
|
|||
|
||||
# Inference
|
||||
t1 = torch_utils.time_synchronized()
|
||||
pred = model(img)[0].float() if half else model(img)[0]
|
||||
pred = model(img, augment=opt.augment)[0]
|
||||
t2 = torch_utils.time_synchronized()
|
||||
|
||||
# to float
|
||||
if half:
|
||||
pred = pred.float()
|
||||
|
||||
# Apply NMS
|
||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
|
||||
multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms)
|
||||
|
@ -173,6 +177,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
|
||||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
||||
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
|
|
25
models.py
25
models.py
|
@ -226,13 +226,21 @@ class Darknet(nn.Module):
|
|||
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
|
||||
self.info(verbose) # print model description
|
||||
|
||||
def forward(self, x, verbose=False):
|
||||
img_size = x.shape[-2:]
|
||||
def forward(self, x, augment=False, verbose=False):
|
||||
img_size = x.shape[-2:] # height, width
|
||||
yolo_out, out = [], []
|
||||
if verbose:
|
||||
print('0', x.shape)
|
||||
str = ''
|
||||
|
||||
# Augment images (inference and test only)
|
||||
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
||||
nb = x.shape[0] # batch size
|
||||
x = torch.cat((x,
|
||||
torch_utils.scale_img(x.flip(3), 0.9), # flip-lr and scale
|
||||
torch_utils.scale_img(x, 0.7), # scale
|
||||
), 0)
|
||||
|
||||
for i, module in enumerate(self.module_list):
|
||||
name = module.__class__.__name__
|
||||
if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat
|
||||
|
@ -256,9 +264,16 @@ class Darknet(nn.Module):
|
|||
elif ONNX_EXPORT: # export
|
||||
x = [torch.cat(x, 0) for x in zip(*yolo_out)]
|
||||
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
|
||||
else: # test
|
||||
io, p = zip(*yolo_out) # inference output, training output
|
||||
return torch.cat(io, 1), p
|
||||
else: # inference or test
|
||||
x, p = zip(*yolo_out) # inference output, training output
|
||||
x = torch.cat(x, 1) # cat yolo outputs
|
||||
if augment: # de-augment results
|
||||
x = torch.split(x, nb, dim=0)
|
||||
x[1][..., :4] /= 0.9 # scale
|
||||
x[1][..., 0] = img_size[1] - x[1][..., 0] # flip lr
|
||||
x[2][..., :4] /= 0.7 # scale
|
||||
x = torch.cat(x, 1)
|
||||
return x, p
|
||||
|
||||
def fuse(self):
|
||||
# Fuse Conv2d + BatchNorm2d layers throughout model
|
||||
|
|
19
test.py
19
test.py
|
@ -88,26 +88,11 @@ def test(cfg,
|
|||
|
||||
# Disable gradients
|
||||
with torch.no_grad():
|
||||
# Augment images
|
||||
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
||||
imgs = torch.cat((imgs,
|
||||
torch_utils.scale_img(imgs.flip(3), 0.9), # flip-lr and scale
|
||||
torch_utils.scale_img(imgs, 0.7), # scale
|
||||
), 0)
|
||||
|
||||
# Run model
|
||||
t = torch_utils.time_synchronized()
|
||||
inf_out, train_out = model(imgs) # inference and training outputs
|
||||
inf_out, train_out = model(imgs, augment=augment) # inference and training outputs
|
||||
t0 += torch_utils.time_synchronized() - t
|
||||
|
||||
# De-augment results
|
||||
if augment:
|
||||
x = torch.split(inf_out, nb, dim=0)
|
||||
x[1][..., :4] /= 0.9 # scale
|
||||
x[1][..., 0] = width - x[1][..., 0] # flip lr
|
||||
x[2][..., :4] /= 0.7 # scale
|
||||
inf_out = torch.cat(x, 1)
|
||||
|
||||
# Compute loss
|
||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||
|
@ -250,7 +235,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'")
|
||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument('--augment', action='store_true', help='augmented testing')
|
||||
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||
opt = parser.parse_args()
|
||||
opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']])
|
||||
print(opt)
|
||||
|
|
Loading…
Reference in New Issue