updates
This commit is contained in:
parent
19ccb41eaf
commit
6e5da1ce27
6
train.py
6
train.py
|
@ -7,7 +7,7 @@ from utils.utils import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-epochs', type=int, default=100, help='number of epochs')
|
parser.add_argument('-epochs', type=int, default=100, help='number of epochs')
|
||||||
parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=4, help='size of each image batch')
|
||||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path')
|
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path')
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
parser.add_argument('-img_size', type=int, default=32 * 19, help='size of each image dimension')
|
parser.add_argument('-img_size', type=int, default=32 * 19, help='size of each image dimension')
|
||||||
|
@ -128,8 +128,8 @@ def main(opt):
|
||||||
loss = model(imgs.to(device), targets, requestPrecision=True)
|
loss = model(imgs.to(device), targets, requestPrecision=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# accumulated_batches = 4 # accumulate gradient for 4 batches before stepping optimizer
|
accumulated_batches = 4 # accumulate gradient for 4 batches before stepping optimizer
|
||||||
# if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
|
@ -436,11 +436,11 @@ def plot_results():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
plt.figure(figsize=(16, 8))
|
plt.figure(figsize=(16, 8))
|
||||||
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall']
|
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall']
|
||||||
for f in ('results.txt',
|
for f in ('results_orig.txt','results.txt',
|
||||||
):
|
):
|
||||||
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T
|
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
plt.subplot(2, 5, i + 1)
|
plt.subplot(2, 5, i + 1)
|
||||||
plt.plot(results[i, :], marker='.', label=f)
|
plt.plot(results[i, :250], marker='.', label=f)
|
||||||
plt.title(s[i])
|
plt.title(s[i])
|
||||||
plt.legend()
|
plt.legend()
|
||||||
|
|
Loading…
Reference in New Issue