Extract seed and cuda initialization utils
This commit is contained in:
parent
45ee668fd7
commit
5a566454f5
|
@ -5,8 +5,6 @@ from models import *
|
|||
from utils.datasets import *
|
||||
from utils.utils import *
|
||||
|
||||
cuda = torch.cuda.is_available()
|
||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
||||
f_path = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -28,6 +26,10 @@ print(opt)
|
|||
|
||||
|
||||
def main(opt):
|
||||
|
||||
device = torch_utils.select_device()
|
||||
print("Using device: \"{}\"".format(device))
|
||||
|
||||
os.system('rm -rf ' + opt.output_folder)
|
||||
os.makedirs(opt.output_folder, exist_ok=True)
|
||||
|
||||
|
|
11
test.py
11
test.py
|
@ -4,6 +4,8 @@ from models import *
|
|||
from utils.datasets import *
|
||||
from utils.utils import *
|
||||
|
||||
from utils import torch_utils
|
||||
|
||||
parser = argparse.ArgumentParser(prog='test.py')
|
||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
||||
|
@ -18,11 +20,11 @@ parser.add_argument('-img_size', type=int, default=416, help='size of each image
|
|||
opt = parser.parse_args()
|
||||
print(opt, end='\n\n')
|
||||
|
||||
cuda = torch.cuda.is_available()
|
||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
|
||||
def main(opt):
|
||||
device = torch_utils.select_device()
|
||||
print("Using device: \"{}\"".format(device))
|
||||
|
||||
# Configure run
|
||||
data_config = parse_data_config(opt.data_config_path)
|
||||
nC = int(data_config['classes']) # number of classes (80 for COCO)
|
||||
|
@ -128,4 +130,7 @@ def main(opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
init_seeds()
|
||||
|
||||
mAP = main(opt)
|
||||
|
|
20
train.py
20
train.py
|
@ -6,6 +6,8 @@ from models import *
|
|||
from utils.datasets import *
|
||||
from utils.utils import *
|
||||
|
||||
from utils import torch_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-epochs', type=int, default=100, help='number of epochs')
|
||||
parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch')
|
||||
|
@ -26,20 +28,15 @@ print(opt)
|
|||
sys.argv[1:] = [] # delete any train.py command-line arguments before they reach test.py
|
||||
import test # must follow sys.argv[1:] = []
|
||||
|
||||
cuda = torch.cuda.is_available()
|
||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if cuda:
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
def main(opt):
|
||||
|
||||
device = torch_utils.select_device()
|
||||
print("Using device: \"{}\"".format(device))
|
||||
|
||||
if not opt.multi_scale:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main(opt):
|
||||
os.makedirs('weights', exist_ok=True)
|
||||
|
||||
# Configure run
|
||||
|
@ -217,5 +214,8 @@ def main(opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
init_seeds()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
main(opt)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
|
||||
|
||||
def check_cuda():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
CUDA_AVAILABLE = check_cuda()
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
torch.manual_seed(seed)
|
||||
if CUDA_AVAILABLE:
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def select_device(force_cpu=False):
|
||||
if force_cpu:
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
|
||||
return device
|
|
@ -5,11 +5,19 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils import torch_utils
|
||||
|
||||
# Set printoptions
|
||||
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch_utils.init_seeds(seed=seed)
|
||||
|
||||
|
||||
def load_classes(path):
|
||||
"""
|
||||
Loads class labels at 'path'
|
||||
|
|
Loading…
Reference in New Issue