Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: use max_completion_tokens in place of max_tokens #26917

Merged
merged 14 commits into from
Nov 26, 2024
30 changes: 28 additions & 2 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to update azure openai as well? if someone is passing in max_completion_tokens to azure today this change will update so instead max_tokens is passed in

"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down Expand Up @@ -717,6 +717,7 @@ def _get_request_payload(
messages = self._convert_input(input_).to_messages()
if stop is not None:
kwargs["stop"] = stop

return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
Expand Down Expand Up @@ -889,7 +890,9 @@ def _get_ls_params(
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
if ls_max_tokens := params.get("max_tokens", self.max_tokens) or params.get(
"max_completion_tokens", self.max_tokens
):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
Expand Down Expand Up @@ -1993,6 +1996,29 @@ def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = super()._default_params
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")

return params

def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# max_tokens was deprecated in favor of max_completion_tokens
# in September 2024 release
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload

def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/openai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_chat_openai() -> None:
max_retries=3,
http_client=None,
n=1,
max_tokens=10,
max_completion_tokens=10,
default_headers=None,
default_query=None,
)
Expand All @@ -65,7 +65,7 @@ def test_chat_openai_model() -> None:

def test_chat_openai_system_message() -> None:
"""Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(max_tokens=10)
chat = ChatOpenAI(max_completion_tokens=10)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat.invoke([system_message, human_message])
Expand All @@ -76,7 +76,7 @@ def test_chat_openai_system_message() -> None:
@pytest.mark.scheduled
def test_chat_openai_generate() -> None:
"""Test ChatOpenAI wrapper with generate."""
chat = ChatOpenAI(max_tokens=10, n=2)
chat = ChatOpenAI(max_completion_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
Expand All @@ -93,7 +93,7 @@ def test_chat_openai_generate() -> None:
@pytest.mark.scheduled
def test_chat_openai_multiple_completions() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
chat = ChatOpenAI(max_tokens=10, n=5)
chat = ChatOpenAI(max_completion_tokens=10, n=5)
message = HumanMessage(content="Hello")
response = chat._generate([message])
assert isinstance(response, ChatResult)
Expand All @@ -109,7 +109,7 @@ def test_chat_openai_streaming() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
max_completion_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
Expand All @@ -134,7 +134,9 @@ def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:

callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager)
chat = ChatOpenAI(
max_completion_tokens=2, temperature=0, callback_manager=callback_manager
)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
Expand All @@ -143,7 +145,7 @@ def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:

def test_chat_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10)
chat = ChatOpenAI(max_completion_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
Expand All @@ -152,7 +154,7 @@ def test_chat_openai_llm_output_contains_model_name() -> None:

def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10, streaming=True)
chat = ChatOpenAI(max_completion_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
Expand All @@ -162,13 +164,13 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5)
ChatOpenAI(max_completion_tokens=10, streaming=True, temperature=0, n=5)


@pytest.mark.scheduled
async def test_async_chat_openai() -> None:
"""Test async generation."""
chat = ChatOpenAI(max_tokens=10, n=2)
chat = ChatOpenAI(max_completion_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
Expand All @@ -188,7 +190,7 @@ async def test_async_chat_openai_streaming() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
max_completion_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
Expand Down Expand Up @@ -220,7 +222,7 @@ class Person(BaseModel):
default=None, title="Fav Food", description="The person's favorite food"
)

chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions(
chat = ChatOpenAI(max_completion_tokens=30, n=1, streaming=True).bind_functions(
functions=[Person], function_call="Person"
)

Expand All @@ -242,7 +244,7 @@ class Person(BaseModel):
@pytest.mark.scheduled
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
Expand All @@ -251,7 +253,7 @@ def test_openai_streaming() -> None:
@pytest.mark.scheduled
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
Expand All @@ -260,7 +262,7 @@ async def test_openai_astream() -> None:
@pytest.mark.scheduled
async def test_openai_abatch() -> None:
"""Test streaming tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
Expand All @@ -270,7 +272,7 @@ async def test_openai_abatch() -> None:
@pytest.mark.scheduled
async def test_openai_abatch_tags() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
Expand All @@ -282,7 +284,7 @@ async def test_openai_abatch_tags() -> None:
@pytest.mark.scheduled
def test_openai_batch() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
Expand All @@ -292,7 +294,7 @@ def test_openai_batch() -> None:
@pytest.mark.scheduled
async def test_openai_ainvoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
Expand All @@ -301,7 +303,7 @@ async def test_openai_ainvoke() -> None:
@pytest.mark.scheduled
def test_openai_invoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
llm = ChatOpenAI(max_completion_tokens=10)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
Expand Down Expand Up @@ -386,23 +388,23 @@ async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
assert chunks_with_token_counts == 0
assert full.usage_metadata is None

llm = ChatOpenAI(temperature=0, max_tokens=5)
llm = ChatOpenAI(temperature=0, max_completion_tokens=5)
await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
)
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
max_completion_tokens=5,
model_kwargs={"stream_options": {"include_usage": True}},
)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True)
llm = ChatOpenAI(temperature=0, max_completion_tokens=5, stream_usage=True)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)

Expand Down Expand Up @@ -667,15 +669,15 @@ def test_openai_response_headers() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
query = "I'm Pickle Rick"
result = chat_openai.invoke(query, max_tokens=10)
result = chat_openai.invoke(query, max_completion_tokens=10)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers

# Stream
full: Optional[BaseMessageChunk] = None
for chunk in chat_openai.stream(query, max_tokens=10):
for chunk in chat_openai.stream(query, max_completion_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
Expand All @@ -688,15 +690,15 @@ async def test_openai_response_headers_async() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
query = "I'm Pickle Rick"
result = await chat_openai.ainvoke(query, max_tokens=10)
result = await chat_openai.ainvoke(query, max_completion_tokens=10)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers

# Stream
full: Optional[BaseMessageChunk] = None
async for chunk in chat_openai.astream(query, max_tokens=10):
async for chunk in chat_openai.astream(query, max_completion_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_openai_model_param() -> None:
llm = ChatOpenAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"

llm = ChatOpenAI(max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
llm = ChatOpenAI(max_completion_tokens=10)
assert llm.max_tokens == 10


def test_openai_o1_temperature() -> None:
llm = ChatOpenAI(model="o1-preview")
Expand Down
Loading