Skip to content

Commit

Permalink
openai[patch]: use max_completion_tokens in place of max_tokens (#26917)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme and baskaryan authored Nov 26, 2024
1 parent 869c8f5 commit 42b1882
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 29 deletions.
35 changes: 32 additions & 3 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = Field(default=None)
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down Expand Up @@ -699,6 +699,7 @@ def _get_request_payload(
messages = self._convert_input(input_).to_messages()
if stop is not None:
kwargs["stop"] = stop

return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
Expand Down Expand Up @@ -853,7 +854,9 @@ def _get_ls_params(
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
if ls_max_tokens := params.get("max_tokens", self.max_tokens) or params.get(
"max_completion_tokens", self.max_tokens
):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
Expand Down Expand Up @@ -1501,7 +1504,7 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
return filtered


class ChatOpenAI(BaseChatOpenAI):
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
.. dropdown:: Setup
Expand Down Expand Up @@ -1963,6 +1966,9 @@ class Joke(BaseModel):
message chunks will be generated during the stream including usage metadata.
"""

max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
"""Maximum number of tokens to generate."""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
Expand Down Expand Up @@ -1992,6 +1998,29 @@ def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = super()._default_params
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")

return params

def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# max_tokens was deprecated in favor of max_completion_tokens
# in September 2024 release
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload

def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_chat_openai() -> None:
max_retries=3,
http_client=None,
n=1,
max_tokens=10,
max_completion_tokens=10,
default_headers=None,
default_query=None,
)
Expand All @@ -64,7 +64,7 @@ def test_chat_openai_model() -> None:

def test_chat_openai_system_message() -> None:
"""Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(max_tokens=10)
chat = ChatOpenAI(max_completion_tokens=10)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat.invoke([system_message, human_message])
Expand All @@ -75,7 +75,7 @@ def test_chat_openai_system_message() -> None:
@pytest.mark.scheduled
def test_chat_openai_generate() -> None:
"""Test ChatOpenAI wrapper with generate."""
chat = ChatOpenAI(max_tokens=10, n=2)
chat = ChatOpenAI(max_completion_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
Expand All @@ -92,7 +92,7 @@ def test_chat_openai_generate() -> None:
@pytest.mark.scheduled
def test_chat_openai_multiple_completions() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
chat = ChatOpenAI(max_tokens=10, n=5)
chat = ChatOpenAI(max_completion_tokens=10, n=5)
message = HumanMessage(content="Hello")
response = chat._generate([message])
assert isinstance(response, ChatResult)
Expand All @@ -108,7 +108,7 @@ def test_chat_openai_streaming() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
max_completion_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
Expand All @@ -133,7 +133,9 @@ def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:

callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager)
chat = ChatOpenAI(
max_completion_tokens=2, temperature=0, callback_manager=callback_manager
)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
Expand All @@ -142,7 +144,7 @@ def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:

def test_chat_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10)
chat = ChatOpenAI(max_completion_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
Expand All @@ -151,7 +153,7 @@ def test_chat_openai_llm_output_contains_model_name() -> None:

def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10, streaming=True)
chat = ChatOpenAI(max_completion_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
Expand All @@ -161,13 +163,13 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5)
ChatOpenAI(max_completion_tokens=10, streaming=True, temperature=0, n=5)


@pytest.mark.scheduled
async def test_async_chat_openai() -> None:
"""Test async generation."""
chat = ChatOpenAI(max_tokens=10, n=2)
chat = ChatOpenAI(max_completion_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
Expand All @@ -187,7 +189,7 @@ async def test_async_chat_openai_streaming() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
max_completion_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
Expand Down Expand Up @@ -219,7 +221,7 @@ class Person(BaseModel):
default=None, title="Fav Food", description="The person's favorite food"
)

chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions(
chat = ChatOpenAI(max_completion_tokens=30, n=1, streaming=True).bind_functions(
functions=[Person], function_call="Person"
)

Expand All @@ -241,7 +243,7 @@ class Person(BaseModel):
@pytest.mark.scheduled
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
Expand All @@ -250,7 +252,7 @@ def test_openai_streaming() -> None:
@pytest.mark.scheduled
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
Expand All @@ -259,7 +261,7 @@ async def test_openai_astream() -> None:
@pytest.mark.scheduled
async def test_openai_abatch() -> None:
"""Test streaming tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
Expand All @@ -269,7 +271,7 @@ async def test_openai_abatch() -> None:
@pytest.mark.scheduled
async def test_openai_abatch_tags() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
Expand All @@ -281,7 +283,7 @@ async def test_openai_abatch_tags() -> None:
@pytest.mark.scheduled
def test_openai_batch() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
Expand All @@ -291,7 +293,7 @@ def test_openai_batch() -> None:
@pytest.mark.scheduled
async def test_openai_ainvoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
Expand All @@ -300,7 +302,7 @@ async def test_openai_ainvoke() -> None:
@pytest.mark.scheduled
def test_openai_invoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
Expand Down Expand Up @@ -385,23 +387,23 @@ async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
assert chunks_with_token_counts == 0
assert full.usage_metadata is None

llm = ChatOpenAI(temperature=0, max_tokens=5)
llm = ChatOpenAI(temperature=0, max_completion_tokens=5)
await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
)
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
max_completion_tokens=5,
model_kwargs={"stream_options": {"include_usage": True}},
)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True)
llm = ChatOpenAI(temperature=0, max_completion_tokens=5, stream_usage=True)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)

Expand Down Expand Up @@ -666,15 +668,15 @@ def test_openai_response_headers() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
query = "I'm Pickle Rick"
result = chat_openai.invoke(query, max_tokens=10)
result = chat_openai.invoke(query, max_completion_tokens=10)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers

# Stream
full: Optional[BaseMessageChunk] = None
for chunk in chat_openai.stream(query, max_tokens=10):
for chunk in chat_openai.stream(query, max_completion_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
Expand All @@ -687,15 +689,15 @@ async def test_openai_response_headers_async() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
query = "I'm Pickle Rick"
result = await chat_openai.ainvoke(query, max_tokens=10)
result = await chat_openai.ainvoke(query, max_completion_tokens=10)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers

# Stream
full: Optional[BaseMessageChunk] = None
async for chunk in chat_openai.astream(query, max_tokens=10):
async for chunk in chat_openai.astream(query, max_completion_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
Expand Down Expand Up @@ -1085,3 +1087,13 @@ async def test_astream_response_format() -> None:
"how are ya", response_format=Foo
):
pass


def test_o1_max_tokens() -> None:
response = ChatOpenAI(model="o1-mini", max_tokens=10).invoke("how are you") # type: ignore[call-arg]
assert isinstance(response, AIMessage)

response = ChatOpenAI(model="gpt-4o", max_completion_tokens=10).invoke(
"how are you"
)
assert isinstance(response, AIMessage)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_openai_model_param() -> None:
llm = ChatOpenAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"

llm = ChatOpenAI(max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
llm = ChatOpenAI(max_completion_tokens=10)
assert llm.max_tokens == 10


def test_openai_o1_temperature() -> None:
llm = ChatOpenAI(model="o1-preview")
Expand Down

0 comments on commit 42b1882

Please sign in to comment.