Skip to content

Commit

Permalink
feat: chat history (#286)
Browse files Browse the repository at this point in the history
* feat: chat history

* streaming

* Liskov Substitution Principle

* rm unused decorator

* docstrs
  • Loading branch information
JarbasAl authored Nov 20, 2024
1 parent 3170871 commit 59b81d6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 31 deletions.
29 changes: 22 additions & 7 deletions ovos_plugin_manager/solvers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from ovos_plugin_manager.utils import normalize_lang, \
PluginTypes, PluginConfigTypes
from ovos_plugin_manager.templates.solvers import QuestionSolver, TldrSolver, \
EntailmentSolver, MultipleChoiceSolver, EvidenceSolver
from ovos_utils.log import LOG
EntailmentSolver, MultipleChoiceSolver, EvidenceSolver, ChatMessageSolver
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes


def find_chat_solver_plugins() -> dict:
"""
Find all installed plugins
@return: dict plugin names to entrypoints
"""
from ovos_plugin_manager.utils import find_plugins
return find_plugins(PluginTypes.CHAT_SOLVER)


def load_chat_solver_plugin(module_name: str) -> type(ChatMessageSolver):
"""
Get an uninstantiated class for the requested module_name
@param module_name: Plugin entrypoint name to load
@return: Uninstantiated class
"""
from ovos_plugin_manager.utils import load_plugin
return load_plugin(module_name, PluginTypes.CHAT_SOLVER)


def find_question_solver_plugins() -> dict:
"""
Expand Down Expand Up @@ -172,7 +188,7 @@ def get_entailment_solver_module_configs(module_name: str) -> dict:


def get_entailment_solver_lang_configs(lang: str,
include_dialects: bool = False) -> dict:
include_dialects: bool = False) -> dict:
"""
Get a dict of plugin names to list valid configurations for the requested
lang.
Expand Down Expand Up @@ -303,7 +319,7 @@ def get_reading_comprehension_solver_module_configs(module_name: str) -> dict:


def get_reading_comprehension_solver_lang_configs(lang: str,
include_dialects: bool = False) -> dict:
include_dialects: bool = False) -> dict:
"""
Get a dict of plugin names to list valid configurations for the requested
lang.
Expand All @@ -324,4 +340,3 @@ def get_reading_comprehension_solver_supported_langs() -> dict:
from ovos_plugin_manager.utils.config import get_plugin_supported_languages
return get_plugin_supported_languages(
PluginTypes.READING_COMPREHENSION_SOLVER)

69 changes: 68 additions & 1 deletion ovos_plugin_manager/templates/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional, List, Iterable, Tuple, Dict, Union, Any

from json_database import JsonStorageXDG
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.lang import standardize_lang_tag
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.xdg_utils import xdg_cache_home

from ovos_plugin_manager.templates.language import LanguageTranslator, LanguageDetector
Expand Down Expand Up @@ -398,6 +398,73 @@ def long_answer(self, query: str,
return steps


class ChatMessageSolver(QuestionSolver):
"""A solver that processes chat history in LLM-style format to generate contextual responses.
This class extends QuestionSolver to handle multi-turn conversations, maintaining
context across messages. It expects chat messages in a format similar to LLM APIs:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Knock knock."},
{"role": "assistant", "content": "Who's there?"},
{"role": "user", "content": "Orange."},
]
"""

@abc.abstractmethod
def continue_chat(self, messages: List[Dict[str, str]],
lang: Optional[str],
units: Optional[str] = None) -> Optional[str]:
"""Generate a response based on the chat history.
Args:
messages (List[Dict[str, str]]): List of chat messages, each containing 'role' and 'content'.
lang (Optional[str]): The language code for the response. If None, will be auto-detected.
units (Optional[str]): Optional unit system for numerical values.
Returns:
Optional[str]: The generated response or None if no response could be generated.
"""

@auto_detect_lang(text_keys=["messages"])
@auto_translate(translate_keys=["messages"])
def get_chat_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
return self.continue_chat(messages=messages, lang=lang, units=units)

def stream_chat_utterances(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
"""
Stream utterances for the given chat history as they become available.
Args:
messages: The chat messages.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.
Returns:
Iterable[str]: An iterable of utterances.
"""
ans = _call_with_sanitized_kwargs(self.get_chat_completion, messages, lang=lang, units=units)
for utt in self.sentence_split(ans):
yield utt

def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
"""Override of QuestionSolver.get_spoken_answer for API compatibility.
This implementation converts the single query into a chat message format
and delegates to continue_chat. While functional, direct use of chat-specific
methods is recommended for chat-based interactions.
"""
# just for api compat since it's a subclass, shouldn't be directly used
return self.continue_chat(messages=[{"role": "user", "content": query}], lang=lang, units=units)


class CorpusSolver(QuestionSolver):
"""Retrieval based question solver"""

Expand Down
49 changes: 26 additions & 23 deletions ovos_plugin_manager/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,37 @@ class PluginTypes(str, Enum):
FACE_EMBEDDINGS = "opm.embeddings.face"
VOICE_EMBEDDINGS = "opm.embeddings.voice"
TEXT_EMBEDDINGS = "opm.embeddings.text"
GUI = "ovos.plugin.gui"
PHAL = "ovos.plugin.phal"
ADMIN = "ovos.plugin.phal.admin"
SKILL = "ovos.plugin.skill"
MIC = "ovos.plugin.microphone"
VAD = "ovos.plugin.VAD"
PHONEME = "ovos.plugin.g2p"
AUDIO2IPA = "ovos.plugin.audio2ipa"
GUI = "ovos.plugin.gui" # TODO rename "opm.gui"
PHAL = "ovos.plugin.phal" # TODO rename "opm.phal"
ADMIN = "ovos.plugin.phal.admin" # TODO rename "opm.phal.admin"
SKILL = "ovos.plugin.skill" # TODO rename "opm.skill"
MIC = "ovos.plugin.microphone" # TODO rename "opm.microphone"
VAD = "ovos.plugin.VAD" # TODO rename "opm.vad"
PHONEME = "ovos.plugin.g2p" # TODO rename "opm.g2p"
AUDIO2IPA = "ovos.plugin.audio2ipa" # TODO rename "opm.audio2ipa"
AUDIO = 'mycroft.plugin.audioservice' # DEPRECATED
STT = 'mycroft.plugin.stt'
TTS = 'mycroft.plugin.tts'
WAKEWORD = 'mycroft.plugin.wake_word'
TRANSLATE = "neon.plugin.lang.translate"
LANG_DETECT = "neon.plugin.lang.detect"
UTTERANCE_TRANSFORMER = "neon.plugin.text"
METADATA_TRANSFORMER = "neon.plugin.metadata"
AUDIO_TRANSFORMER = "neon.plugin.audio"
STT = 'mycroft.plugin.stt' # TODO rename "opm.stt"
TTS = 'mycroft.plugin.tts' # TODO rename "opm.tts"
WAKEWORD = 'mycroft.plugin.wake_word' # TODO rename "opm.wake_word"
TRANSLATE = "neon.plugin.lang.translate" # TODO rename "opm.lang.translate"
LANG_DETECT = "neon.plugin.lang.detect" # TODO rename "opm.lang.detect"
UTTERANCE_TRANSFORMER = "neon.plugin.text" # TODO rename "opm.transformer.text"
METADATA_TRANSFORMER = "neon.plugin.metadata" # TODO rename "opm.transformer.metadata"
AUDIO_TRANSFORMER = "neon.plugin.audio" # TODO rename "opm.transformer.audio"
DIALOG_TRANSFORMER = "opm.transformer.dialog"
TTS_TRANSFORMER = "opm.transformer.tts"
QUESTION_SOLVER = "neon.plugin.solver"
QUESTION_SOLVER = "neon.plugin.solver" # TODO rename "opm.solver.question"
CHAT_SOLVER = "opm.solver.chat"
TLDR_SOLVER = "opm.solver.summarization"
ENTAILMENT_SOLVER = "opm.solver.entailment"
MULTIPLE_CHOICE_SOLVER = "opm.solver.multiple_choice"
READING_COMPREHENSION_SOLVER = "opm.solver.reading_comprehension"
COREFERENCE_SOLVER = "intentbox.coreference"
KEYWORD_EXTRACTION = "intentbox.keywords"
UTTERANCE_SEGMENTATION = "intentbox.segmentation"
TOKENIZATION = "intentbox.tokenization"
POSTAG = "intentbox.postag"
STREAM_EXTRACTOR = "ovos.ocp.extractor"
COREFERENCE_SOLVER = "intentbox.coreference" # TODO rename "opm.coreference"
KEYWORD_EXTRACTION = "intentbox.keywords" # TODO rename "opm.keywords"
UTTERANCE_SEGMENTATION = "intentbox.segmentation" # TODO rename "opm.segmentation"
TOKENIZATION = "intentbox.tokenization" # TODO rename "opm.tokenization"
POSTAG = "intentbox.postag" # TODO rename "opm.postag"
STREAM_EXTRACTOR = "ovos.ocp.extractor" # TODO rename "opm.ocp.extractor"
AUDIO_PLAYER = "opm.media.audio"
VIDEO_PLAYER = "opm.media.video"
WEB_PLAYER = "opm.media.web"
Expand Down Expand Up @@ -91,6 +92,7 @@ class PluginConfigTypes(str, Enum):
DIALOG_TRANSFORMER = "opm.transformer.dialog.config"
TTS_TRANSFORMER = "opm.transformer.tts.config"
QUESTION_SOLVER = "neon.plugin.solver.config"
CHAT_SOLVER = "opm.solver.chat.config"
TLDR_SOLVER = "opm.solver.summarization.config"
ENTAILMENT_SOLVER = "opm.solver.entailment.config"
MULTIPLE_CHOICE_SOLVER = "opm.solver.multiple_choice.config"
Expand Down Expand Up @@ -173,6 +175,7 @@ def load_plugin(plug_name: str, plug_type: Optional[PluginTypes] = None):
LOG.warning(f'Could not find the plugin {plug_type}.{plug_name}')
return None


@deprecated("normalize_lang has been deprecated! update to 'from ovos_utils.lang import standardize_lang_tag'", "1.0.0")
def normalize_lang(lang):
from ovos_utils.lang import standardize_lang_tag
Expand Down

0 comments on commit 59b81d6

Please sign in to comment.