parameterize augment scales
This commit is contained in:
parent
b9b14bef59
commit
d79c3bd076
43
models.py
43
models.py
|
@ -231,6 +231,36 @@ class Darknet(nn.Module):
|
||||||
self.info(verbose) # print model description
|
self.info(verbose) # print model description
|
||||||
|
|
||||||
def forward(self, x, augment=False, verbose=False):
|
def forward(self, x, augment=False, verbose=False):
|
||||||
|
|
||||||
|
if not augment:
|
||||||
|
return self.forward_once(x)
|
||||||
|
else: # Augment images (inference and test only) https://github.com/ultralytics/yolov3/issues/931
|
||||||
|
img_size = x.shape[-2:] # height, width
|
||||||
|
s = [0.83, 1.33] # scales
|
||||||
|
y = []
|
||||||
|
for i, xi in enumerate((x,
|
||||||
|
torch_utils.scale_img(x.flip(3), s[0], same_shape=False), # flip-lr and scale
|
||||||
|
torch_utils.scale_img(x, s[1], same_shape=False), # scale
|
||||||
|
)):
|
||||||
|
cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
|
||||||
|
y.append(self.forward_once(xi)[0])
|
||||||
|
|
||||||
|
y[1][..., :4] /= s[0] # scale
|
||||||
|
y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
|
||||||
|
y[2][..., :4] /= s[1] # scale
|
||||||
|
|
||||||
|
# for i, yi in enumerate(y): # coco small, medium, large = < 32**2 < 96**2 <
|
||||||
|
# area = yi[..., 2:4].prod(2)[:, :, None]
|
||||||
|
# if i == 1:
|
||||||
|
# yi = yi * (area < 96. ** 2).float()
|
||||||
|
# elif i == 2:
|
||||||
|
# yi = yi * (area > 32. ** 2).float()
|
||||||
|
# y[i] = yi
|
||||||
|
|
||||||
|
y = torch.cat(y, 1)
|
||||||
|
return y, None
|
||||||
|
|
||||||
|
def forward_once(self, x, augment=False, verbose=False):
|
||||||
img_size = x.shape[-2:] # height, width
|
img_size = x.shape[-2:] # height, width
|
||||||
yolo_out, out = [], []
|
yolo_out, out = [], []
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -240,9 +270,10 @@ class Darknet(nn.Module):
|
||||||
# Augment images (inference and test only)
|
# Augment images (inference and test only)
|
||||||
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
if augment: # https://github.com/ultralytics/yolov3/issues/931
|
||||||
nb = x.shape[0] # batch size
|
nb = x.shape[0] # batch size
|
||||||
|
s = [0.83, 0.67] # scales
|
||||||
x = torch.cat((x,
|
x = torch.cat((x,
|
||||||
torch_utils.scale_img(x.flip(3), 0.83), # flip-lr and scale
|
torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
|
||||||
torch_utils.scale_img(x, 0.67), # scale
|
torch_utils.scale_img(x, s[1]), # scale
|
||||||
), 0)
|
), 0)
|
||||||
|
|
||||||
for i, module in enumerate(self.module_list):
|
for i, module in enumerate(self.module_list):
|
||||||
|
@ -250,8 +281,8 @@ class Darknet(nn.Module):
|
||||||
if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat
|
if name in ['WeightedFeatureFusion', 'FeatureConcat']: # sum, concat
|
||||||
if verbose:
|
if verbose:
|
||||||
l = [i - 1] + module.layers # layers
|
l = [i - 1] + module.layers # layers
|
||||||
s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes
|
sh = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes
|
||||||
str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, s)])
|
str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, sh)])
|
||||||
x = module(x, out) # WeightedFeatureFusion(), FeatureConcat()
|
x = module(x, out) # WeightedFeatureFusion(), FeatureConcat()
|
||||||
elif name == 'YOLOLayer':
|
elif name == 'YOLOLayer':
|
||||||
yolo_out.append(module(x, img_size, out))
|
yolo_out.append(module(x, img_size, out))
|
||||||
|
@ -273,9 +304,9 @@ class Darknet(nn.Module):
|
||||||
x = torch.cat(x, 1) # cat yolo outputs
|
x = torch.cat(x, 1) # cat yolo outputs
|
||||||
if augment: # de-augment results
|
if augment: # de-augment results
|
||||||
x = torch.split(x, nb, dim=0)
|
x = torch.split(x, nb, dim=0)
|
||||||
x[1][..., :4] /= 0.83 # scale
|
x[1][..., :4] /= s[0] # scale
|
||||||
x[1][..., 0] = img_size[1] - x[1][..., 0] # flip lr
|
x[1][..., 0] = img_size[1] - x[1][..., 0] # flip lr
|
||||||
x[2][..., :4] /= 0.67 # scale
|
x[2][..., :4] /= s[1] # scale
|
||||||
x = torch.cat(x, 1)
|
x = torch.cat(x, 1)
|
||||||
return x, p
|
return x, p
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue