diff --git a/train.py b/train.py index 15f412a9..6c0ead81 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ from utils.utils import * parser = argparse.ArgumentParser() parser.add_argument('-epochs', type=int, default=100, help='number of epochs') -parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch') +parser.add_argument('-batch_size', type=int, default=4, help='size of each image batch') parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('-img_size', type=int, default=32 * 19, help='size of each image dimension') @@ -128,10 +128,10 @@ def main(opt): loss = model(imgs.to(device), targets, requestPrecision=True) loss.backward() - # accumulated_batches = 4 # accumulate gradient for 4 batches before stepping optimizer - # if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1): - optimizer.step() - optimizer.zero_grad() + accumulated_batches = 4 # accumulate gradient for 4 batches before stepping optimizer + if ((i+1) % accumulated_batches == 0) or (i == len(dataloader) - 1): + optimizer.step() + optimizer.zero_grad() # Compute running epoch-means of tracked metrics ui += 1 diff --git a/utils/utils.py b/utils/utils.py index 0f26542f..277b6a70 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -436,11 +436,11 @@ def plot_results(): import matplotlib.pyplot as plt plt.figure(figsize=(16, 8)) s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall'] - for f in ('results.txt', + for f in ('results_orig.txt','results.txt', ): results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T for i in range(9): plt.subplot(2, 5, i + 1) - plt.plot(results[i, :], marker='.', label=f) + plt.plot(results[i, :250], marker='.', label=f) plt.title(s[i]) plt.legend()