updates
This commit is contained in:
		
							parent
							
								
									e1425b7288
								
							
						
					
					
						commit
						4fb7fbf4bc
					
				
							
								
								
									
										18
									
								
								train.py
								
								
								
								
							
							
						
						
									
										18
									
								
								train.py
								
								
								
								
							|  | @ -73,7 +73,7 @@ def train(cfg, | ||||||
|           img_size=416, |           img_size=416, | ||||||
|           epochs=100,  # 500200 batches at bs 16, 117263 images = 273 epochs |           epochs=100,  # 500200 batches at bs 16, 117263 images = 273 epochs | ||||||
|           batch_size=16, |           batch_size=16, | ||||||
|           accumulate=4):  # effective bs = batch_size * accumulate = 8 * 8 = 64 |           accumulate=4):  # effective bs = batch_size * accumulate = 16 * 4 = 64 | ||||||
|     # Initialize |     # Initialize | ||||||
|     init_seeds() |     init_seeds() | ||||||
|     weights = 'weights' + os.sep |     weights = 'weights' + os.sep | ||||||
|  | @ -365,11 +365,7 @@ if __name__ == '__main__': | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     print(opt) |     print(opt) | ||||||
| 
 | 
 | ||||||
|     if opt.evolve: |     if not opt.evolve:  # Train normally | ||||||
|         opt.notest = True  # only test final epoch |  | ||||||
|         opt.nosave = True  # only save final checkpoint |  | ||||||
| 
 |  | ||||||
|     # Train |  | ||||||
|         results = train(opt.cfg, |         results = train(opt.cfg, | ||||||
|                         opt.data, |                         opt.data, | ||||||
|                         img_size=opt.img_size, |                         img_size=opt.img_size, | ||||||
|  | @ -377,10 +373,14 @@ if __name__ == '__main__': | ||||||
|                         batch_size=opt.batch_size, |                         batch_size=opt.batch_size, | ||||||
|                         accumulate=opt.accumulate) |                         accumulate=opt.accumulate) | ||||||
| 
 | 
 | ||||||
|     # Evolve hyperparameters (optional) |     else:  # Evolve hyperparameters (optional) | ||||||
|     if opt.evolve: |         opt.notest = True  # only test final epoch | ||||||
|         print_mutation(hyp, results)  # Write mutation results |         opt.nosave = True  # only save final checkpoint | ||||||
|  |         if opt.bucket: | ||||||
|  |             os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket)  # download evolve.txt if exists | ||||||
|  | 
 | ||||||
|         for _ in range(1000):  # generations to evolve |         for _ in range(1000):  # generations to evolve | ||||||
|  |             if os._exists('evolve.txt'):  # if evolve.txt exists: select best hyps and mutate | ||||||
|                 # Get best hyperparameters |                 # Get best hyperparameters | ||||||
|                 x = np.loadtxt('evolve.txt', ndmin=2) |                 x = np.loadtxt('evolve.txt', ndmin=2) | ||||||
|                 x = x[fitness(x).argmax()]  # select best fitness hyps |                 x = x[fitness(x).argmax()]  # select best fitness hyps | ||||||
|  |  | ||||||
|  | @ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn): | ||||||
|     # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ |     # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|         # init |         # init | ||||||
|         fusedconv = torch.nn.Conv2d( |         fusedconv = torch.nn.Conv2d(conv.in_channels, | ||||||
|             conv.in_channels, |  | ||||||
|                                     conv.out_channels, |                                     conv.out_channels, | ||||||
|                                     kernel_size=conv.kernel_size, |                                     kernel_size=conv.kernel_size, | ||||||
|                                     stride=conv.stride, |                                     stride=conv.stride, | ||||||
|                                     padding=conv.padding, |                                     padding=conv.padding, | ||||||
|             bias=True |                                     bias=True) | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         # prepare filters |         # prepare filters | ||||||
|         w_conv = conv.weight.clone().view(conv.out_channels, -1) |         w_conv = conv.weight.clone().view(conv.out_channels, -1) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue