This commit is contained in:
Glenn Jocher 2019-01-08 19:37:23 +01:00
parent fcda9a2fa0
commit acfe4aaf94
1 changed files with 15 additions and 14 deletions

View File

@ -7,6 +7,7 @@ from utils.utils import *
from utils import torch_utils from utils import torch_utils
def detect( def detect(
net_config_path, net_config_path,
data_config_path, data_config_path,
@ -68,7 +69,8 @@ def detect(
# cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed # cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed
img = torch.from_numpy(img).unsqueeze(0).to(device) img = torch.from_numpy(img).unsqueeze(0).to(device)
if ONNX_EXPORT: if ONNX_EXPORT:
pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True);
return # ONNX export
pred = model(img) pred = model(img)
pred = pred[pred[:, :, 4] > conf_thres] pred = pred[pred[:, :, 4] > conf_thres]
@ -90,7 +92,8 @@ def detect(
for img_i, (path, detections) in enumerate(zip(imgs, img_detections)): for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):
print("image %g: '%s'" % (img_i, path)) print("image %g: '%s'" % (img_i, path))
if save_images: # Draw bounding boxes and labels of detections
if detections is not None:
img = cv2.imread(path) img = cv2.imread(path)
# The amount of padding that was added # The amount of padding that was added
@ -100,8 +103,6 @@ def detect(
unpad_h = img_size - pad_y unpad_h = img_size - pad_y
unpad_w = img_size - pad_x unpad_w = img_size - pad_x
# Draw bounding boxes and labels of detections
if detections is not None:
unique_classes = detections[:, -1].cpu().unique() unique_classes = detections[:, -1].cpu().unique()
bbox_colors = random.sample(color_list, len(unique_classes)) bbox_colors = random.sample(color_list, len(unique_classes))