updates
This commit is contained in:
parent
c591936446
commit
b5a2747a6a
4
train.py
4
train.py
|
@ -153,8 +153,8 @@ def train(
|
|||
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
||||
loss.backward()
|
||||
|
||||
# accumulated_batches = 1 # accumulate gradient for 4 batches before stepping optimizer
|
||||
# if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
|
||||
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class load_images_and_labels(): # for training
|
|||
augment_hsv = True
|
||||
if self.augment and augment_hsv:
|
||||
# SV augmentation by 50%
|
||||
fraction = 0.50
|
||||
fraction = 0.25
|
||||
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
S = img_hsv[:, :, 1].astype(np.float32)
|
||||
V = img_hsv[:, :, 2].astype(np.float32)
|
||||
|
@ -153,7 +153,7 @@ class load_images_and_labels(): # for training
|
|||
|
||||
# Augment image and labels
|
||||
if self.augment:
|
||||
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.8, 1.2))
|
||||
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.05, 0.05), scale=(0.95, 1.05))
|
||||
|
||||
plotFlag = False
|
||||
if plotFlag:
|
||||
|
@ -211,7 +211,7 @@ def resize_square(img, height=416, color=(0, 0, 0)): # resize a rectangular ima
|
|||
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color), ratio, dw // 2, dh // 2
|
||||
|
||||
|
||||
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-3, 3),
|
||||
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
|
||||
borderValue=(127.5, 127.5, 127.5)):
|
||||
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
||||
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
|
||||
|
@ -288,8 +288,5 @@ def convert_tif2bmp(p='../xview/val_images_bmp'):
|
|||
files = sorted(glob.glob('%s/*.tif' % p))
|
||||
for i, f in enumerate(files):
|
||||
print('%g/%g' % (i + 1, len(files)))
|
||||
|
||||
img = cv2.imread(f)
|
||||
|
||||
cv2.imwrite(f.replace('.tif', '.bmp'), img)
|
||||
cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f))
|
||||
os.system('rm -rf ' + f)
|
||||
|
|
Loading…
Reference in New Issue