move inference augmentation to model.forward()
This commit is contained in:
parent
4da5c6c114
commit
c6d4e80335
|
@ -84,9 +84,13 @@ def detect(save_img=False):
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
t1 = torch_utils.time_synchronized()
|
t1 = torch_utils.time_synchronized()
|
||||||
pred = model(img)[0].float() if half else model(img)[0]
|
pred = model(img, augment=opt.augment)[0]
|
||||||
t2 = torch_utils.time_synchronized()
|
t2 = torch_utils.time_synchronized()
|
||||||
|
|
||||||
|
# to float
|
||||||
|
if half:
|
||||||
|
pred = pred.float()
|
||||||
|
|
||||||
# Apply NMS
|
# Apply NMS
|
||||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
|
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
|
||||||
multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms)
|
multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms)
|
||||||
|
@ -173,6 +177,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
||||||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
|
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
|
||||||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
||||||
|
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
|
|
25
models.py
25
models.py
|
@ -226,13 +226,21 @@ class Darknet(nn.Module):
|
||||||
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
|
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
|
||||||
self.info(verbose) # print model description
|
self.info(verbose) # print model description
|
||||||
|
|
||||||
def forward(self, x, verbose=False):
|
def forward(self, x, augment=False, verbose=False):
|
||||||
img_size = x.shape[-2:]
|
img_size = x.shape[-2:] # height, width
|
||||||
yolo_out, out = [], []
|
yolo_out, out = [], []
|
||||||
if verbose:
|
if verbose:
|
||||||
print('0', x.shape)
|
print('0', x.shape)
|
||||||
str = ''
|
str = ''
|
||||||
|
|
||||||
|
# Augment images (inference and test only)
|
||||||
|
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
||||||
|
nb = x.shape[0] # batch size
|
||||||
|
x = torch.cat((x,
|
||||||
|
torch_utils.scale_img(x.flip(3), 0.9), # flip-lr and scale
|
||||||
|
torch_utils.scale_img(x, 0.7), # scale
|
||||||
|
), 0)
|
||||||
|
|
||||||
for i, module in enumerate(self.module_list):
|
for i, module in enumerate(self.module_list):
|
||||||
name = module.__class__.__name__
|
name = module.__class__.__name__
|
||||||
if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat
|
if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat
|
||||||
|
@ -256,9 +264,16 @@ class Darknet(nn.Module):
|
||||||
elif ONNX_EXPORT: # export
|
elif ONNX_EXPORT: # export
|
||||||
x = [torch.cat(x, 0) for x in zip(*yolo_out)]
|
x = [torch.cat(x, 0) for x in zip(*yolo_out)]
|
||||||
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
|
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
|
||||||
else: # test
|
else: # inference or test
|
||||||
io, p = zip(*yolo_out) # inference output, training output
|
x, p = zip(*yolo_out) # inference output, training output
|
||||||
return torch.cat(io, 1), p
|
x = torch.cat(x, 1) # cat yolo outputs
|
||||||
|
if augment: # de-augment results
|
||||||
|
x = torch.split(x, nb, dim=0)
|
||||||
|
x[1][..., :4] /= 0.9 # scale
|
||||||
|
x[1][..., 0] = img_size[1] - x[1][..., 0] # flip lr
|
||||||
|
x[2][..., :4] /= 0.7 # scale
|
||||||
|
x = torch.cat(x, 1)
|
||||||
|
return x, p
|
||||||
|
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
# Fuse Conv2d + BatchNorm2d layers throughout model
|
# Fuse Conv2d + BatchNorm2d layers throughout model
|
||||||
|
|
19
test.py
19
test.py
|
@ -88,26 +88,11 @@ def test(cfg,
|
||||||
|
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Augment images
|
|
||||||
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
|
||||||
imgs = torch.cat((imgs,
|
|
||||||
torch_utils.scale_img(imgs.flip(3), 0.9), # flip-lr and scale
|
|
||||||
torch_utils.scale_img(imgs, 0.7), # scale
|
|
||||||
), 0)
|
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
t = torch_utils.time_synchronized()
|
t = torch_utils.time_synchronized()
|
||||||
inf_out, train_out = model(imgs) # inference and training outputs
|
inf_out, train_out = model(imgs, augment=augment) # inference and training outputs
|
||||||
t0 += torch_utils.time_synchronized() - t
|
t0 += torch_utils.time_synchronized() - t
|
||||||
|
|
||||||
# De-augment results
|
|
||||||
if augment:
|
|
||||||
x = torch.split(inf_out, nb, dim=0)
|
|
||||||
x[1][..., :4] /= 0.9 # scale
|
|
||||||
x[1][..., 0] = width - x[1][..., 0] # flip lr
|
|
||||||
x[2][..., :4] /= 0.7 # scale
|
|
||||||
inf_out = torch.cat(x, 1)
|
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
||||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||||
|
@ -250,7 +235,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'")
|
parser.add_argument('--task', default='test', help="'test', 'study', 'benchmark'")
|
||||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
|
||||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||||
parser.add_argument('--augment', action='store_true', help='augmented testing')
|
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']])
|
opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']])
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
Loading…
Reference in New Issue