Skip to content

Commit

Permalink
Fix model name passing in OpenAIChat (#60)
Browse files Browse the repository at this point in the history
Co-authored-by: Douglas Reid <[email protected]>
  • Loading branch information
douglas-reid and Douglas Reid authored Sep 12, 2023
1 parent 2550d9e commit 9648440
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/steamship_langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Config:
def __init__(
self, client: Steamship, model_name: str = "gpt-4", moderate_output: bool = True, **kwargs
):
super().__init__(client=client, **kwargs)
super().__init__(client=client, model_name=model_name, **kwargs)
plugin_config = {"model": model_name, "moderate_output": moderate_output}
if self.openai_api_key:
plugin_config["openai_api_key"] = self.openai_api_key
Expand Down
18 changes: 18 additions & 0 deletions tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,21 @@ def test_openai_llm_with_chat_model_init(client: Steamship) -> None:
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "What is the meaning of life?"


@pytest.mark.usefixtures("client")
def test_openai_large_context(client: Steamship):
"""Basic tests of the OpenAIChat plugin wrapper for large context models."""

llm_under_test = OpenAIChat(client=client, model_name="gpt-3.5-turbo-16k", temperature=0.8)

long_prompt = (
'Complete the following short story. The child screamed "'
+ "AHHHHH" * 5000
+ '" when they saw the'
)

prompts = [long_prompt]
generated = llm_under_test.generate(prompts=prompts)
assert len(generated.generations) != 0
assert len(generated.generations[0]) > 0
2 changes: 1 addition & 1 deletion tests/tools/test_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def test_search_tool(client: Steamship):

answer = tool_under_test.search("Who won the 2019 World Series?")
assert len(answer) != 0
assert "Washington Nationals" in answer
assert ("Nationals" in answer) or ("Washington" in answer)

0 comments on commit 9648440

Please sign in to comment.