updates
This commit is contained in:
		
							parent
							
								
									813024116b
								
							
						
					
					
						commit
						b459587cb0
					
				
							
								
								
									
										11
									
								
								train.py
								
								
								
								
							
							
						
						
									
										11
									
								
								train.py
								
								
								
								
							|  | @ -167,7 +167,7 @@ def train(cfg, | ||||||
|     maps = np.zeros(nc)  # mAP per class |     maps = np.zeros(nc)  # mAP per class | ||||||
|     results = (0, 0, 0, 0, 0)  # P, R, mAP, F1, test_loss |     results = (0, 0, 0, 0, 0)  # P, R, mAP, F1, test_loss | ||||||
|     n_burnin = min(round(nb / 5 + 1), 1000)  # burn-in batches |     n_burnin = min(round(nb / 5 + 1), 1000)  # burn-in batches | ||||||
|     torch.cuda.empty_cache() |     t0 = time.time() | ||||||
|     for epoch in range(start_epoch, epochs): |     for epoch in range(start_epoch, epochs): | ||||||
|         model.train() |         model.train() | ||||||
|         print(('\n%8s' + '%10s' * 8) % |         print(('\n%8s' + '%10s' * 8) % | ||||||
|  | @ -235,14 +235,10 @@ def train(cfg, | ||||||
| 
 | 
 | ||||||
|             # Print batch results |             # Print batch results | ||||||
|             mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses |             mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses | ||||||
|             mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0 |             mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0  # (GB) | ||||||
|             s = ('%8s' + '%10.3g' * 8) % ('%g/%g' % (epoch, epochs - 1), *mloss, len(targets), img_size, mem) |             s = ('%8s' + '%10.3g' * 8) % ('%g/%g' % (epoch, epochs - 1), *mloss, len(targets), img_size, mem) | ||||||
|             pbar.set_description(s)  # print(s) |             pbar.set_description(s)  # print(s) | ||||||
| 
 | 
 | ||||||
|         # Report time |  | ||||||
|         # dt = (time.time() - t0) / 3600 |  | ||||||
|         # print('%g epochs completed in %.3f hours.' % (epoch - start_epoch + 1, dt)) |  | ||||||
| 
 |  | ||||||
|         # Calculate mAP (always test final epoch, skip first 5 if opt.nosave) |         # Calculate mAP (always test final epoch, skip first 5 if opt.nosave) | ||||||
|         if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1: |         if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1: | ||||||
|             with torch.no_grad(): |             with torch.no_grad(): | ||||||
|  | @ -286,6 +282,9 @@ def train(cfg, | ||||||
|             # Delete checkpoint |             # Delete checkpoint | ||||||
|             del chkpt |             del chkpt | ||||||
| 
 | 
 | ||||||
|  |     # Report time | ||||||
|  |     print('%g epochs completed in %.3f hours.' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) | ||||||
|  | 
 | ||||||
|     return results |     return results | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue