updates
This commit is contained in:
		
							parent
							
								
									4b720013d1
								
							
						
					
					
						commit
						f743235fac
					
				
							
								
								
									
										8
									
								
								train.py
								
								
								
								
							
							
						
						
									
										8
									
								
								train.py
								
								
								
								
							|  | @ -288,7 +288,7 @@ def train(): | |||
|             else: | ||||
|                 loss.backward() | ||||
| 
 | ||||
|             # Accumulate gradient for x batches before optimizing | ||||
|             # Optimize accumulated gradient | ||||
|             if ni % accumulate == 0: | ||||
|                 optimizer.step() | ||||
|                 optimizer.zero_grad() | ||||
|  | @ -301,6 +301,9 @@ def train(): | |||
| 
 | ||||
|             # end batch ------------------------------------------------------------------------------------------------ | ||||
| 
 | ||||
|         # Update scheduler | ||||
|         scheduler.step() | ||||
| 
 | ||||
|         # Process epoch results | ||||
|         final_epoch = epoch + 1 == epochs | ||||
|         if not opt.notest or final_epoch:  # Calculate mAP | ||||
|  | @ -316,9 +319,6 @@ def train(): | |||
|                                       single_cls=opt.single_cls, | ||||
|                                       dataloader=testloader) | ||||
| 
 | ||||
|         # Update scheduler | ||||
|         scheduler.step() | ||||
| 
 | ||||
|         # Write epoch results | ||||
|         with open(results_file, 'a') as f: | ||||
|             f.write(s + '%10.3g' * 7 % results + '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue