This commit is contained in:
Glenn Jocher 2019-09-13 16:00:52 +02:00
parent 5452bb7036
commit 4286bba40f
3 changed files with 8 additions and 5 deletions

View File

@ -13,7 +13,7 @@ def detect(save_txt=False, save_img=False, stream_img=False):
streams = 'streams' in source and source.endswith('.txt')
# Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder
@ -141,6 +141,7 @@ if __name__ == '__main__':
parser.add_argument('--nms-thres', type=float, default=0.5, help='iou threshold for non-maximum suppression')
parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)')
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
opt = parser.parse_args()
print(opt)

View File

@ -388,7 +388,7 @@ if __name__ == '__main__':
parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture') # defaultpw, uCE, uBCE
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--var', type=float, help='debug variable')
opt = parser.parse_args()

View File

@ -1,4 +1,5 @@
import os
import torch
@ -14,9 +15,10 @@ def init_seeds(seed=0):
torch.backends.cudnn.benchmark = False
def select_device(device=None, force_cpu=False, apex=False):
# Set environment variable if device is specified
if device:
def select_device(device=None, apex=False):
if device == 'cpu':
force_cpu = True
elif device: # Set environment variable if device is specified
os.environ['CUDA_VISIBLE_DEVICES'] = device
# apex if mixed precision training https://github.com/NVIDIA/apex