diff --git a/detect.py b/detect.py index 2e9ed61d..3e74c6aa 100755 --- a/detect.py +++ b/detect.py @@ -7,6 +7,7 @@ from utils.utils import * from utils import torch_utils + def detect( net_config_path, data_config_path, @@ -68,7 +69,8 @@ def detect( # cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed img = torch.from_numpy(img).unsqueeze(0).to(device) if ONNX_EXPORT: - pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export + pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); + return # ONNX export pred = model(img) pred = pred[pred[:, :, 4] > conf_thres] @@ -90,18 +92,17 @@ def detect( for img_i, (path, detections) in enumerate(zip(imgs, img_detections)): print("image %g: '%s'" % (img_i, path)) - if save_images: - img = cv2.imread(path) - - # The amount of padding that was added - pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape)) - pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape)) - # Image height and width after padding is removed - unpad_h = img_size - pad_y - unpad_w = img_size - pad_x - # Draw bounding boxes and labels of detections if detections is not None: + img = cv2.imread(path) + + # The amount of padding that was added + pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape)) + pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape)) + # Image height and width after padding is removed + unpad_h = img_size - pad_y + unpad_w = img_size - pad_x + unique_classes = detections[:, -1].cpu().unique() bbox_colors = random.sample(color_list, len(unique_classes)) @@ -136,9 +137,9 @@ def detect( color = bbox_colors[int(np.where(unique_classes == int(cls_pred))[0])] plot_one_box([x1, y1, x2, y2], img, label=label, color=color) - if save_images: - # Save generated image with detections - cv2.imwrite(results_img_path.replace('.bmp', '.jpg').replace('.tif', '.jpg'), img) + if save_images: + # Save generated image with detections + cv2.imwrite(results_img_path.replace('.bmp', '.jpg').replace('.tif', '.jpg'), img) if platform == 'darwin': # MacOS (local) os.system('open ' + output)