Skip to content

Commit

Permalink
Add option to use query lists from options in TAPQueryRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
stvoutsin authored and rra committed Dec 14, 2023
1 parent 8833179 commit 70b4760
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 24 deletions.
11 changes: 8 additions & 3 deletions src/mobu/models/business/tapqueryrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ class TAPQueryRunnerOptions(BusinessOptions):
example=True,
)

queries: list | None = Field(
None,
title="Which query list to use for the TapQueryRunner",
description="List of queries to be run instead of a query_set",
example=True,
)


class TAPQueryRunnerConfig(BusinessConfig):
"""Configuration specialization for TAPQueryRunner."""

type: Literal["TAPQueryRunner"] = Field(
..., title="Type of business to run"
)
type: Literal["TAPQueryRunner"] = Field(..., title="Type of business to run")

options: TAPQueryRunnerOptions = Field(
default_factory=TAPQueryRunnerOptions,
Expand Down
203 changes: 186 additions & 17 deletions src/mobu/services/business/tapqueryrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import asyncio
import importlib.resources
import math
from typing import Any, Protocol, Union, runtime_checkable
from concurrent.futures import ThreadPoolExecutor
from random import SystemRandom

from enum import Enum
import jinja2
import pyvo
import requests
Expand All @@ -26,6 +27,95 @@
from .base import Business


@runtime_checkable
class TAPQueryContext(Protocol):
"""Query Context Protocol
Defines the methods that should be implemented for various query context implementations.
Query context: Where/how the collection of queries to be run is generated from.
"""

def __init__(self, **_arg: Any) -> None:
...

class ContextTypes(Enum):
"""Define different types of query contexts"""

QUERY_LIST = "QUERY_LIST" # List of queries passed in as options
TEMPLATES = "TEMPLATES" # Templates from local filepath

@property
def context_type(self) -> TAPQueryContext.ContextTypes:
"""Get the context type"""
...

def get_next_query(self) -> str:
"""Get the next query"""
...


class TAPQueryContextTemplates:
"""Context is template based here, i.e. the queries are read from local filepath as
Jinja templates.
"""

def __init__(self, taprunner: TAPQueryRunner) -> None:
self.taprunner = taprunner

@property
def context_type(self) -> TAPQueryContext.ContextTypes:
"""Get Context Type
Returns:
TAPQueryContext.ContextTypes: The context type
"""
return TAPQueryContext.ContextTypes.TEMPLATES

def get_next_query(self) -> str:
"""Get a query from the query_set randomly, using the random_engine of the TAP Runner
Render query from template, using generated parameters
Returns:
str: The next query string
"""
template_name = self.taprunner.random_engine.choice(
self.taprunner.env.list_templates(["sql"])
)
template = self.taprunner.env.get_template(template_name)
query = template.render(self.taprunner.generated_params)
return query


class TAPQueryContextQueryList:
"""Context for generating queries from a given list of query strings."""

def __init__(self, taprunner: TAPQueryRunner) -> None:
self.taprunner = taprunner

@property
def context_type(self) -> TAPQueryContext.ContextTypes:
"""Get Context Type
Returns:
TAPQueryContext.ContextTypes: The context type
"""
return TAPQueryContext.ContextTypes.QUERY_LIST

def get_next_query(self) -> str:
"""Get a query from the list randomly, using the random_engine of the TAP Runner
Returns:
str: The next query string
"""
return self.taprunner.random_engine.choice(self.taprunner.queries)


# Mapping of context types to TAPQueryContext class type
TAP_QUERY_CONTEXTS = {
TAPQueryContext.ContextTypes.QUERY_LIST: TAPQueryContextQueryList,
TAPQueryContext.ContextTypes.TEMPLATES: TAPQueryContextTemplates,
}


class TAPQueryRunner(Business):
"""Run queries against TAP.
Expand Down Expand Up @@ -53,33 +143,104 @@ def __init__(
self._client: pyvo.dal.TAPService | None = None
self._random = SystemRandom()
self._pool = ThreadPoolExecutor(max_workers=1)
self._context = self._get_context(options)
self._env = self._get_environment()
self._params = self._get_params()
self._queries = options.queries

async def startup(self) -> None:
if self._context.context_type is TAPQueryContext.ContextTypes.TEMPLATES:
templates = self._env.list_templates(["sql"])
self.logger.info("Query templates to choose from: %s", templates)
with self.timings.start("make_client"):
self._client = self._make_client(self.user.token)

def get_next_query(self) -> str:
"""Get the next query string from the context
Returns:
str: The next query string
"""
return self._context.get_next_query()

@property
def queries(self):
return self._queries

@property
def env(self):
return self._env

@property
def random_engine(self):
return self._random

@property
def params(self):
return self._params

@property
def generated_params(self):
return self._generate_parameters()

def _get_context(
self,
options,
) -> Union[TAPQueryContextQueryList, TAPQueryContextTemplates]:
"""Get the context for this TAP query runner
Parameters:
options (TAPQueryRunnerOptions): The runner options based on which to get the context
"""
if options.queries:
return TAP_QUERY_CONTEXTS[TAPQueryContext.ContextTypes.QUERY_LIST](
taprunner=self
)
return TAP_QUERY_CONTEXTS[TAPQueryContext.ContextTypes.TEMPLATES](
taprunner=self
)

def _get_environment(self) -> Union[jinja2.Environment, None]:
"""Get the jinha2 template if applicable else return None
Returns:
Union[jinja2.Environment, None]: Return the jinja2 Environment, or None
"""
if self._context.context_type is not TAPQueryContext.ContextTypes.TEMPLATES:
return None

# Load templates and parameters. The path has to be specified in two
# different ways: as a relative path for Jinja's PackageLoader, and as
# a sequence of joinpath operations for importlib.resources.
template_path = ("data", "tapqueryrunner", options.query_set)
self._env = jinja2.Environment(
template_path = ("data", "tapqueryrunner", self.options.query_set)
env = jinja2.Environment(
loader=jinja2.PackageLoader("mobu", "/".join(template_path)),
undefined=jinja2.StrictUndefined,
autoescape=jinja2.select_autoescape(disabled_extensions=["sql"]),
)
return env

def _get_params(self) -> Union[dict, None]:
"""Get the parameters as a dictionary if applicable else return None
Returns:
Union[dict, None]: Return the parameters as a dict, or None
"""
if self._context.context_type is not TAPQueryContext.ContextTypes.TEMPLATES:
return None
template_path = ("data", "tapqueryrunner", self.options.query_set)
files = importlib.resources.files("mobu")
for directory in template_path:
files = files.joinpath(directory)
with files.joinpath("params.yaml").open("r") as f:
self._params = yaml.safe_load(f)

async def startup(self) -> None:
templates = self._env.list_templates(["sql"])
self.logger.info("Query templates to choose from: %s", templates)
with self.timings.start("make_client"):
self._client = self._make_client(self.user.token)
params = yaml.safe_load(f)
return params

async def execute(self) -> None:
template_name = self._random.choice(self._env.list_templates(["sql"]))
template = self._env.get_template(template_name)
query = template.render(self._generate_parameters())

"""Get and execute the next query from the context, synchronously or asynchronously
depending on options
"""
query = self.get_next_query()
with self.timings.start("execute_query", {"query": query}) as sw:
self._running_query = query

Expand All @@ -102,6 +263,11 @@ async def execute(self) -> None:
self.logger.info(f"Query finished after {elapsed} seconds")

async def run_async_query(self, query: str) -> None:
"""Run the query asynchronously
Parameters:
query (str): The query string to execute
"""
if not self._client:
raise RuntimeError("TAPQueryRunner startup never ran")
self.logger.info("Running (async): %s", query)
Expand All @@ -114,6 +280,11 @@ async def run_async_query(self, query: str) -> None:
job.delete()

async def run_sync_query(self, query: str) -> None:
"""Run the query synchronously
Parameters:
query (str): The query string to execute
"""
if not self._client:
raise RuntimeError("TAPQueryRunner startup never ran")
self.logger.info("Running (sync): %s", query)
Expand Down Expand Up @@ -191,9 +362,7 @@ def _generate_parameters(self) -> dict[str, int | float | str]:
radius_range=radius_range,
),
"radius": min_radius + self._random.random() * radius_range,
"radius_near": (
min_radius + self._random.random() * radius_near_range
),
"radius_near": (min_radius + self._random.random() * radius_near_range),
"username": self.user.username,
"query_id": "mobu-" + shortuuid.uuid(),
}
Expand Down
21 changes: 17 additions & 4 deletions tests/business/tapqueryrunner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,7 @@ async def test_alert(
"type": "section",
"text": {
"type": "mrkdwn",
"text": (
"*Error*\n"
"```\nException: some error\n```"
),
"text": ("*Error*\n" "```\nException: some error\n```"),
"verbatim": True,
},
},
Expand Down Expand Up @@ -263,3 +260,19 @@ async def test_random_object() -> None:
assert len(random_objects) == 12
for obj in random_objects:
assert obj in objects


@pytest.mark.asyncio
async def test_query_list() -> None:
queries = [
"SELECT TOP 10 * FROM TAP_SCHEMA.tables",
"SELECT TOP 10 * FROM TAP_SCHEMA.columns",
]
user = AuthenticatedUser(username="user", scopes=["read:tap"], token="blah blah")
logger = structlog.get_logger(__file__)
options = TAPQueryRunnerOptions(queries=queries)
http_client = await http_client_dependency()
with patch.object(pyvo.dal, "TAPService"):
runner = TAPQueryRunner(options, user, http_client, logger)
generated_query = runner.get_next_query()
assert generated_query in queries

0 comments on commit 70b4760

Please sign in to comment.