updates
This commit is contained in:
		
							parent
							
								
									870020ed15
								
							
						
					
					
						commit
						c24702941f
					
				
							
								
								
									
										54
									
								
								models.py
								
								
								
								
							
							
						
						
									
										54
									
								
								models.py
								
								
								
								
							|  | @ -291,27 +291,9 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float | ||||||
| 
 | 
 | ||||||
| def load_darknet_weights(self, weights, cutoff=-1): | def load_darknet_weights(self, weights, cutoff=-1): | ||||||
|     # Parses and loads the weights stored in 'weights' |     # Parses and loads the weights stored in 'weights' | ||||||
|     # cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved) | 
 | ||||||
|  |     # Establish cutoffs (load layers between 0 and cutoff. if cutoff = -1 all are loaded) | ||||||
|     file = Path(weights).name |     file = Path(weights).name | ||||||
| 
 |  | ||||||
|     # Try to download weights if not available locally |  | ||||||
|     msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI' |  | ||||||
|     if not os.path.isfile(weights): |  | ||||||
|         if file == 'yolov3-spp.weights': |  | ||||||
|             gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights) |  | ||||||
|         elif file == 'darknet53.conv.74': |  | ||||||
|             gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights) |  | ||||||
|         else: |  | ||||||
|             try:  # download from pjreddie.com |  | ||||||
|                 url = 'https://pjreddie.com/media/files/' + file |  | ||||||
|                 print('Downloading ' + url) |  | ||||||
|                 os.system('curl -f ' + url + ' -o ' + weights) |  | ||||||
|             except IOError: |  | ||||||
|                 print(msg) |  | ||||||
|                 os.system('rm ' + weights)  # remove partial downloads |  | ||||||
|     assert os.path.exists(weights), msg  # download missing weights from Google Drive |  | ||||||
| 
 |  | ||||||
|     # Establish cutoffs |  | ||||||
|     if file == 'darknet53.conv.74': |     if file == 'darknet53.conv.74': | ||||||
|         cutoff = 75 |         cutoff = 75 | ||||||
|     elif file == 'yolov3-tiny.conv.15': |     elif file == 'yolov3-tiny.conv.15': | ||||||
|  | @ -417,3 +399,35 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|         print('Error: extension not supported.') |         print('Error: extension not supported.') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def attempt_download(weights): | ||||||
|  |     # Attempt to download pretrained weights if not found locally | ||||||
|  | 
 | ||||||
|  |     msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI' | ||||||
|  |     if not os.path.isfile(weights): | ||||||
|  |         file = Path(weights).name | ||||||
|  | 
 | ||||||
|  |         if file == 'yolov3-spp.weights': | ||||||
|  |             gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights) | ||||||
|  |         elif file == 'yolov3-spp.pt': | ||||||
|  |             gdrive_download(id='1vFlbJ_dXPvtwaLLOu-twnjK4exdFiQ73', name=weights) | ||||||
|  |         elif file == 'yolov3.pt': | ||||||
|  |             gdrive_download(id='11uy0ybbOXA2hc-NJkJbbbkDwNX1QZDlz', name=weights) | ||||||
|  |         elif file == 'yolov3-tiny.pt': | ||||||
|  |             gdrive_download(id='1qKSgejNeNczgNNiCn9ZF_o55GFk1DjY_', name=weights) | ||||||
|  |         elif file == 'darknet53.conv.74': | ||||||
|  |             gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights) | ||||||
|  |         elif file == 'yolov3-tiny.conv.15': | ||||||
|  |             gdrive_download(id='140PnSedCsGGgu3rOD6Ez4oI6cdDzerLC', name=weights) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             try:  # download from pjreddie.com | ||||||
|  |                 url = 'https://pjreddie.com/media/files/' + file | ||||||
|  |                 print('Downloading ' + url) | ||||||
|  |                 os.system('curl -f ' + url + ' -o ' + weights) | ||||||
|  |             except IOError: | ||||||
|  |                 print(msg) | ||||||
|  |                 os.system('rm ' + weights)  # remove partial downloads | ||||||
|  | 
 | ||||||
|  |     assert os.path.exists(weights), msg  # download missing weights from Google Drive | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								test.py
								
								
								
								
							
							
						
						
									
										1
									
								
								test.py
								
								
								
								
							|  | @ -27,6 +27,7 @@ def test(cfg, | ||||||
|         model = Darknet(cfg, img_size).to(device) |         model = Darknet(cfg, img_size).to(device) | ||||||
| 
 | 
 | ||||||
|         # Load weights |         # Load weights | ||||||
|  |         attempt_download(weights) | ||||||
|         if weights.endswith('.pt'):  # pytorch format |         if weights.endswith('.pt'):  # pytorch format | ||||||
|             model.load_state_dict(torch.load(weights, map_location=device)['model']) |             model.load_state_dict(torch.load(weights, map_location=device)['model']) | ||||||
|         else:  # darknet format |         else:  # darknet format | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								train.py
								
								
								
								
							
							
						
						
									
										1
									
								
								train.py
								
								
								
								
							|  | @ -100,6 +100,7 @@ def train(): | ||||||
|     cutoff = -1  # backbone reaches to cutoff layer |     cutoff = -1  # backbone reaches to cutoff layer | ||||||
|     start_epoch = 0 |     start_epoch = 0 | ||||||
|     best_fitness = 0. |     best_fitness = 0. | ||||||
|  |     attempt_download(weights) | ||||||
|     if weights.endswith('.pt'):  # pytorch format |     if weights.endswith('.pt'):  # pytorch format | ||||||
|         # possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. |         # possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. | ||||||
|         if opt.bucket: |         if opt.bucket: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue