Extract seed and cuda initialization utils
This commit is contained in:
parent
45ee668fd7
commit
5a566454f5
|
@ -5,8 +5,6 @@ from models import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
cuda = torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
f_path = os.path.dirname(os.path.realpath(__file__)) + '/'
|
f_path = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -28,6 +26,10 @@ print(opt)
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
|
||||||
|
device = torch_utils.select_device()
|
||||||
|
print("Using device: \"{}\"".format(device))
|
||||||
|
|
||||||
os.system('rm -rf ' + opt.output_folder)
|
os.system('rm -rf ' + opt.output_folder)
|
||||||
os.makedirs(opt.output_folder, exist_ok=True)
|
os.makedirs(opt.output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
|
11
test.py
11
test.py
|
@ -4,6 +4,8 @@ from models import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
from utils import torch_utils
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog='test.py')
|
parser = argparse.ArgumentParser(prog='test.py')
|
||||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
||||||
|
@ -18,11 +20,11 @@ parser.add_argument('-img_size', type=int, default=416, help='size of each image
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt, end='\n\n')
|
print(opt, end='\n\n')
|
||||||
|
|
||||||
cuda = torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
device = torch_utils.select_device()
|
||||||
|
print("Using device: \"{}\"".format(device))
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data_config = parse_data_config(opt.data_config_path)
|
data_config = parse_data_config(opt.data_config_path)
|
||||||
nC = int(data_config['classes']) # number of classes (80 for COCO)
|
nC = int(data_config['classes']) # number of classes (80 for COCO)
|
||||||
|
@ -128,4 +130,7 @@ def main(opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
init_seeds()
|
||||||
|
|
||||||
mAP = main(opt)
|
mAP = main(opt)
|
||||||
|
|
20
train.py
20
train.py
|
@ -6,6 +6,8 @@ from models import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
from utils import torch_utils
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-epochs', type=int, default=100, help='number of epochs')
|
parser.add_argument('-epochs', type=int, default=100, help='number of epochs')
|
||||||
parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch')
|
||||||
|
@ -26,20 +28,15 @@ print(opt)
|
||||||
sys.argv[1:] = [] # delete any train.py command-line arguments before they reach test.py
|
sys.argv[1:] = [] # delete any train.py command-line arguments before they reach test.py
|
||||||
import test # must follow sys.argv[1:] = []
|
import test # must follow sys.argv[1:] = []
|
||||||
|
|
||||||
cuda = torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
|
||||||
|
|
||||||
random.seed(0)
|
def main(opt):
|
||||||
np.random.seed(0)
|
|
||||||
torch.manual_seed(0)
|
device = torch_utils.select_device()
|
||||||
if cuda:
|
print("Using device: \"{}\"".format(device))
|
||||||
torch.cuda.manual_seed(0)
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
if not opt.multi_scale:
|
if not opt.multi_scale:
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
|
||||||
os.makedirs('weights', exist_ok=True)
|
os.makedirs('weights', exist_ok=True)
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
|
@ -217,5 +214,8 @@ def main(opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
init_seeds()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_AVAILABLE = check_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def init_seeds(seed=0):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if CUDA_AVAILABLE:
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def select_device(force_cpu=False):
|
||||||
|
if force_cpu:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
|
||||||
|
return device
|
|
@ -5,11 +5,19 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils import torch_utils
|
||||||
|
|
||||||
# Set printoptions
|
# Set printoptions
|
||||||
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
|
||||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||||
|
|
||||||
|
|
||||||
|
def init_seeds(seed=0):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch_utils.init_seeds(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
def load_classes(path):
|
def load_classes(path):
|
||||||
"""
|
"""
|
||||||
Loads class labels at 'path'
|
Loads class labels at 'path'
|
||||||
|
|
Loading…
Reference in New Issue