Skip to content

Commit

Permalink
add streaming to cohere and ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
WillBeebe committed Jul 16, 2024
1 parent 9670aff commit cc2cc69
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ada-python"
version = "0.4.0"
version = "0.5.0"
description = "Ada, making LLMs easier to work with."
authors = ["Will Beebe"]
packages = [
Expand Down
2 changes: 0 additions & 2 deletions src/abcs/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,4 @@ async def content_generator():
raise e

async def handle_tool_call(self, tool_calls, combined_history, tools):
# This is a placeholder for handling tool calls in streaming context
# You'll need to implement the logic to execute the tool call and generate a response
pass
76 changes: 75 additions & 1 deletion src/abcs/cohere.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional

import cohere
from abcs.llm import LLM
from abcs.models import PromptResponse, UsageStats
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
from tools.tool_manager import ToolManager

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -121,3 +123,75 @@ def _translate_response(self, response) -> PromptResponse:
)
raise e

# https://github.com/cohere-ai/cohere-python/blob/main/src/cohere/types/streamed_chat_response.py
# https://docs.cohere.com/docs/streaming#stream-events
# https://docs.cohere.com/docs/streaming#example-responses
async def generate_text_stream(
self,
prompt: str,
past_messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> StreamingPromptResponse:
combined_history = past_messages + [{"role": "user", "content": prompt}]

try:
combined_history = []
for msg in past_messages:
combined_history.append({
"role": 'CHATBOT' if msg['role'] == 'assistant' else 'USER',
"message": msg['content'],
})
stream = self.client.chat_stream(
chat_history=combined_history,
message=prompt,
tools=tools,
model=self.model,
# perform web search before answering the question. You can also use your own custom connector.
# connectors=[{"id": "web-search"}],
)

async def content_generator():
for event in stream:
if isinstance(event, cohere.types.StreamedChatResponse_StreamStart):
# Message start event, we can ignore this
pass
elif isinstance(event, cohere.types.StreamedChatResponse_TextGeneration):
# This is the event that contains the actual text
if event.text:
yield event.text
elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsGeneration):
# todo: call tool
pass
elif isinstance(event, cohere.types.StreamedChatResponse_CitationGeneration):
# todo: not sure, but seems useful
pass
elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsChunk):
# todo: tool response
pass
elif isinstance(event, cohere.types.StreamedChatResponse_SearchQueriesGeneration):
pass
elif isinstance(event, cohere.types.StreamedChatResponse_SearchResults):
pass
elif isinstance(event, cohere.types.StreamedChatResponse_StreamEnd):
# Message stop event, we can ignore this
pass
# Small delay to allow for cooperative multitasking
await asyncio.sleep(0)

return StreamingPromptResponse(
content=content_generator(),
raw_response=stream,
error={},
usage=UsageStats(
input_tokens=0, # These will need to be updated after streaming
output_tokens=0,
extra={},
),
)
except Exception as e:
logger.exception(f"An error occurred while streaming from Claude: {e}")
raise e

async def handle_tool_call(self, tool_calls, combined_history, tools):
pass
64 changes: 63 additions & 1 deletion src/abcs/ollama.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional

from abcs.llm import LLM
from abcs.models import OllamaResponse, PromptResponse, UsageStats
from abcs.models import (
OllamaResponse,
PromptResponse,
StreamingPromptResponse,
UsageStats,
)
from ollama import Client
from tools.tool_manager import ToolManager

Expand Down Expand Up @@ -110,3 +116,59 @@ def _translate_response(self, response) -> PromptResponse:
except Exception as e:
logger.exception(f"An error occurred while translating Ollama response: {e}")
raise e

async def generate_text_stream(
self,
prompt: str,
past_messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> StreamingPromptResponse:
combined_history = past_messages + [{"role": "user", "content": prompt}]

try:
combined_history = past_messages
combined_history.append(
{
"role": "user",
"content": prompt,
}
)
# https://github.com/ollama/ollama-python
# client = Client(host="https://120d-2606-40-15c-13ba-00-460-7bae.ngrok-free.app",)

# todo: generate vs chat
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
stream = self.client.chat(
model=self.model,
messages=combined_history,
stream=True,
# num_predict=4000
# todo
# system=self.system_prompt
)


async def content_generator():
for chunk in stream:
if chunk['message']['content']:
yield chunk['message']['content']
# Small delay to allow for cooperative multitasking
await asyncio.sleep(0)

return StreamingPromptResponse(
content=content_generator(),
raw_response=stream,
error={},
usage=UsageStats(
input_tokens=0, # These will need to be updated after streaming
output_tokens=0,
extra={},
),
)
except Exception as e:
logger.exception(f"An error occurred while streaming from Claude: {e}")
raise e

async def handle_tool_call(self, tool_calls, combined_history, tools):
pass
4 changes: 0 additions & 4 deletions src/abcs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,4 @@ async def content_generator():
raise e

async def handle_tool_call(self, collected_content, combined_history, tools):
# This is a placeholder for handling tool calls in streaming context
# You'll need to implement the logic to parse the tool call, execute it,
# and generate a response based on the tool's output
# This might involve breaking the streaming and making a new API call
pass

0 comments on commit cc2cc69

Please sign in to comment.