Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HuggingFace Text Embedding Inference endpoint for embeddings #524

Merged
merged 9 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
if model_endpoint_type in DEFAULT_ENDPOINTS:
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
elif config.model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
else:
# default_model_endpoint = None
model_endpoint = None
Expand Down Expand Up @@ -173,12 +175,10 @@ def configure_embedding_endpoint(config: MemGPTConfig):
# configure embedding endpoint

default_embedding_endpoint_type = config.embedding_endpoint_type
if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model
default_embedding_endpoint_type = "local"

embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None
embedding_provider = questionary.select(
"Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type
"Select embedding provider:", choices=["openai", "azure", "hugging-face", "local"], default=default_embedding_endpoint_type
).ask()
if embedding_provider == "openai":
embedding_endpoint_type = "openai"
Expand All @@ -188,11 +188,38 @@ def configure_embedding_endpoint(config: MemGPTConfig):
embedding_endpoint_type = "azure"
_, _, _, _, embedding_endpoint = get_azure_credentials()
embedding_dim = 1536
elif embedding_provider == "hugging-face":
# configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference)
# supports custom model/endpoints
embedding_endpoint_type = "hugging-face"
embedding_endpoint = None

# get endpoint
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint:
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
embedding_endpoint = None

# get model type
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"
embedding_model = questionary.text(
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
default=default_embedding_model,
).ask()

# get model dimentions
default_embedding_dim = config.embedding_dim if config.embedding_dim else "1024"
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
try:
embedding_dim = int(embedding_dim)
except Exception as e:
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
else: # local models
embedding_endpoint_type = "local"
embedding_endpoint = None
embedding_dim = 384
return embedding_endpoint_type, embedding_endpoint, embedding_dim

return embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model


def configure_cli(config: MemGPTConfig):
Expand Down Expand Up @@ -253,7 +280,7 @@ def configure():
config = MemGPTConfig.load()
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri = configure_archival_storage(config)

Expand Down Expand Up @@ -286,6 +313,7 @@ def configure():
embedding_endpoint_type=embedding_endpoint_type,
embedding_endpoint=embedding_endpoint,
embedding_dim=embedding_dim,
embedding_model=embedding_model,
# cli configs
preset=default_preset,
persona=default_persona,
Expand Down
5 changes: 5 additions & 0 deletions memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class MemGPTConfig:
# embedding parameters
embedding_endpoint_type: str = "openai" # openai, azure, local
embedding_endpoint: str = None
embedding_model: str = None
embedding_dim: int = 1536
embedding_chunk_size: int = 300 # number of tokens

Expand Down Expand Up @@ -153,6 +154,7 @@ def load(cls) -> "MemGPTConfig":
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
"embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"),
"embedding_model": get_field(config, "embedding", "embedding_model"),
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
Expand Down Expand Up @@ -203,6 +205,7 @@ def save(self):
# embeddings
set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type)
set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint)
set_field(config, "embedding", "embedding_model", self.embedding_model)
set_field(config, "embedding", "embedding_dim", str(self.embedding_dim))
set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size))

Expand Down Expand Up @@ -265,6 +268,7 @@ def __init__(
# embedding info
embedding_endpoint_type=None,
embedding_endpoint=None,
embedding_model=None,
embedding_dim=None,
embedding_chunk_size=None,
# other
Expand Down Expand Up @@ -292,6 +296,7 @@ def __init__(
self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper
self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
self.embedding_model = config.embedding_model if embedding_model is None else embedding_model
self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim
self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size

Expand Down
10 changes: 9 additions & 1 deletion memgpt/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typer
import os
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings import TextEmbeddingsInference


def embedding_model():
Expand All @@ -24,8 +25,15 @@ def embedding_model():
api_type="azure",
api_version=config.azure_version,
)
elif endpoint == "hugging-face":
embed_model = TextEmbeddingsInference(
base_url=config.embedding_endpoint,
model_name=config.embedding_model,
timeout=60, # timeout in seconds
)
return embed_model
else:
# default to hugging face model
# default to hugging face model running local
from llama_index.embeddings import HuggingFaceEmbedding

os.environ["TOKENIZERS_PARALLELISM"] = "False"
Expand Down
1 change: 1 addition & 0 deletions memgpt/local_llm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"ollama": "http://localhost:11434",
"webui-legacy": "http://localhost:5000",
"webui": "http://localhost:5000",
"vllm": "http://localhost:8000",
}

DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
Expand Down
4 changes: 2 additions & 2 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,8 @@ def search(self, query_string, count=None, start=None):
query_vec = self.embed_model.get_text_embedding(query_string)
self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)

start = start if start else 0
count = count if count else self.top_k
start = int(start if start else 0)
count = int(count if count else self.top_k)
end = min(count + start, len(self.cache[query_string]))

results = self.cache[query_string][start:end]
Expand Down
86 changes: 68 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
docstring-parser = "^0.15"
lancedb = {version = "^0.3.3", optional = true}
httpx = "^0.25.2"

[tool.poetry.extras]
legacy = ["faiss-cpu", "numpy"]
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def configure_memgpt_localllm():
child.expect("Select embedding provider", timeout=TIMEOUT)
child.send("\x1b[B") # Send the down arrow key
child.send("\x1b[B") # Send the down arrow key
child.send("\x1b[B") # Send the down arrow key
child.sendline()

child.expect("Select default preset", timeout=TIMEOUT)
Expand Down
Loading