diff --git a/detect.py b/detect.py index d4fed6b5..ec53c697 100644 --- a/detect.py +++ b/detect.py @@ -13,7 +13,7 @@ def detect(save_txt=False, save_img=False, stream_img=False): streams = 'streams' in source and source.endswith('.txt') # Initialize - device = torch_utils.select_device(force_cpu=ONNX_EXPORT) + device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device) if os.path.exists(out): shutil.rmtree(out) # delete output folder os.makedirs(out) # make new output folder @@ -141,6 +141,7 @@ if __name__ == '__main__': parser.add_argument('--nms-thres', type=float, default=0.5, help='iou threshold for non-maximum suppression') parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)') parser.add_argument('--half', action='store_true', help='half precision FP16 inference') + parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') opt = parser.parse_args() print(opt) diff --git a/train.py b/train.py index cb77f2d5..55e82af4 100644 --- a/train.py +++ b/train.py @@ -388,7 +388,7 @@ if __name__ == '__main__': parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture') # defaultpw, uCE, uBCE parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training') parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') - parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1') + parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') parser.add_argument('--adam', action='store_true', help='use adam optimizer') parser.add_argument('--var', type=float, help='debug variable') opt = parser.parse_args() diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 4805c23b..6b4b2624 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,4 +1,5 @@ import os + import torch @@ -14,9 +15,10 @@ def init_seeds(seed=0): torch.backends.cudnn.benchmark = False -def select_device(device=None, force_cpu=False, apex=False): - # Set environment variable if device is specified - if device: +def select_device(device=None, apex=False): + if device == 'cpu': + force_cpu = True + elif device: # Set environment variable if device is specified os.environ['CUDA_VISIBLE_DEVICES'] = device # apex if mixed precision training https://github.com/NVIDIA/apex