Skip to content

Commit

Permalink
chore: Consistent naming for VectorIO providers (#1023)
Browse files Browse the repository at this point in the history
# What does this PR do?

This changes all VectorIO providers classes to follow the pattern
`<ProviderName>VectorIOConfig` and `<ProviderName>VectorIOAdapter`. All
API endpoints for VectorIOs are currently consistent with `/vector-io`.

Note that API endpoint for VectorDB stay unchanged as `/vector-dbs`. 

## Test Plan

I don't have a way to test all providers. This is a simple renaming so
things should work as expected.

---------

Signed-off-by: Yuan Tang <[email protected]>
  • Loading branch information
terrytangyuan authored Feb 13, 2025
1 parent e4a1579 commit 8ff27b5
Show file tree
Hide file tree
Showing 34 changed files with 85 additions and 86 deletions.
2 changes: 1 addition & 1 deletion llama_stack/apis/telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)

from llama_models.llama3.api.datatypes import Primitive
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/inline/vector_io/chroma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import ChromaInlineImplConfig
from .config import ChromaVectorIOConfig


async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/vector_io/chroma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel


class ChromaInlineImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
db_path: str

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions llama_stack/providers/inline/vector_io/faiss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import FaissImplConfig
from .config import FaissVectorIOConfig


async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOImpl
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOAdapter

assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}"
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

impl = FaissVectorIOImpl(config, deps[Api.inference])
impl = FaissVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/vector_io/faiss/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@json_schema_type
class FaissImplConfig(BaseModel):
class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/inline/vector_io/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
VectorDBWithIndex,
)

from .config import FaissImplConfig
from .config import FaissVectorIOConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,8 +112,8 @@ async def query(self, embedding: NDArray, k: int, score_threshold: float) -> Que
return QueryChunksResponse(chunks=chunks, scores=scores)


class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/vector_io/chroma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig


async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter

impl = ChromaVectorIOAdapter(config, deps[Api.inference])
Expand Down
7 changes: 3 additions & 4 deletions llama_stack/providers/remote/vector_io/chroma/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)

from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +88,7 @@ async def delete(self):
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
config: Union[ChromaVectorIOConfig, ChromaVectorIOConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
Expand All @@ -100,7 +99,7 @@ def __init__(
self.cache = {}

async def initialize(self) -> None:
if isinstance(self.config, ChromaRemoteImplConfig):
if isinstance(self.config, ChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/chroma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel


class ChromaRemoteImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
url: str

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/pgvector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig


async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorDBAdapter
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter

impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@json_schema_type
class PGVectorConfig(BaseModel):
class PGVectorVectorIOConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str = Field(default="postgres")
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
VectorDBWithIndex,
)

from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,8 +121,8 @@ async def delete(self):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")


class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/qdrant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import QdrantConfig
from .config import QdrantVectorIOConfig


async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorDBAdapter
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter

impl = QdrantVectorDBAdapter(config, deps[Api.inference])
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/qdrant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@json_schema_type
class QdrantConfig(BaseModel):
class QdrantVectorIOConfig(BaseModel):
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
VectorDBWithIndex,
)

from .config import QdrantConfig
from .config import QdrantVectorIOConfig

log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
Expand Down Expand Up @@ -98,8 +98,8 @@ async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)


class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from typing import Any

from .config import SampleConfig
from .config import SampleVectorIOConfig


async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleMemoryImpl
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
from .sample import SampleVectorIOImpl

impl = SampleMemoryImpl(config)
impl = SampleVectorIOImpl(config)
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/sample/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from pydantic import BaseModel


class SampleConfig(BaseModel):
class SampleVectorIOConfig(BaseModel):
host: str = "localhost"
port: int = 9999
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO

from .config import SampleConfig
from .config import SampleVectorIOConfig


class SampleMemoryImpl(VectorIO):
def __init__(self, config: SampleConfig):
class SampleVectorIOImpl(VectorIO):
def __init__(self, config: SampleVectorIOConfig):
self.config = config

async def register_vector_db(self, vector_db: VectorDB) -> None:
Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/vector_io/weaviate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.providers.datatypes import Api, ProviderSpec

from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig # noqa: F401


async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter

impl = WeaviateMemoryAdapter(config, deps[Api.inference])
impl = WeaviateVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/vector_io/weaviate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ class WeaviateRequestProviderData(BaseModel):
weaviate_cluster_url: str


class WeaviateConfig(BaseModel):
class WeaviateVectorIOConfig(BaseModel):
pass
6 changes: 3 additions & 3 deletions llama_stack/providers/remote/vector_io/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
VectorDBWithIndex,
)

from .config import WeaviateConfig, WeaviateRequestProviderData
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,12 +85,12 @@ async def delete(self, chunk_ids: List[str]) -> None:
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))


class WeaviateMemoryAdapter(
class WeaviateVectorIOAdapter(
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
Expand Down
20 changes: 10 additions & 10 deletions llama_stack/providers/tests/vector_io/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig

Expand Down Expand Up @@ -45,7 +45,7 @@ def vector_io_faiss() -> ProviderFixture:
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
config=FaissVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
Expand Down Expand Up @@ -76,7 +76,7 @@ def vector_io_pgvector() -> ProviderFixture:
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
config=PGVectorVectorIOConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
Expand All @@ -95,7 +95,7 @@ def vector_io_weaviate() -> ProviderFixture:
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
config=WeaviateVectorIOConfig().model_dump(),
)
],
provider_data=dict(
Expand All @@ -109,12 +109,12 @@ def vector_io_weaviate() -> ProviderFixture:
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
config = ChromaVectorIOConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/templates/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings

Expand All @@ -37,7 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)

core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
Expand Down
Loading

0 comments on commit 8ff27b5

Please sign in to comment.