check_git_status() to train.py
This commit is contained in:
		
							parent
							
								
									2f636d5740
								
							
						
					
					
						commit
						77b3829d56
					
				
							
								
								
									
										3
									
								
								train.py
								
								
								
								
							
							
						
						
									
										3
									
								
								train.py
								
								
								
								
							|  | @ -23,7 +23,6 @@ best = wdir + 'best.pt' | |||
| results_file = 'results.txt' | ||||
| 
 | ||||
| # Hyperparameters https://github.com/ultralytics/yolov3/issues/310 | ||||
| 
 | ||||
| hyp = {'giou': 3.54,  # giou loss gain | ||||
|        'cls': 37.4,  # cls loss gain | ||||
|        'cls_pw': 1.0,  # cls BCELoss positive_weight | ||||
|  | @ -54,7 +53,6 @@ if f: | |||
| if hyp['fl_gamma']: | ||||
|     print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) | ||||
| 
 | ||||
| 
 | ||||
| def train(): | ||||
|     cfg = opt.cfg | ||||
|     data = opt.data | ||||
|  | @ -408,6 +406,7 @@ if __name__ == '__main__': | |||
|     parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') | ||||
|     opt = parser.parse_args() | ||||
|     opt.weights = last if opt.resume else opt.weights | ||||
|     check_git_status() | ||||
|     print(opt) | ||||
|     opt.img_size.extend([opt.img_size[-1]] * (3 - len(opt.img_size)))  # extend to 3 sizes (min, max, test) | ||||
|     device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) | ||||
|  |  | |||
|  | @ -17,16 +17,10 @@ from tqdm import tqdm | |||
| 
 | ||||
| from . import torch_utils  # , google_utils | ||||
| 
 | ||||
| matplotlib.rc('font', **{'size': 11}) | ||||
| 
 | ||||
| # Suggest 'git pull' | ||||
| s = subprocess.check_output('if [ -d .git ]; then git status -uno; fi', shell=True).decode('utf-8') | ||||
| if 'Your branch is behind' in s: | ||||
|     print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') | ||||
| 
 | ||||
| # Set printoptions | ||||
| torch.set_printoptions(linewidth=320, precision=5, profile='long') | ||||
| np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5 | ||||
| matplotlib.rc('font', **{'size': 11}) | ||||
| 
 | ||||
| # Prevent OpenCV from multithreading (to use PyTorch DataLoader) | ||||
| cv2.setNumThreads(0) | ||||
|  | @ -38,6 +32,13 @@ def init_seeds(seed=0): | |||
|     torch_utils.init_seeds(seed=seed) | ||||
| 
 | ||||
| 
 | ||||
| def check_git_status(): | ||||
|     # Suggest 'git pull' if repo is out of date | ||||
|     s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') | ||||
|     if 'Your branch is behind' in s: | ||||
|         print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') | ||||
| 
 | ||||
| 
 | ||||
| def load_classes(path): | ||||
|     # Loads *.names file at 'path' | ||||
|     with open(path, 'r') as f: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue