This commit is contained in:
Glenn Jocher 2019-07-08 15:02:20 +02:00
parent 68b50f5cb6
commit 291c3ec9c7
1 changed files with 4 additions and 7 deletions

View File

@ -11,10 +11,7 @@ from models import *
from utils.datasets import *
from utils.utils import *
# 0.0945 0.279 0.114 0.131 25 0.035 0.2 0.1 0.035 79 1.61 3.53 0.29 0.001 -4 0.9 0.0005 320
# 0.149 0.241 0.126 0.156 6.85 1.008 1.421 0.07989 16.94 6.215 10.61 4.272 0.251 0.001 -4 0.9 0.0005 320 giou
# 0.111 0.27 0.132 0.131 3.96 1.276 0.3156 0.1425 21.21 6.224 11.59 8.83 0.376 0.001 -4 0.9 0.0005 320
# 0.114 0.287 0.144 0.132 7.1 1.666 4.046 0.1364 42.6 3.34 12.61 8.338 0.2705 0.001 -4 0.9 0.0005 320 giou + best_anchor False
# 0.109 0.297 0.15 0.126 7.04 1.666 4.062 0.1845 42.6 3.34 12.61 8.338 0.2705 0.001 -4 0.9 0.0005 320 giou + best_anchor False
hyp = {'giou': 1.666, # giou loss gain
'xy': 4.062, # xy loss gain
'wh': 0.1845, # wh loss gain
@ -114,12 +111,11 @@ def train(
# plt.savefig('LR.png', dpi=300)
# Dataset
rectangular_training = False
dataset = LoadImagesAndLabels(train_path,
img_size,
batch_size,
augment=True,
rect=rectangular_training)
rect=opt.rect) # rectangular training
# Initialize distributed training
if torch.cuda.device_count() > 1:
@ -135,7 +131,7 @@ def train(
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=opt.num_workers,
shuffle=not rectangular_training, # Shuffle=True unless rectangular training is used
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
pin_memory=True,
collate_fn=dataset.collate_fn)
@ -301,6 +297,7 @@ if __name__ == '__main__':
parser.add_argument('--data-cfg', type=str, default='data/coco_64img.data', help='coco.data file path')
parser.add_argument('--single-scale', action='store_true', help='train at fixed size (no multi-scale)')
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')