From 97909df1a600e0f950aa36cc6c44cedbed2ab3fc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 10 Feb 2019 21:06:22 +0100 Subject: [PATCH] updates --- detect.py | 81 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/detect.py b/detect.py index a707b9e0..36727f8f 100755 --- a/detect.py +++ b/detect.py @@ -8,10 +8,18 @@ from utils.utils import * from utils import torch_utils -def detect(cfg, weights, images, output='output', img_size=416, conf_thres=0.3, nms_thres=0.45, save_txt=False, - save_images=True): +def detect( + cfg, + weights, + images, + output='output', + img_size=416, + conf_thres=0.3, + nms_thres=0.45, + save_txt=False, + save_images=True +): device = torch_utils.select_device() - os.system('rm -rf ' + output) os.makedirs(output, exist_ok=True) @@ -39,43 +47,42 @@ def detect(cfg, weights, images, output='output', img_size=416, conf_thres=0.3, t = time.time() # Get detections - with torch.no_grad(): - img = torch.from_numpy(img).unsqueeze(0).to(device) - if ONNX_EXPORT: - pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True) - return # ONNX export - pred = model(img) - pred = pred[pred[:, :, 4] > conf_thres] + img = torch.from_numpy(img).unsqueeze(0).to(device) + if ONNX_EXPORT: + pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True) + return # ONNX export + pred = model(img) + pred = pred[pred[:, :, 4] > conf_thres] - if len(pred) > 0: - detections = non_max_suppression(pred.unsqueeze(0), conf_thres, nms_thres)[0] + if len(pred) > 0: + detections = non_max_suppression(pred.unsqueeze(0), conf_thres, nms_thres)[0] - # Draw bounding boxes and labels of detections - if detections is not None: - save_img_path = os.path.join(output, path.split('/')[-1]) - save_txt_path = save_img_path + '.txt' + # Draw bounding boxes and labels of detections + if detections is not None: + save_img_path = os.path.join(output, path.split('/')[-1]) + save_txt_path = save_img_path + '.txt' - # Rescale boxes from 416 to true image size - detections[:, :4] = scale_coords(img_size, detections[:, :4], im0.shape) + # Rescale boxes from 416 to true image size + detections[:, :4] = scale_coords(img_size, detections[:, :4], im0.shape) - unique_classes = detections[:, -1].cpu().unique() - for i in unique_classes: - n = (detections[:, -1].cpu() == i).sum() - print('%g %ss' % (n, classes[int(i)]), end=', ') + unique_classes = detections[:, -1].cpu().unique() + for i in unique_classes: + n = (detections[:, -1].cpu() == i).sum() + print('%g %ss' % (n, classes[int(i)]), end=', ') - for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: - if save_txt: # Write to file - with open(save_txt_path, 'a') as file: - file.write('%g %g %g %g %g %g\n' % (x1, y1, x2, y2, cls_pred, cls_conf * conf)) + for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: + if save_txt: # Write to file + with open(save_txt_path, 'a') as file: + file.write('%g %g %g %g %g %g\n' % (x1, y1, x2, y2, cls_pred, cls_conf * conf)) - if save_images: # Add bbox to the image - label = '%s %.2f' % (classes[int(cls_pred)], conf) - plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls_pred)]) + if save_images: # Add bbox to the image + label = '%s %.2f' % (classes[int(cls_pred)], conf) + plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls_pred)]) - if save_images: # Save generated image with detections - cv2.imwrite(save_img_path, im0) + if save_images: # Save generated image with detections + cv2.imwrite(save_img_path, im0) - print(' Done. (%.3fs)' % (time.time() - t)) + print(' Done. (%.3fs)' % (time.time() - t)) if platform == 'darwin': # MacOS os.system('open ' + output + '&& open ' + save_img_path) @@ -92,4 +99,12 @@ if __name__ == '__main__': opt = parser.parse_args() print(opt) - detect(opt.cfg, opt.weights, opt.images, img_size=opt.img_size, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres) + with torch.no_grad(): + detect( + opt.cfg, + opt.weights, + opt.images, + img_size=opt.img_size, + conf_thres=opt.conf_thres, + nms_thres=opt.nms_thres + )