Skip to content

Commit

Permalink
Merge branch 'main' into ODSC-63451/oci_odsc_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
mrDzurb authored Dec 12, 2024
2 parents 2d198e6 + d36ce67 commit a01f809
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
to_tgi_messages,
force_single_tool_call,
resolve_tgi_function_call,
get_max_input_length,
get_max_total_tokens,
resolve_tool_choice,
)
from transformers import (
Expand Down Expand Up @@ -137,7 +137,7 @@ def completion_to_prompt(completion):
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum number of tokens available for input.",
description=(LLMMetadata.model_fields["context_window"].description),
gt=0,
)
max_new_tokens: int = Field(
Expand Down Expand Up @@ -752,7 +752,10 @@ class TextGenerationInference(FunctionCallingLLM):

context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=("Maximum input length in tokens returned from TGI endpoint"),
description=(
LLMMetadata.model_fields["context_window"].description
+ " Maximum total tokens returned from TGI endpoint."
),
)
is_chat_model: bool = Field(
default=True,
Expand Down Expand Up @@ -819,7 +822,7 @@ def __init__(
logger.warning(f"TGI client has no function call support: {e}")
is_function_calling_model = False

context_window = get_max_input_length(model_url) or DEFAULT_CONTEXT_WINDOW
context_window = get_max_total_tokens(model_url) or DEFAULT_CONTEXT_WINDOW

super().__init__(
context_window=context_window,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ def resolve_tgi_function_call(url: str) -> bool:
)


def get_max_input_length(url: str) -> Union[int, None]:
def get_max_input_tokens(url: str) -> Union[int, None]:
url = f"{url}/info"
model_info = dict(requests.get(url).json())
return model_info.get("max_input_length", None)
tgi_version = model_info.get("version", None)
if version.parse(tgi_version) >= version.parse("2.1.0"):
return model_info.get("max_input_tokens", None)
else:
return model_info.get("max_input_length", None)


def get_max_total_tokens(url: str) -> Union[int, None]:
url = f"{url}/info"
model_info = dict(requests.get(url).json())
return model_info.get("max_total_tokens", None)


def to_tgi_messages(messages: Sequence[ChatMessage]) -> Sequence[Message]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-huggingface"
readme = "README.md"
version = "0.4.0"
version = "0.4.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"meta/codellama-70b": 1024,
"meta/llama3-70b-instruct": 8192,
"meta/llama3-8b-instruct": 8192,
"meta/llama-3.3-70b-instruct": 128000,
"microsoft/phi-3-medium-4k-instruct": 1024,
"microsoft/phi-3-mini-128k-instruct": 2048,
"microsoft/phi-3-mini-4k-instruct": 2048,
Expand All @@ -37,6 +38,7 @@
"meta/llama-3.1-8b-instruct",
"meta/llama-3.1-70b-instruct",
"meta/llama-3.1-405b-instruct",
"meta/llama-3.3-70b-instruct",
"mistralai/mistral-large-2-instruct",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-llms-nvidia"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def to_openai_message_dict(
msg = f"Unsupported content block type: {type(block).__name__}"
raise ValueError(msg)

# NOTE: Sending a blank string to openai will cause an error.
# This will commonly happen with tool calls.
content_txt = None if content_txt == "" else content_txt

# NOTE: Despite what the openai docs say, if the role is ASSISTANT, SYSTEM
# or TOOL, 'content' cannot be a list and must be string instead.
# Furthermore, if all blocks are text blocks, we can use the content_txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-openai"
readme = "README.md"
version = "0.3.9"
version = "0.3.10"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def openi_message_dicts_with_function_calling() -> List[ChatCompletionMessagePar
},
{
"role": "assistant",
"content": "",
"content": None,
"function_call": {
"name": "get_current_weather",
"arguments": '{ "location": "Boston, MA"}',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
to_tgi_messages,
force_single_tool_call,
resolve_tgi_function_call,
get_max_input_length,
get_max_total_tokens,
resolve_tool_choice,
)
from text_generation import (
Expand Down Expand Up @@ -101,7 +101,10 @@ class TextGenerationInference(FunctionCallingLLM):

context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=("Maximum input length in tokens returned from TGI endpoint"),
description=(
LLMMetadata.model_fields["context_window"].description
+ " Maximum total tokens returned from TGI endpoint."
),
)
is_chat_model: bool = Field(
default=True,
Expand Down Expand Up @@ -155,7 +158,7 @@ def __init__(
logger.warning(f"TGI client has no function call support: {e}")
is_function_calling_model = False

context_window = get_max_input_length(model_url) or DEFAULT_CONTEXT_WINDOW
context_window = get_max_total_tokens(model_url) or DEFAULT_CONTEXT_WINDOW

super().__init__(
context_window=context_window,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ def resolve_tgi_function_call(url: str) -> bool:
)


def get_max_input_length(url: str) -> Union[int, None]:
def get_max_input_tokens(url: str) -> Union[int, None]:
url = f"{url}/info"
model_info = dict(requests.get(url).json())
return model_info.get("max_input_length", None)
tgi_version = model_info.get("version", None)
if version.parse(tgi_version) >= version.parse("2.1.0"):
return model_info.get("max_input_tokens", None)
else:
return model_info.get("max_input_length", None)


def get_max_total_tokens(url: str) -> Union[int, None]:
url = f"{url}/info"
model_info = dict(requests.get(url).json())
return model_info.get("max_total_tokens", None)


def to_tgi_messages(messages: Sequence[ChatMessage]) -> Sequence[Message]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-text-generation-inference"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down

0 comments on commit a01f809

Please sign in to comment.