This commit is contained in:
Glenn Jocher 2019-02-21 15:57:18 +01:00
parent ec308d605e
commit af853f604c
2 changed files with 22 additions and 22 deletions

View File

@ -295,7 +295,7 @@ def load_darknet_weights(self, weights, cutoff=-1):
if weights_file == 'darknet53.conv.74': if weights_file == 'darknet53.conv.74':
cutoff = 75 cutoff = 75
elif weights_file == 'yolov3-tiny.conv.15': elif weights_file == 'yolov3-tiny.conv.15':
cutoff = 16 cutoff = 15
# Open the weights file # Open the weights file
fp = open(weights, 'rb') fp = open(weights, 'rb')

View File

@ -15,11 +15,13 @@ def train(
epochs=100, epochs=100,
batch_size=16, batch_size=16,
accumulated_batches=1, accumulated_batches=1,
weights='weights',
multi_scale=False, multi_scale=False,
freeze_backbone=True, freeze_backbone=True,
var=0, var=0,
): ):
weights = 'weights' + os.sep
latest = weights + 'latest.pt'
best = weights + 'best.pt'
device = torch_utils.select_device() device = torch_utils.select_device()
if multi_scale: # pass maximum multi_scale size if multi_scale: # pass maximum multi_scale size
@ -27,9 +29,6 @@ def train(
else: else:
torch.backends.cudnn.benchmark = True # unsuitable for multiscale torch.backends.cudnn.benchmark = True # unsuitable for multiscale
latest = os.path.join(weights, 'latest.pt')
best = os.path.join(weights, 'best.pt')
# Configure run # Configure run
train_path = parse_data_cfg(data_cfg)['train'] train_path = parse_data_cfg(data_cfg)['train']
@ -40,6 +39,7 @@ def train(
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, multi_scale=multi_scale, augment=True) dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, multi_scale=multi_scale, augment=True)
lr0 = 0.001 lr0 = 0.001
cutoff = -1 # backbone reaches to cutoff layer
if resume: if resume:
checkpoint = torch.load(latest, map_location='cpu') checkpoint = torch.load(latest, map_location='cpu')
@ -69,8 +69,13 @@ def train(
start_epoch = 0 start_epoch = 0
best_loss = float('inf') best_loss = float('inf')
# Initialize model with darknet53 weights (optional) # Initialize model with backbone (optional)
load_darknet_weights(model, os.path.join(weights, 'darknet53.conv.74')) if cfg.endswith('yolov3.cfg'):
load_darknet_weights(model, weights + 'darknet53.conv.74')
cutoff = 75
elif cfg.endswith('yolov3-tiny.cfg'):
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
cutoff = 15
# if torch.cuda.device_count() > 1: # if torch.cuda.device_count() > 1:
# model = nn.DataParallel(model) # model = nn.DataParallel(model)
@ -102,15 +107,10 @@ def train(
g['lr'] = lr g['lr'] = lr
# Freeze darknet53.conv.74 for first epoch # Freeze darknet53.conv.74 for first epoch
if freeze_backbone: if freeze_backbone and (epoch < 2):
if epoch == 0: for i, (name, p) in enumerate(model.named_parameters()):
for i, (name, p) in enumerate(model.named_parameters()): if int(name.split('.')[1]) < cutoff: # if layer < 75
if int(name.split('.')[1]) < 75: # if layer < 75 p.requires_grad = False if (epoch == 0) else True
p.requires_grad = False
elif epoch == 1:
for i, (name, p) in enumerate(model.named_parameters()):
if int(name.split('.')[1]) < 75: # if layer < 75
p.requires_grad = True
ui = -1 ui = -1
rloss = defaultdict(float) # running loss rloss = defaultdict(float) # running loss
@ -140,9 +140,11 @@ def train(
rloss[key] = (rloss[key] * ui + val) / (ui + 1) rloss[key] = (rloss[key] * ui + val) / (ui + 1)
s = ('%8s%12s' + '%10.3g' * 7) % ( s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['xy'], '%g/%g' % (epoch, epochs - 1),
rloss['wh'], rloss['conf'], rloss['cls'], '%g/%g' % (i, len(dataloader) - 1),
rloss['loss'], model.losses['nT'], time.time() - t0) rloss['xy'], rloss['wh'], rloss['conf'],
rloss['cls'], rloss['loss'],
model.losses['nT'], time.time() - t0)
t0 = time.time() t0 = time.time()
print(s) print(s)
@ -164,7 +166,7 @@ def train(
# Save backup weights every 5 epochs (optional) # Save backup weights every 5 epochs (optional)
# if (epoch > 0) & (epoch % 5 == 0): # if (epoch > 0) & (epoch % 5 == 0):
# os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch))) # os.system('cp ' + latest + ' ' + weights + 'backup{}.pt'.format(epoch)))
# Calculate mAP # Calculate mAP
with torch.no_grad(): with torch.no_grad():
@ -184,7 +186,6 @@ if __name__ == '__main__':
parser.add_argument('--data-cfg', type=str, default='cfg/coco.data', help='coco.data file path') parser.add_argument('--data-cfg', type=str, default='cfg/coco.data', help='coco.data file path')
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608') parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--weights', type=str, default='weights', help='path to store weights')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--var', type=float, default=0, help='test variable') parser.add_argument('--var', type=float, default=0, help='test variable')
opt = parser.parse_args() opt = parser.parse_args()
@ -200,7 +201,6 @@ if __name__ == '__main__':
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,
accumulated_batches=opt.accumulated_batches, accumulated_batches=opt.accumulated_batches,
weights=opt.weights,
multi_scale=opt.multi_scale, multi_scale=opt.multi_scale,
var=opt.var, var=opt.var,
) )