From cc2cc691b19fb356fcbfe7f0202fa04717ee000b Mon Sep 17 00:00:00 2001 From: Will <651833+WillBeebe@users.noreply.github.com> Date: Tue, 16 Jul 2024 00:17:59 -0700 Subject: [PATCH] add streaming to cohere and ollama --- pyproject.toml | 2 +- src/abcs/anthropic.py | 2 -- src/abcs/cohere.py | 76 ++++++++++++++++++++++++++++++++++++++++++- src/abcs/ollama.py | 64 +++++++++++++++++++++++++++++++++++- src/abcs/openai.py | 4 --- 5 files changed, 139 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5d5f3a0..e5b66fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ada-python" -version = "0.4.0" +version = "0.5.0" description = "Ada, making LLMs easier to work with." authors = ["Will Beebe"] packages = [ diff --git a/src/abcs/anthropic.py b/src/abcs/anthropic.py index 32536d7..27bb2f4 100644 --- a/src/abcs/anthropic.py +++ b/src/abcs/anthropic.py @@ -198,6 +198,4 @@ async def content_generator(): raise e async def handle_tool_call(self, tool_calls, combined_history, tools): - # This is a placeholder for handling tool calls in streaming context - # You'll need to implement the logic to execute the tool call and generate a response pass diff --git a/src/abcs/cohere.py b/src/abcs/cohere.py index aa2b6b1..2eab6e0 100644 --- a/src/abcs/cohere.py +++ b/src/abcs/cohere.py @@ -1,8 +1,10 @@ +import asyncio import logging +from typing import Any, Dict, List, Optional import cohere from abcs.llm import LLM -from abcs.models import PromptResponse, UsageStats +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats from tools.tool_manager import ToolManager logging.basicConfig(level=logging.INFO) @@ -121,3 +123,75 @@ def _translate_response(self, response) -> PromptResponse: ) raise e + # https://github.com/cohere-ai/cohere-python/blob/main/src/cohere/types/streamed_chat_response.py + # https://docs.cohere.com/docs/streaming#stream-events + # https://docs.cohere.com/docs/streaming#example-responses + async def generate_text_stream( + self, + prompt: str, + past_messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> StreamingPromptResponse: + combined_history = past_messages + [{"role": "user", "content": prompt}] + + try: + combined_history = [] + for msg in past_messages: + combined_history.append({ + "role": 'CHATBOT' if msg['role'] == 'assistant' else 'USER', + "message": msg['content'], + }) + stream = self.client.chat_stream( + chat_history=combined_history, + message=prompt, + tools=tools, + model=self.model, + # perform web search before answering the question. You can also use your own custom connector. + # connectors=[{"id": "web-search"}], + ) + + async def content_generator(): + for event in stream: + if isinstance(event, cohere.types.StreamedChatResponse_StreamStart): + # Message start event, we can ignore this + pass + elif isinstance(event, cohere.types.StreamedChatResponse_TextGeneration): + # This is the event that contains the actual text + if event.text: + yield event.text + elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsGeneration): + # todo: call tool + pass + elif isinstance(event, cohere.types.StreamedChatResponse_CitationGeneration): + # todo: not sure, but seems useful + pass + elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsChunk): + # todo: tool response + pass + elif isinstance(event, cohere.types.StreamedChatResponse_SearchQueriesGeneration): + pass + elif isinstance(event, cohere.types.StreamedChatResponse_SearchResults): + pass + elif isinstance(event, cohere.types.StreamedChatResponse_StreamEnd): + # Message stop event, we can ignore this + pass + # Small delay to allow for cooperative multitasking + await asyncio.sleep(0) + + return StreamingPromptResponse( + content=content_generator(), + raw_response=stream, + error={}, + usage=UsageStats( + input_tokens=0, # These will need to be updated after streaming + output_tokens=0, + extra={}, + ), + ) + except Exception as e: + logger.exception(f"An error occurred while streaming from Claude: {e}") + raise e + + async def handle_tool_call(self, tool_calls, combined_history, tools): + pass diff --git a/src/abcs/ollama.py b/src/abcs/ollama.py index 39e65b2..1cee6fe 100644 --- a/src/abcs/ollama.py +++ b/src/abcs/ollama.py @@ -1,9 +1,15 @@ +import asyncio import logging import os from typing import Any, Dict, List, Optional from abcs.llm import LLM -from abcs.models import OllamaResponse, PromptResponse, UsageStats +from abcs.models import ( + OllamaResponse, + PromptResponse, + StreamingPromptResponse, + UsageStats, +) from ollama import Client from tools.tool_manager import ToolManager @@ -110,3 +116,59 @@ def _translate_response(self, response) -> PromptResponse: except Exception as e: logger.exception(f"An error occurred while translating Ollama response: {e}") raise e + + async def generate_text_stream( + self, + prompt: str, + past_messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> StreamingPromptResponse: + combined_history = past_messages + [{"role": "user", "content": prompt}] + + try: + combined_history = past_messages + combined_history.append( + { + "role": "user", + "content": prompt, + } + ) + # https://github.com/ollama/ollama-python + # client = Client(host="https://120d-2606-40-15c-13ba-00-460-7bae.ngrok-free.app",) + + # todo: generate vs chat + # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion + stream = self.client.chat( + model=self.model, + messages=combined_history, + stream=True, + # num_predict=4000 + # todo + # system=self.system_prompt + ) + + + async def content_generator(): + for chunk in stream: + if chunk['message']['content']: + yield chunk['message']['content'] + # Small delay to allow for cooperative multitasking + await asyncio.sleep(0) + + return StreamingPromptResponse( + content=content_generator(), + raw_response=stream, + error={}, + usage=UsageStats( + input_tokens=0, # These will need to be updated after streaming + output_tokens=0, + extra={}, + ), + ) + except Exception as e: + logger.exception(f"An error occurred while streaming from Claude: {e}") + raise e + + async def handle_tool_call(self, tool_calls, combined_history, tools): + pass diff --git a/src/abcs/openai.py b/src/abcs/openai.py index 14a0b9e..271db8e 100644 --- a/src/abcs/openai.py +++ b/src/abcs/openai.py @@ -242,8 +242,4 @@ async def content_generator(): raise e async def handle_tool_call(self, collected_content, combined_history, tools): - # This is a placeholder for handling tool calls in streaming context - # You'll need to implement the logic to parse the tool call, execute it, - # and generate a response based on the tool's output - # This might involve breaking the streaming and making a new API call pass