Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: chat history #286

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions ovos_plugin_manager/solvers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from ovos_plugin_manager.utils import normalize_lang, \
PluginTypes, PluginConfigTypes
from ovos_plugin_manager.templates.solvers import QuestionSolver, TldrSolver, \
EntailmentSolver, MultipleChoiceSolver, EvidenceSolver
from ovos_utils.log import LOG
EntailmentSolver, MultipleChoiceSolver, EvidenceSolver, ChatMessageSolver
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes


def find_chat_solver_plugins() -> dict:
"""
Find all installed plugins
@return: dict plugin names to entrypoints
"""
from ovos_plugin_manager.utils import find_plugins
return find_plugins(PluginTypes.CHAT_SOLVER)


def load_chat_solver_plugin(module_name: str) -> type(ChatMessageSolver):
"""
Get an uninstantiated class for the requested module_name
@param module_name: Plugin entrypoint name to load
@return: Uninstantiated class
"""
from ovos_plugin_manager.utils import load_plugin
return load_plugin(module_name, PluginTypes.CHAT_SOLVER)


JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
def find_question_solver_plugins() -> dict:
"""
Expand Down Expand Up @@ -172,7 +188,7 @@ def get_entailment_solver_module_configs(module_name: str) -> dict:


def get_entailment_solver_lang_configs(lang: str,
include_dialects: bool = False) -> dict:
include_dialects: bool = False) -> dict:
"""
Get a dict of plugin names to list valid configurations for the requested
lang.
Expand Down Expand Up @@ -303,7 +319,7 @@ def get_reading_comprehension_solver_module_configs(module_name: str) -> dict:


def get_reading_comprehension_solver_lang_configs(lang: str,
include_dialects: bool = False) -> dict:
include_dialects: bool = False) -> dict:
"""
Get a dict of plugin names to list valid configurations for the requested
lang.
Expand All @@ -324,4 +340,3 @@ def get_reading_comprehension_solver_supported_langs() -> dict:
from ovos_plugin_manager.utils.config import get_plugin_supported_languages
return get_plugin_supported_languages(
PluginTypes.READING_COMPREHENSION_SOLVER)

50 changes: 49 additions & 1 deletion ovos_plugin_manager/templates/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional, List, Iterable, Tuple, Dict, Union, Any

from json_database import JsonStorageXDG
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.lang import standardize_lang_tag
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.xdg_utils import xdg_cache_home

from ovos_plugin_manager.templates.language import LanguageTranslator, LanguageDetector
Expand Down Expand Up @@ -398,6 +398,54 @@ def long_answer(self, query: str,
return steps


class ChatMessageSolver(QuestionSolver):
"""take chat history as input LLM style
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Knock knock."},
{"role": "assistant", "content": "Who's there?"},
{"role": "user", "content": "Orange."},
]
"""

@abc.abstractmethod
def continue_chat(self, messages: List[Dict[str, str]],
lang: Optional[str],
units: Optional[str] = None) -> Optional[str]:
pass

@auto_detect_lang(text_keys=["messages"])
@auto_translate(translate_keys=["messages"])
def get_chat_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
return self.continue_chat(messages=messages, lang=lang, units=units)

def stream_chat_utterances(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
"""
Stream utterances for the given chat history as they become available.

Args:
messages: The chat messages.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
Iterable[str]: An iterable of utterances.
"""
ans = _call_with_sanitized_kwargs(self.get_chat_completion, messages, lang=lang, units=units)
for utt in self.sentence_split(ans):
yield utt

def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
# just for api compat since it's a subclass, shouldn't be directly used
return self.continue_chat(messages=[{"role": "user", "content": query}], lang=lang, units=units)


class CorpusSolver(QuestionSolver):
"""Retrieval based question solver"""

Expand Down
49 changes: 26 additions & 23 deletions ovos_plugin_manager/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,37 @@ class PluginTypes(str, Enum):
FACE_EMBEDDINGS = "opm.embeddings.face"
VOICE_EMBEDDINGS = "opm.embeddings.voice"
TEXT_EMBEDDINGS = "opm.embeddings.text"
GUI = "ovos.plugin.gui"
PHAL = "ovos.plugin.phal"
ADMIN = "ovos.plugin.phal.admin"
SKILL = "ovos.plugin.skill"
MIC = "ovos.plugin.microphone"
VAD = "ovos.plugin.VAD"
PHONEME = "ovos.plugin.g2p"
AUDIO2IPA = "ovos.plugin.audio2ipa"
GUI = "ovos.plugin.gui" # TODO rename "opm.gui"
PHAL = "ovos.plugin.phal" # TODO rename "opm.phal"
ADMIN = "ovos.plugin.phal.admin" # TODO rename "opm.phal.admin"
SKILL = "ovos.plugin.skill" # TODO rename "opm.skill"
MIC = "ovos.plugin.microphone" # TODO rename "opm.microphone"
VAD = "ovos.plugin.VAD" # TODO rename "opm.vad"
PHONEME = "ovos.plugin.g2p" # TODO rename "opm.g2p"
AUDIO2IPA = "ovos.plugin.audio2ipa" # TODO rename "opm.audio2ipa"
AUDIO = 'mycroft.plugin.audioservice' # DEPRECATED
STT = 'mycroft.plugin.stt'
TTS = 'mycroft.plugin.tts'
WAKEWORD = 'mycroft.plugin.wake_word'
TRANSLATE = "neon.plugin.lang.translate"
LANG_DETECT = "neon.plugin.lang.detect"
UTTERANCE_TRANSFORMER = "neon.plugin.text"
METADATA_TRANSFORMER = "neon.plugin.metadata"
AUDIO_TRANSFORMER = "neon.plugin.audio"
STT = 'mycroft.plugin.stt' # TODO rename "opm.stt"
TTS = 'mycroft.plugin.tts' # TODO rename "opm.tts"
WAKEWORD = 'mycroft.plugin.wake_word' # TODO rename "opm.wake_word"
TRANSLATE = "neon.plugin.lang.translate" # TODO rename "opm.lang.translate"
LANG_DETECT = "neon.plugin.lang.detect" # TODO rename "opm.lang.detect"
UTTERANCE_TRANSFORMER = "neon.plugin.text" # TODO rename "opm.transformer.text"
METADATA_TRANSFORMER = "neon.plugin.metadata" # TODO rename "opm.transformer.metadata"
AUDIO_TRANSFORMER = "neon.plugin.audio" # TODO rename "opm.transformer.audio"
DIALOG_TRANSFORMER = "opm.transformer.dialog"
TTS_TRANSFORMER = "opm.transformer.tts"
QUESTION_SOLVER = "neon.plugin.solver"
QUESTION_SOLVER = "neon.plugin.solver" # TODO rename "opm.solver.question"
CHAT_SOLVER = "opm.solver.chat"
TLDR_SOLVER = "opm.solver.summarization"
ENTAILMENT_SOLVER = "opm.solver.entailment"
MULTIPLE_CHOICE_SOLVER = "opm.solver.multiple_choice"
READING_COMPREHENSION_SOLVER = "opm.solver.reading_comprehension"
COREFERENCE_SOLVER = "intentbox.coreference"
KEYWORD_EXTRACTION = "intentbox.keywords"
UTTERANCE_SEGMENTATION = "intentbox.segmentation"
TOKENIZATION = "intentbox.tokenization"
POSTAG = "intentbox.postag"
STREAM_EXTRACTOR = "ovos.ocp.extractor"
COREFERENCE_SOLVER = "intentbox.coreference" # TODO rename "opm.coreference"
KEYWORD_EXTRACTION = "intentbox.keywords" # TODO rename "opm.keywords"
UTTERANCE_SEGMENTATION = "intentbox.segmentation" # TODO rename "opm.segmentation"
TOKENIZATION = "intentbox.tokenization" # TODO rename "opm.tokenization"
POSTAG = "intentbox.postag" # TODO rename "opm.postag"
STREAM_EXTRACTOR = "ovos.ocp.extractor" # TODO rename "opm.ocp.extractor"
AUDIO_PLAYER = "opm.media.audio"
VIDEO_PLAYER = "opm.media.video"
WEB_PLAYER = "opm.media.web"
Expand Down Expand Up @@ -91,6 +92,7 @@ class PluginConfigTypes(str, Enum):
DIALOG_TRANSFORMER = "opm.transformer.dialog.config"
TTS_TRANSFORMER = "opm.transformer.tts.config"
QUESTION_SOLVER = "neon.plugin.solver.config"
CHAT_SOLVER = "opm.solver.chat.config"
TLDR_SOLVER = "opm.solver.summarization.config"
ENTAILMENT_SOLVER = "opm.solver.entailment.config"
MULTIPLE_CHOICE_SOLVER = "opm.solver.multiple_choice.config"
Expand Down Expand Up @@ -173,6 +175,7 @@ def load_plugin(plug_name: str, plug_type: Optional[PluginTypes] = None):
LOG.warning(f'Could not find the plugin {plug_type}.{plug_name}')
return None


@deprecated("normalize_lang has been deprecated! update to 'from ovos_utils.lang import standardize_lang_tag'", "1.0.0")
def normalize_lang(lang):
from ovos_utils.lang import standardize_lang_tag
Expand Down
Loading