updates
This commit is contained in:
parent
8b9aae484b
commit
8b88e50f2f
14
detect.py
14
detect.py
|
@ -11,7 +11,7 @@ from utils import torch_utils
|
||||||
def detect(
|
def detect(
|
||||||
net_config_path,
|
net_config_path,
|
||||||
data_config_path,
|
data_config_path,
|
||||||
weights_file_path,
|
weights_path,
|
||||||
images_path,
|
images_path,
|
||||||
output='output',
|
output='output',
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
@ -32,14 +32,14 @@ def detect(
|
||||||
# Load model
|
# Load model
|
||||||
model = Darknet(net_config_path, img_size)
|
model = Darknet(net_config_path, img_size)
|
||||||
|
|
||||||
if weights_file_path.endswith('.pt'): # pytorch format
|
if weights_path.endswith('.pt'): # pytorch format
|
||||||
if weights_file_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_file_path):
|
if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path):
|
||||||
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_file_path)
|
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path)
|
||||||
checkpoint = torch.load(weights_file_path, map_location='cpu')
|
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
del checkpoint
|
del checkpoint
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_weights(model, weights_file_path)
|
load_darknet_weights(model, weights_path)
|
||||||
|
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
|
||||||
|
@ -136,8 +136,6 @@ def detect(
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# Get data configuration
|
|
||||||
|
|
||||||
parser.add_argument('--image-folder', type=str, default='data/samples', help='path to images')
|
parser.add_argument('--image-folder', type=str, default='data/samples', help='path to images')
|
||||||
parser.add_argument('--output-folder', type=str, default='output', help='path to outputs')
|
parser.add_argument('--output-folder', type=str, default='output', help='path to outputs')
|
||||||
parser.add_argument('--plot-flag', type=bool, default=True)
|
parser.add_argument('--plot-flag', type=bool, default=True)
|
||||||
|
|
18
models.py
18
models.py
|
@ -1,5 +1,6 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import os
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
|
@ -333,13 +334,22 @@ class Darknet(nn.Module):
|
||||||
return sum(output) if is_training else torch.cat(output, 1)
|
return sum(output) if is_training else torch.cat(output, 1)
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, weights_path, cutoff=-1):
|
def load_darknet_weights(self, weights_path, cutoff=-1):
|
||||||
# Parses and loads the weights stored in 'weights_path'
|
# Parses and loads the weights stored in 'weights_path'
|
||||||
# @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
|
# cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved)
|
||||||
|
weights_file = weights_path.split(os.sep)[-1]
|
||||||
|
|
||||||
if weights_path.endswith('darknet53.conv.74'):
|
# Try to download weights if not available locally
|
||||||
|
if not os.path.isfile(weights_path):
|
||||||
|
try:
|
||||||
|
os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights_path)
|
||||||
|
except:
|
||||||
|
assert os.path.isfile(weights_path)
|
||||||
|
|
||||||
|
# Establish cutoffs
|
||||||
|
if weights_file == 'darknet53.conv.74':
|
||||||
cutoff = 75
|
cutoff = 75
|
||||||
elif weights_path.endswith('yolov3-tiny.conv.15'):
|
elif weights_file == 'yolov3-tiny.conv.15':
|
||||||
cutoff = 16
|
cutoff = 16
|
||||||
|
|
||||||
# Open the weights file
|
# Open the weights file
|
||||||
|
|
8
test.py
8
test.py
|
@ -10,7 +10,7 @@ from utils import torch_utils
|
||||||
def test(
|
def test(
|
||||||
net_config_path,
|
net_config_path,
|
||||||
data_config_path,
|
data_config_path,
|
||||||
weights_file_path,
|
weights_path,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
iou_thres=0.5,
|
iou_thres=0.5,
|
||||||
|
@ -30,12 +30,12 @@ def test(
|
||||||
model = Darknet(net_config_path, img_size)
|
model = Darknet(net_config_path, img_size)
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
if weights_file_path.endswith('.pt'): # pytorch format
|
if weights_path.endswith('.pt'): # pytorch format
|
||||||
checkpoint = torch.load(weights_file_path, map_location='cpu')
|
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
del checkpoint
|
del checkpoint
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_weights(model, weights_file_path)
|
load_darknet_weights(model, weights_path)
|
||||||
|
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
|
||||||
|
|
11
train.py
11
train.py
|
@ -10,9 +10,6 @@ from utils import torch_utils
|
||||||
# Import test.py to get mAP after each epoch
|
# Import test.py to get mAP after each epoch
|
||||||
import test
|
import test
|
||||||
|
|
||||||
DARKNET_WEIGHTS_FILENAME = 'darknet53.conv.74'
|
|
||||||
DARKNET_WEIGHTS_URL = 'https://pjreddie.com/media/files/{}'.format(DARKNET_WEIGHTS_FILENAME)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
net_config_path,
|
net_config_path,
|
||||||
|
@ -83,13 +80,7 @@ def train(
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
# Initialize model with darknet53 weights (optional)
|
# Initialize model with darknet53 weights (optional)
|
||||||
def_weight_file = os.path.join(weights_path, DARKNET_WEIGHTS_FILENAME)
|
load_darknet_weights(model, os.path.join(weights_path, 'darknet53.conv.74'))
|
||||||
if not os.path.isfile(def_weight_file):
|
|
||||||
os.system('wget {} -P {}'.format(
|
|
||||||
DARKNET_WEIGHTS_URL,
|
|
||||||
weights_path))
|
|
||||||
assert os.path.isfile(def_weight_file)
|
|
||||||
load_weights(model, def_weight_file)
|
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21')
|
raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21')
|
||||||
|
|
Loading…
Reference in New Issue