diff --git a/detect.py b/detect.py index ce1389b2..39ce30f4 100755 --- a/detect.py +++ b/detect.py @@ -11,7 +11,7 @@ from utils import torch_utils def detect( net_config_path, data_config_path, - weights_file_path, + weights_path, images_path, output='output', batch_size=16, @@ -32,14 +32,14 @@ def detect( # Load model model = Darknet(net_config_path, img_size) - if weights_file_path.endswith('.pt'): # pytorch format - if weights_file_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_file_path): - os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_file_path) - checkpoint = torch.load(weights_file_path, map_location='cpu') + if weights_path.endswith('.pt'): # pytorch format + if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path): + os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path) + checkpoint = torch.load(weights_path, map_location='cpu') model.load_state_dict(checkpoint['model']) del checkpoint else: # darknet format - load_weights(model, weights_file_path) + load_darknet_weights(model, weights_path) model.to(device).eval() @@ -136,8 +136,6 @@ def detect( if __name__ == '__main__': parser = argparse.ArgumentParser() - # Get data configuration - parser.add_argument('--image-folder', type=str, default='data/samples', help='path to images') parser.add_argument('--output-folder', type=str, default='output', help='path to outputs') parser.add_argument('--plot-flag', type=bool, default=True) diff --git a/models.py b/models.py index 8090b903..34c573e3 100755 --- a/models.py +++ b/models.py @@ -1,5 +1,6 @@ from collections import defaultdict +import os import torch.nn as nn from utils.parse_config import * @@ -333,13 +334,22 @@ class Darknet(nn.Module): return sum(output) if is_training else torch.cat(output, 1) -def load_weights(self, weights_path, cutoff=-1): +def load_darknet_weights(self, weights_path, cutoff=-1): # Parses and loads the weights stored in 'weights_path' - # @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved) + # cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved) + weights_file = weights_path.split(os.sep)[-1] - if weights_path.endswith('darknet53.conv.74'): + # Try to download weights if not available locally + if not os.path.isfile(weights_path): + try: + os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights_path) + except: + assert os.path.isfile(weights_path) + + # Establish cutoffs + if weights_file == 'darknet53.conv.74': cutoff = 75 - elif weights_path.endswith('yolov3-tiny.conv.15'): + elif weights_file == 'yolov3-tiny.conv.15': cutoff = 16 # Open the weights file diff --git a/test.py b/test.py index 51019b5f..fafd816d 100644 --- a/test.py +++ b/test.py @@ -10,7 +10,7 @@ from utils import torch_utils def test( net_config_path, data_config_path, - weights_file_path, + weights_path, batch_size=16, img_size=416, iou_thres=0.5, @@ -30,12 +30,12 @@ def test( model = Darknet(net_config_path, img_size) # Load weights - if weights_file_path.endswith('.pt'): # pytorch format - checkpoint = torch.load(weights_file_path, map_location='cpu') + if weights_path.endswith('.pt'): # pytorch format + checkpoint = torch.load(weights_path, map_location='cpu') model.load_state_dict(checkpoint['model']) del checkpoint else: # darknet format - load_weights(model, weights_file_path) + load_darknet_weights(model, weights_path) model.to(device).eval() diff --git a/train.py b/train.py index 9232139d..6ab18406 100644 --- a/train.py +++ b/train.py @@ -10,9 +10,6 @@ from utils import torch_utils # Import test.py to get mAP after each epoch import test -DARKNET_WEIGHTS_FILENAME = 'darknet53.conv.74' -DARKNET_WEIGHTS_URL = 'https://pjreddie.com/media/files/{}'.format(DARKNET_WEIGHTS_FILENAME) - def train( net_config_path, @@ -83,13 +80,7 @@ def train( best_loss = float('inf') # Initialize model with darknet53 weights (optional) - def_weight_file = os.path.join(weights_path, DARKNET_WEIGHTS_FILENAME) - if not os.path.isfile(def_weight_file): - os.system('wget {} -P {}'.format( - DARKNET_WEIGHTS_URL, - weights_path)) - assert os.path.isfile(def_weight_file) - load_weights(model, def_weight_file) + load_darknet_weights(model, os.path.join(weights_path, 'darknet53.conv.74')) if torch.cuda.device_count() > 1: raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21')