Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Consistent naming for VectorIO providers #1023

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama_stack/apis/telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)

from llama_models.llama3.api.datatypes import Primitive
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/inline/vector_io/chroma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import ChromaInlineImplConfig
from .config import ChromaVectorIOConfig


async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/vector_io/chroma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel


class ChromaInlineImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
db_path: str

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions llama_stack/providers/inline/vector_io/faiss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import FaissImplConfig
from .config import FaissVectorIOConfig


async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOImpl
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOAdapter

assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}"
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

impl = FaissVectorIOImpl(config, deps[Api.inference])
impl = FaissVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/vector_io/faiss/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@json_schema_type
class FaissImplConfig(BaseModel):
class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/inline/vector_io/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
VectorDBWithIndex,
)

from .config import FaissImplConfig
from .config import FaissVectorIOConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,8 +112,8 @@ async def query(self, embedding: NDArray, k: int, score_threshold: float) -> Que
return QueryChunksResponse(chunks=chunks, scores=scores)


class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/vector_io/chroma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig


async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter

impl = ChromaVectorIOAdapter(config, deps[Api.inference])
Expand Down
7 changes: 3 additions & 4 deletions llama_stack/providers/remote/vector_io/chroma/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)

from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +88,7 @@ async def delete(self):
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
config: Union[ChromaVectorIOConfig, ChromaVectorIOConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
Expand All @@ -100,7 +99,7 @@ def __init__(
self.cache = {}

async def initialize(self) -> None:
if isinstance(self.config, ChromaRemoteImplConfig):
if isinstance(self.config, ChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/chroma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel


class ChromaRemoteImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
url: str

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/pgvector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig


async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorDBAdapter
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter

impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@json_schema_type
class PGVectorConfig(BaseModel):
class PGVectorVectorIOConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str = Field(default="postgres")
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
VectorDBWithIndex,
)

from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,8 +121,8 @@ async def delete(self):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")


class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/qdrant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import QdrantConfig
from .config import QdrantVectorIOConfig


async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorDBAdapter
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter

impl = QdrantVectorDBAdapter(config, deps[Api.inference])
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/qdrant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@json_schema_type
class QdrantConfig(BaseModel):
class QdrantVectorIOConfig(BaseModel):
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
VectorDBWithIndex,
)

from .config import QdrantConfig
from .config import QdrantVectorIOConfig

log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
Expand Down Expand Up @@ -98,8 +98,8 @@ async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)


class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from typing import Any

from .config import SampleConfig
from .config import SampleVectorIOConfig


async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleMemoryImpl
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
from .sample import SampleVectorIOImpl

impl = SampleMemoryImpl(config)
impl = SampleVectorIOImpl(config)
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/sample/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from pydantic import BaseModel


class SampleConfig(BaseModel):
class SampleVectorIOConfig(BaseModel):
host: str = "localhost"
port: int = 9999
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO

from .config import SampleConfig
from .config import SampleVectorIOConfig


class SampleMemoryImpl(VectorIO):
def __init__(self, config: SampleConfig):
class SampleVectorIOImpl(VectorIO):
def __init__(self, config: SampleVectorIOConfig):
self.config = config

async def register_vector_db(self, vector_db: VectorDB) -> None:
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/weaviate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig # noqa: F401


async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter

impl = WeaviateMemoryAdapter(config, deps[Api.inference])
impl = WeaviateVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ class WeaviateRequestProviderData(BaseModel):
weaviate_cluster_url: str


class WeaviateConfig(BaseModel):
class WeaviateVectorIOConfig(BaseModel):
pass
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
VectorDBWithIndex,
)

from .config import WeaviateConfig, WeaviateRequestProviderData
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,12 +85,12 @@ async def delete(self, chunk_ids: List[str]) -> None:
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))


class WeaviateMemoryAdapter(
class WeaviateVectorIOAdapter(
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
Expand Down
20 changes: 10 additions & 10 deletions llama_stack/providers/tests/vector_io/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig

Expand Down Expand Up @@ -45,7 +45,7 @@ def vector_io_faiss() -> ProviderFixture:
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
config=FaissVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
Expand Down Expand Up @@ -76,7 +76,7 @@ def vector_io_pgvector() -> ProviderFixture:
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
config=PGVectorVectorIOConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
Expand All @@ -95,7 +95,7 @@ def vector_io_weaviate() -> ProviderFixture:
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
config=WeaviateVectorIOConfig().model_dump(),
)
],
provider_data=dict(
Expand All @@ -109,12 +109,12 @@ def vector_io_weaviate() -> ProviderFixture:
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
config = ChromaVectorIOConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/templates/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings

Expand All @@ -37,7 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)

core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
Expand Down
Loading
Loading