updates
This commit is contained in:
		
							parent
							
								
									8b9aae484b
								
							
						
					
					
						commit
						8b88e50f2f
					
				
							
								
								
									
										14
									
								
								detect.py
								
								
								
								
							
							
						
						
									
										14
									
								
								detect.py
								
								
								
								
							|  | @ -11,7 +11,7 @@ from utils import torch_utils | |||
| def detect( | ||||
|         net_config_path, | ||||
|         data_config_path, | ||||
|         weights_file_path, | ||||
|         weights_path, | ||||
|         images_path, | ||||
|         output='output', | ||||
|         batch_size=16, | ||||
|  | @ -32,14 +32,14 @@ def detect( | |||
|     # Load model | ||||
|     model = Darknet(net_config_path, img_size) | ||||
| 
 | ||||
|     if weights_file_path.endswith('.pt'):  # pytorch format | ||||
|         if weights_file_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_file_path): | ||||
|             os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_file_path) | ||||
|         checkpoint = torch.load(weights_file_path, map_location='cpu') | ||||
|     if weights_path.endswith('.pt'):  # pytorch format | ||||
|         if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path): | ||||
|             os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path) | ||||
|         checkpoint = torch.load(weights_path, map_location='cpu') | ||||
|         model.load_state_dict(checkpoint['model']) | ||||
|         del checkpoint | ||||
|     else:  # darknet format | ||||
|         load_weights(model, weights_file_path) | ||||
|         load_darknet_weights(model, weights_path) | ||||
| 
 | ||||
|     model.to(device).eval() | ||||
| 
 | ||||
|  | @ -136,8 +136,6 @@ def detect( | |||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser() | ||||
|     # Get data configuration | ||||
| 
 | ||||
|     parser.add_argument('--image-folder', type=str, default='data/samples', help='path to images') | ||||
|     parser.add_argument('--output-folder', type=str, default='output', help='path to outputs') | ||||
|     parser.add_argument('--plot-flag', type=bool, default=True) | ||||
|  |  | |||
							
								
								
									
										18
									
								
								models.py
								
								
								
								
							
							
						
						
									
										18
									
								
								models.py
								
								
								
								
							|  | @ -1,5 +1,6 @@ | |||
| from collections import defaultdict | ||||
| 
 | ||||
| import os | ||||
| import torch.nn as nn | ||||
| 
 | ||||
| from utils.parse_config import * | ||||
|  | @ -333,13 +334,22 @@ class Darknet(nn.Module): | |||
|         return sum(output) if is_training else torch.cat(output, 1) | ||||
| 
 | ||||
| 
 | ||||
| def load_weights(self, weights_path, cutoff=-1): | ||||
| def load_darknet_weights(self, weights_path, cutoff=-1): | ||||
|     # Parses and loads the weights stored in 'weights_path' | ||||
|     # @:param cutoff  - save layers between 0 and cutoff (cutoff = -1 -> all are saved) | ||||
|     # cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved) | ||||
|     weights_file = weights_path.split(os.sep)[-1] | ||||
| 
 | ||||
|     if weights_path.endswith('darknet53.conv.74'): | ||||
|     # Try to download weights if not available locally | ||||
|     if not os.path.isfile(weights_path): | ||||
|         try: | ||||
|             os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights_path) | ||||
|         except: | ||||
|             assert os.path.isfile(weights_path) | ||||
| 
 | ||||
|     # Establish cutoffs | ||||
|     if weights_file == 'darknet53.conv.74': | ||||
|         cutoff = 75 | ||||
|     elif weights_path.endswith('yolov3-tiny.conv.15'): | ||||
|     elif weights_file == 'yolov3-tiny.conv.15': | ||||
|         cutoff = 16 | ||||
| 
 | ||||
|     # Open the weights file | ||||
|  |  | |||
							
								
								
									
										8
									
								
								test.py
								
								
								
								
							
							
						
						
									
										8
									
								
								test.py
								
								
								
								
							|  | @ -10,7 +10,7 @@ from utils import torch_utils | |||
| def test( | ||||
|         net_config_path, | ||||
|         data_config_path, | ||||
|         weights_file_path, | ||||
|         weights_path, | ||||
|         batch_size=16, | ||||
|         img_size=416, | ||||
|         iou_thres=0.5, | ||||
|  | @ -30,12 +30,12 @@ def test( | |||
|     model = Darknet(net_config_path, img_size) | ||||
| 
 | ||||
|     # Load weights | ||||
|     if weights_file_path.endswith('.pt'):  # pytorch format | ||||
|         checkpoint = torch.load(weights_file_path, map_location='cpu') | ||||
|     if weights_path.endswith('.pt'):  # pytorch format | ||||
|         checkpoint = torch.load(weights_path, map_location='cpu') | ||||
|         model.load_state_dict(checkpoint['model']) | ||||
|         del checkpoint | ||||
|     else:  # darknet format | ||||
|         load_weights(model, weights_file_path) | ||||
|         load_darknet_weights(model, weights_path) | ||||
| 
 | ||||
|     model.to(device).eval() | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										11
									
								
								train.py
								
								
								
								
							
							
						
						
									
										11
									
								
								train.py
								
								
								
								
							|  | @ -10,9 +10,6 @@ from utils import torch_utils | |||
| # Import test.py to get mAP after each epoch | ||||
| import test | ||||
| 
 | ||||
| DARKNET_WEIGHTS_FILENAME = 'darknet53.conv.74' | ||||
| DARKNET_WEIGHTS_URL = 'https://pjreddie.com/media/files/{}'.format(DARKNET_WEIGHTS_FILENAME) | ||||
| 
 | ||||
| 
 | ||||
| def train( | ||||
|         net_config_path, | ||||
|  | @ -83,13 +80,7 @@ def train( | |||
|         best_loss = float('inf') | ||||
| 
 | ||||
|         # Initialize model with darknet53 weights (optional) | ||||
|         def_weight_file = os.path.join(weights_path, DARKNET_WEIGHTS_FILENAME) | ||||
|         if not os.path.isfile(def_weight_file): | ||||
|             os.system('wget {} -P {}'.format( | ||||
|                 DARKNET_WEIGHTS_URL, | ||||
|                 weights_path)) | ||||
|         assert os.path.isfile(def_weight_file) | ||||
|         load_weights(model, def_weight_file) | ||||
|         load_darknet_weights(model, os.path.join(weights_path, 'darknet53.conv.74')) | ||||
| 
 | ||||
|         if torch.cuda.device_count() > 1: | ||||
|             raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21') | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue