Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/units_kwarg_solvers #247

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 148 additions & 59 deletions ovos_plugin_manager/templates/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import abc
import inspect
from functools import wraps, lru_cache
from typing import Optional, List, Iterable, Tuple, Dict, Union
from typing import Optional, List, Iterable, Tuple, Dict, Union, Any

from json_database import JsonStorageXDG
from ovos_utils import flatten_list
Expand Down Expand Up @@ -282,12 +282,14 @@ def __init__(self, config: Optional[Dict] = None,
"""
Initialize the QuestionSolver.

:param config: Optional configuration dictionary.
:param translator: Optional language translator.
:param detector: Optional language detector.
:param priority: Priority of the solver.
:param enable_tx: Flag to enable translation.
:param enable_cache: Flag to enable caching.
Args:
config (Optional[Dict]): Optional configuration dictionary.
translator (Optional[LanguageTranslator]): Optional language translator.
detector (Optional[LanguageDetector]): Optional language detector.
priority (int): Priority of the solver.
enable_tx (bool): Flag to enable translation.
enable_cache (bool): Flag to enable caching.
internal_lang (Optional[str]): Internal language code. Defaults to None.
"""
super().__init__(config, translator, detector, priority,
enable_tx, enable_cache, internal_lang,
Expand All @@ -307,69 +309,101 @@ def __init__(self, config: Optional[Dict] = None,

# plugin methods to override
@abc.abstractmethod
def get_spoken_answer(self, query: str, lang: Optional[str] = None) -> str:
def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> str:
"""
Obtain the spoken answer for a given query.

:param query: The query text.
:param lang: Optional language code.
:return: The spoken answer as a text response.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
str: The spoken answer as a text response.
"""
raise NotImplementedError

@_deprecate_context2lang()
def stream_utterances(self, query: str, lang: Optional[str] = None) -> Iterable[str]:
def stream_utterances(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
"""
Stream utterances for the given query as they become available.

:param query: The query text.
:param lang: Optional language code.
:return: An iterable of utterances.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Iterable[str]: An iterable of utterances.
"""
ans = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang)
ans = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang, units=units)
for utt in self.sentence_split(ans):
yield utt

@_deprecate_context2lang()
def get_data(self, query: str, lang: Optional[str] = None) -> Optional[dict]:
def get_data(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[Dict]:
"""
Retrieve data for the given query.

:param query: The query text.
:param lang: Optional language code.
:return: A dictionary containing the answer.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Optional[Dict]: A dictionary containing the answer.
"""
return {"answer": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang)}
return {"answer": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang, units=units)}

@_deprecate_context2lang()
def get_image(self, query: str, lang: Optional[str] = None) -> Optional[str]:
def get_image(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
"""
Get the path or URL to an image associated with the query.

:param query: The query text
:param lang: Optional language code.
:return: The path or URL to a single image.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Optional[str]: The path or URL to a single image.
"""
return None

@_deprecate_context2lang()
def get_expanded_answer(self, query: str, lang: Optional[str] = None) -> List[dict]:
def get_expanded_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> List[dict]:
"""
Get an expanded list of steps to elaborate on the answer.

:param query: The query text
:param lang: Optional language code.
:return: A list of dictionaries with each step containing a title, summary, and optional image.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
List[Dict]: A list of dictionaries with each step containing a title, summary, and optional image.
"""
return [{"title": query,
"summary": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang),
"img": _call_with_sanitized_kwargs(self.get_image, query, lang=lang)}]
"summary": _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang, units=units),
"img": _call_with_sanitized_kwargs(self.get_image, query, lang=lang, units=units)}]

# user facing methods
@_deprecate_context2lang()
@auto_detect_lang(text_keys=["query"])
@auto_translate(translate_keys=["query"])
def search(self, query: str, lang: Optional[str] = None) -> dict:
def search(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> dict:
"""
Perform a search with automatic translation and caching.

Expand All @@ -378,17 +412,21 @@ def search(self, query: str, lang: Optional[str] = None) -> dict:
If translations happens, the returned value of this method will also
be automatically translated back

:param query: The query text.
:param lang: Optional language code.
:return: The data dictionary retrieved from the cache or computed anew.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Dict: The data dictionary retrieved from the cache or computed anew.
"""
# read from cache
if self.enable_cache and query in self.cache:
data = self.cache[query]
else:
# search data
try:
data = _call_with_sanitized_kwargs(self.get_data, query, lang=lang)
data = _call_with_sanitized_kwargs(self.get_data, query, lang=lang, units=units)
except:
return {}

Expand All @@ -401,7 +439,9 @@ def search(self, query: str, lang: Optional[str] = None) -> dict:
@_deprecate_context2lang()
@auto_detect_lang(text_keys=["query"])
@auto_translate(translate_keys=["query"])
def visual_answer(self, query: str, lang: Optional[str] = None) -> str:
def visual_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> str:
"""
Retrieve the image associated with the query with automatic translation and caching.

Expand All @@ -410,16 +450,22 @@ def visual_answer(self, query: str, lang: Optional[str] = None) -> str:
If translations happens, the returned value of this method will also
be automatically translated back

:param query: The query text.
:param lang: Optional language code.
:return: The path or URL to the image.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
str: The path or URL to the image.
"""
return _call_with_sanitized_kwargs(self.get_image, query, lang=lang)
return _call_with_sanitized_kwargs(self.get_image, query, lang=lang, units=units)

@_deprecate_context2lang()
@auto_detect_lang(text_keys=["query"])
@auto_translate(translate_keys=["query"])
def spoken_answer(self, query: str, lang: Optional[str] = None) -> str:
def spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> str:
"""
Retrieve the spoken answer for the query with automatic translation and caching.

Expand All @@ -428,17 +474,21 @@ def spoken_answer(self, query: str, lang: Optional[str] = None) -> str:
If translations happens, the returned value of this method will also
be automatically translated back

:param query: The query text.
:param lang: Optional language code.
:return: The spoken answer as a text response.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
str: The spoken answer as a text response.
"""
# get answer
if self.enable_cache and query in self.spoken_cache:
# read from cache
summary = self.spoken_cache[query]
else:

summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang)
summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang, units=units)
# save to cache
if self.enable_cache:
self.spoken_cache[query] = summary
Expand All @@ -448,7 +498,9 @@ def spoken_answer(self, query: str, lang: Optional[str] = None) -> str:
@_deprecate_context2lang()
@auto_detect_lang(text_keys=["query"])
@auto_translate(translate_keys=["query"])
def long_answer(self, query: str, lang: Optional[str] = None) -> List[dict]:
def long_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> List[Dict]:
"""
Retrieve a detailed list of steps to expand the answer.

Expand All @@ -457,18 +509,21 @@ def long_answer(self, query: str, lang: Optional[str] = None) -> List[dict]:
If translations happens, the returned value of this method will also
be automatically translated back

:param query: The query text.
:param lang: Optional language code.
:return: A list of steps to elaborate on the answer, with each step containing a title, summary, and optional image.
Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
List[Dict]: A list of steps to elaborate on the answer, with each step containing a title, summary, and optional image.
"""
steps = _call_with_sanitized_kwargs(self.get_expanded_answer, query, lang=lang)
steps = _call_with_sanitized_kwargs(self.get_expanded_answer, query, lang=lang, units=units)
# use spoken_answer as last resort
if not steps:
summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang)
summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, lang=lang, units=units)
if summary:
img = _call_with_sanitized_kwargs(self.get_image, query, lang=lang)
steps = [{"title": query, "summary": step0, "img": img}
for step0 in self.sentence_split(summary, -1)]
img = _call_with_sanitized_kwargs(self.get_image, query, lang=lang, units=units)
steps = [{"title": query, "summary": step, "img": img} for step in self.sentence_split(summary, -1)]
return steps


Expand Down Expand Up @@ -711,7 +766,19 @@ def entails(self, premise: str, hypothesis: str, lang: Optional[str] = None) ->
return self.check_entailment(premise, hypothesis, lang=lang)


def _do_tx(solver, data, source_lang, target_lang):
def _do_tx(solver, data: Any, source_lang: str, target_lang: str) -> Any:
"""
Translate the given data from source language to target language using the provided solver.

Args:
solver: The translation solver.
data (Any): The data to translate. Can be a string, list, dictionary, or tuple.
source_lang (str): The source language code.
target_lang (str): The target language code.

Returns:
Any: The translated data in the same structure as the input data.
"""
if isinstance(data, str):
return solver.translate(data,
source_lang=source_lang, target_lang=target_lang)
Expand All @@ -734,14 +801,36 @@ def _do_tx(solver, data, source_lang, target_lang):
return data


def _call_with_sanitized_kwargs(func, *args, lang: Optional[str] = None):
# Inspect the function signature to ensure it has both 'lang' and 'context' parameters
def _call_with_sanitized_kwargs(func, *args: Any,
lang: Optional[str] = None,
units: Optional[str] = None) -> Any:
"""
Call a function with sanitized keyword arguments for language and units.

Args:
func: The function to call.
args (Any): Positional arguments to pass to the function.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Any: The result of the function call.
"""
params = inspect.signature(func).parameters
kwargs = {}

if "lang" in params:
# new style - only lang is passed
# new style - only lang/units is passed
kwargs["lang"] = lang
elif "context" in kwargs:
# old style - when plugins received context only
kwargs["context"]["lang"] = lang

if "units" in params:
# new style - only lang/units is passed
kwargs["units"] = units
elif "context" in kwargs:
# old style - when plugins received context only
kwargs["context"]["units"] = units

return func(*args, **kwargs)
Loading