updates
This commit is contained in:
parent
c591936446
commit
b5a2747a6a
10
train.py
10
train.py
|
@ -111,7 +111,7 @@ def train(
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
print(('%8s%12s' + '%10s' * 14) % ('Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R',
|
print(('%8s%12s' + '%10s' * 14) % ('Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R',
|
||||||
'nTargets', 'TP', 'FP', 'FN', 'time'))
|
'nTargets', 'TP', 'FP', 'FN', 'time'))
|
||||||
|
|
||||||
# Update scheduler (automatic)
|
# Update scheduler (automatic)
|
||||||
# scheduler.step()
|
# scheduler.step()
|
||||||
|
@ -153,10 +153,10 @@ def train(
|
||||||
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# accumulated_batches = 1 # accumulate gradient for 4 batches before stepping optimizer
|
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
|
||||||
# if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Running epoch-means of tracked metrics
|
# Running epoch-means of tracked metrics
|
||||||
ui += 1
|
ui += 1
|
||||||
|
|
|
@ -116,7 +116,7 @@ class load_images_and_labels(): # for training
|
||||||
augment_hsv = True
|
augment_hsv = True
|
||||||
if self.augment and augment_hsv:
|
if self.augment and augment_hsv:
|
||||||
# SV augmentation by 50%
|
# SV augmentation by 50%
|
||||||
fraction = 0.50
|
fraction = 0.25
|
||||||
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
S = img_hsv[:, :, 1].astype(np.float32)
|
S = img_hsv[:, :, 1].astype(np.float32)
|
||||||
V = img_hsv[:, :, 2].astype(np.float32)
|
V = img_hsv[:, :, 2].astype(np.float32)
|
||||||
|
@ -153,7 +153,7 @@ class load_images_and_labels(): # for training
|
||||||
|
|
||||||
# Augment image and labels
|
# Augment image and labels
|
||||||
if self.augment:
|
if self.augment:
|
||||||
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.8, 1.2))
|
img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.05, 0.05), scale=(0.95, 1.05))
|
||||||
|
|
||||||
plotFlag = False
|
plotFlag = False
|
||||||
if plotFlag:
|
if plotFlag:
|
||||||
|
@ -211,7 +211,7 @@ def resize_square(img, height=416, color=(0, 0, 0)): # resize a rectangular ima
|
||||||
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color), ratio, dw // 2, dh // 2
|
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color), ratio, dw // 2, dh // 2
|
||||||
|
|
||||||
|
|
||||||
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-3, 3),
|
def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
|
||||||
borderValue=(127.5, 127.5, 127.5)):
|
borderValue=(127.5, 127.5, 127.5)):
|
||||||
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
||||||
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
|
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
|
||||||
|
@ -288,8 +288,5 @@ def convert_tif2bmp(p='../xview/val_images_bmp'):
|
||||||
files = sorted(glob.glob('%s/*.tif' % p))
|
files = sorted(glob.glob('%s/*.tif' % p))
|
||||||
for i, f in enumerate(files):
|
for i, f in enumerate(files):
|
||||||
print('%g/%g' % (i + 1, len(files)))
|
print('%g/%g' % (i + 1, len(files)))
|
||||||
|
cv2.imwrite(f.replace('.tif', '.bmp'), cv2.imread(f))
|
||||||
img = cv2.imread(f)
|
|
||||||
|
|
||||||
cv2.imwrite(f.replace('.tif', '.bmp'), img)
|
|
||||||
os.system('rm -rf ' + f)
|
os.system('rm -rf ' + f)
|
||||||
|
|
Loading…
Reference in New Issue