cleanup
This commit is contained in:
		
							parent
							
								
									e9d41bb566
								
							
						
					
					
						commit
						8521c3cff9
					
				
							
								
								
									
										18
									
								
								train.py
								
								
								
								
							
							
						
						
									
										18
									
								
								train.py
								
								
								
								
							|  | @ -249,7 +249,7 @@ def train(): | ||||||
|                     if 'momentum' in x: |                     if 'momentum' in x: | ||||||
|                         x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']]) |                         x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']]) | ||||||
| 
 | 
 | ||||||
|             # Multi-Scale training |             # Multi-Scale | ||||||
|             if opt.multi_scale: |             if opt.multi_scale: | ||||||
|                 if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch |                 if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch | ||||||
|                     img_size = random.randrange(grid_min, grid_max + 1) * gs |                     img_size = random.randrange(grid_min, grid_max + 1) * gs | ||||||
|  | @ -258,38 +258,36 @@ def train(): | ||||||
|                     ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to 32-multiple) |                     ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to 32-multiple) | ||||||
|                     imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) |                     imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) | ||||||
| 
 | 
 | ||||||
|             # Run model |             # Forward | ||||||
|             pred = model(imgs) |             pred = model(imgs) | ||||||
| 
 | 
 | ||||||
|             # Compute loss |             # Loss | ||||||
|             loss, loss_items = compute_loss(pred, targets, model) |             loss, loss_items = compute_loss(pred, targets, model) | ||||||
|             if not torch.isfinite(loss): |             if not torch.isfinite(loss): | ||||||
|                 print('WARNING: non-finite loss, ending training ', loss_items) |                 print('WARNING: non-finite loss, ending training ', loss_items) | ||||||
|                 return results |                 return results | ||||||
| 
 | 
 | ||||||
|             # Scale loss by nominal batch_size of 64 |             # Backward | ||||||
|             loss *= batch_size / 64 |             loss *= batch_size / 64  # scale loss | ||||||
| 
 |  | ||||||
|             # Compute gradient |  | ||||||
|             if mixed_precision: |             if mixed_precision: | ||||||
|                 with amp.scale_loss(loss, optimizer) as scaled_loss: |                 with amp.scale_loss(loss, optimizer) as scaled_loss: | ||||||
|                     scaled_loss.backward() |                     scaled_loss.backward() | ||||||
|             else: |             else: | ||||||
|                 loss.backward() |                 loss.backward() | ||||||
| 
 | 
 | ||||||
|             # Optimize accumulated gradient |             # Optimize | ||||||
|             if ni % accumulate == 0: |             if ni % accumulate == 0: | ||||||
|                 optimizer.step() |                 optimizer.step() | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 ema.update(model) |                 ema.update(model) | ||||||
| 
 | 
 | ||||||
|             # Print batch results |             # Print | ||||||
|             mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses |             mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses | ||||||
|             mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB) |             mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB) | ||||||
|             s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size) |             s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size) | ||||||
|             pbar.set_description(s) |             pbar.set_description(s) | ||||||
| 
 | 
 | ||||||
|             # Plot images with bounding boxes |             # Plot | ||||||
|             if ni < 1: |             if ni < 1: | ||||||
|                 f = 'train_batch%g.png' % i  # filename |                 f = 'train_batch%g.png' % i  # filename | ||||||
|                 plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) |                 plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue