This commit is contained in:
Glenn Jocher 2020-04-27 15:22:36 -07:00
parent e9d41bb566
commit 8521c3cff9
1 changed files with 8 additions and 10 deletions

View File

@ -249,7 +249,7 @@ def train():
if 'momentum' in x: if 'momentum' in x:
x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']]) x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']])
# Multi-Scale training # Multi-Scale
if opt.multi_scale: if opt.multi_scale:
if ni / accumulate % 1 == 0: #  adjust img_size (67% - 150%) every 1 batch if ni / accumulate % 1 == 0: #  adjust img_size (67% - 150%) every 1 batch
img_size = random.randrange(grid_min, grid_max + 1) * gs img_size = random.randrange(grid_min, grid_max + 1) * gs
@ -258,38 +258,36 @@ def train():
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to 32-multiple) ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to 32-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Run model # Forward
pred = model(imgs) pred = model(imgs)
# Compute loss # Loss
loss, loss_items = compute_loss(pred, targets, model) loss, loss_items = compute_loss(pred, targets, model)
if not torch.isfinite(loss): if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items) print('WARNING: non-finite loss, ending training ', loss_items)
return results return results
# Scale loss by nominal batch_size of 64 # Backward
loss *= batch_size / 64 loss *= batch_size / 64 # scale loss
# Compute gradient
if mixed_precision: if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
loss.backward() loss.backward()
# Optimize accumulated gradient # Optimize
if ni % accumulate == 0: if ni % accumulate == 0:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
ema.update(model) ema.update(model)
# Print batch results # Print
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size) s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
pbar.set_description(s) pbar.set_description(s)
# Plot images with bounding boxes # Plot
if ni < 1: if ni < 1:
f = 'train_batch%g.png' % i # filename f = 'train_batch%g.png' % i # filename
plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)