Skip to content

Commit

Permalink
Ec/support functions (#52)
Browse files Browse the repository at this point in the history
* Support functions

* Temporarily remove RoleTag validation

* Correct role assignment

* Remove duplicate environment validator

* Lint

* More flexible dependencies

* Reuse enabled
  • Loading branch information
Enias Cailliau authored Jun 16, 2023
1 parent 956de57 commit 2550d9e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 22 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
steamship~=2.17.0
langchain==0.0.168
steamship~=2.17.4
langchain==0.0.200
3 changes: 3 additions & 0 deletions src/steamship_langchain/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from steamship_langchain.chat_models.openai import ChatOpenAI

__all__ = ["ChatOpenAI"]
83 changes: 63 additions & 20 deletions src/steamship_langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@
"""OpenAI chat wrapper."""
from __future__ import annotations

import json
import logging
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple

import tiktoken
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
FunctionMessage,
HumanMessage,
LLMResult,
SystemMessage,
)
from pydantic import Extra, Field, root_validator
from pydantic import Extra, Field, ValidationError
from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.data.tags.tag_constants import TagKind

logger = logging.getLogger(__file__)


def _convert_dict_to_message(_dict: dict) -> BaseMessage:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
content = _dict["content"]
if "function_call" in content:
try:
return AIMessage(content="", additional_kwargs=json.loads(content))
except Exception:
pass
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
Expand All @@ -42,16 +52,24 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict


class ChatOpenAI(BaseChatModel):
class ChatOpenAI(ChatOpenAI, BaseChatModel):
"""Wrapper around OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the
Expand All @@ -68,7 +86,7 @@ class ChatOpenAI(BaseChatModel):
"""

client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo"
model_name: str = "gpt-3.5-turbo-0613"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
Expand All @@ -95,11 +113,30 @@ class Config:
def __init__(
self,
client: Steamship,
model_name: str = "gpt-3.5-turbo",
model_name: str = "gpt-3.5-turbo-0613",
moderate_output: bool = True,
**kwargs,
):
super().__init__(client=client, model_name=model_name, **kwargs)
try:

class OpenAI(object):
class ChatCompletion:
pass

import sys

sys.modules["openai"] = OpenAI

dummy_api_key = False
if "openai_api_key" not in kwargs:
kwargs["openai_api_key"] = "DUMMY"
dummy_api_key = True
super().__init__(client=client, model_name=model_name, **kwargs)
if dummy_api_key:
self.openai_api_key = None
except ValidationError as e:
print(e)
self.client = client
plugin_config = {"model": self.model_name, "moderate_output": moderate_output}
if self.openai_api_key:
plugin_config["openai_api_key"] = self.openai_api_key
Expand All @@ -122,14 +159,6 @@ def __init__(
fetch_if_exists=True,
)

@classmethod
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
Expand All @@ -154,12 +183,16 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]:
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
name = msg.get("name", "")
if len(content) > 0:
role_tag = RoleTag(role)
tags = [Tag(kind=TagKind.ROLE, name=role)]
if name:
tags.append(Tag(kind="name", name=name))

blocks.append(
Block(
text=content,
tags=[Tag(kind=TagKind.ROLE, name=role_tag)],
tags=tags,
mime_type=MimeTypes.TXT,
)
)
Expand All @@ -169,14 +202,24 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]:
generate_task.wait()

return [
_convert_dict_to_message({"content": block.text, "role": RoleTag.USER.value})
_convert_dict_to_message(
{
"content": block.text,
"role": [tag for tag in block.tags if tag.kind == TagKind.ROLE.value][0].name,
}
)
for block in generate_task.output.blocks
]

def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
messages = self._complete(messages=message_dicts, **params)
return ChatResult(
generations=[ChatGeneration(message=message) for message in messages],
Expand Down

0 comments on commit 2550d9e

Please sign in to comment.