This commit is contained in:
Glenn Jocher 2018-12-12 17:02:37 +01:00
parent c591936446
commit b5a2747a6a
2 changed files with 9 additions and 12 deletions

View File

@ -153,8 +153,8 @@ def train(
loss = model(imgs.to(device), targets, batch_report=report, var=var)
loss.backward()
# accumulated_batches = 1 # accumulate gradient for 4 batches before stepping optimizer
# if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
optimizer.step()
optimizer.zero_grad()

View File

@ -116,7 +116,7 @@ class load_images_and_labels(): # for training
augment_hsv = True
if self.augment and augment_hsv:
# SV augmentation by 50%
fraction = 0.50
fraction = 0.25
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
S = img_hsv[:, :, 1].astype(np.float32)
V = img_hsv[:, :, 2].astype(np.float32)
@ -153,7 +153,7 @@ class load_images_and_labels(): # for training
# Augment image and labels
if self.augment:
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.8, 1.2))
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.05, 0.05), scale=(0.95, 1.05))
plotFlag = False
if plotFlag:
@ -211,7 +211,7 @@ def resize_square(img, height=416, color=(0, 0, 0)): # resize a rectangular ima
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color), ratio, dw // 2, dh // 2
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-3, 3),
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
@ -288,8 +288,5 @@ def convert_tif2bmp(p='../xview/val_images_bmp'):
files = sorted(glob.glob('%s/*.tif' % p))
for i, f in enumerate(files):
print('%g/%g' % (i + 1, len(files)))
img = cv2.imread(f)
cv2.imwrite(f.replace('.tif', '.bmp'), img)
cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f))
os.system('rm -rf ' + f)