diff --git a/.gitignore b/.gitignore index 316c32cb664..0ee3678ced8 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ index_data # Default configuration for persist_directory in chromadb/config.py # Currently it's located in "./chroma/" chroma/ -chroma_test_data +chroma_test_data/ server.htpasswd .venv diff --git a/chroma_data/chroma.sqlite3 b/chroma_data/chroma.sqlite3 new file mode 100644 index 00000000000..5885d1523cf Binary files /dev/null and b/chroma_data/chroma.sqlite3 differ diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 9c0b8000a14..599fb94dd45 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -35,8 +35,8 @@ "QueryResult", "GetResult", ] -from chromadb.telemetry.events import ClientStartEvent -from chromadb.telemetry import Telemetry +from chromadb.telemetry.product.events import ClientStartEvent +from chromadb.telemetry.product import ProductTelemetryClient logger = logging.getLogger(__name__) @@ -56,12 +56,14 @@ is_client = False try: from chromadb.is_thin_client import is_thin_client # type: ignore + is_client = is_thin_client except ImportError: is_client = False if not is_client: import sqlite3 + if sqlite3.sqlite_version_info < (3, 35, 0): if IN_COLAB: # In Colab, hotswap to pysqlite-binary if it's too old @@ -75,8 +77,11 @@ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") else: raise RuntimeError( - "\033[91mYour system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0.\033[0m\n" - "\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m" + "\033[91mYour system has an unsupported version of sqlite3. Chroma \ + requires sqlite3 >= 3.35.0.\033[0m\n" + "\033[94mPlease visit \ + https://docs.trychroma.com/troubleshooting#sqlite to learn how \ + to upgrade.\033[0m" ) @@ -147,12 +152,11 @@ def Client(settings: Settings = __settings) -> API: system = System(settings) - telemetry_client = system.instance(Telemetry) + product_telemetry_client = system.instance(ProductTelemetryClient) api = system.instance(API) system.start() - # Submit event for client start - telemetry_client.capture(ClientStartEvent()) + product_telemetry_client.capture(ClientStartEvent()) return api diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ddd537ebff..8db5bf889f7 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -31,7 +31,12 @@ from chromadb.auth.providers import RequestsClientAuthProtocolAdapter from chromadb.auth.registry import resolve_provider from chromadb.config import Settings, System -from chromadb.telemetry import Telemetry +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) +from chromadb.telemetry.product import ProductTelemetryClient from urllib.parse import urlparse, urlunparse, quote logger = logging.getLogger(__name__) @@ -51,7 +56,8 @@ def _validate_host(host: str) -> None: if "/" in host and (not host.startswith("http")): raise ValueError( "Invalid URL. " - "Seems that you are trying to pass URL as a host but without specifying the protocol. " + "Seems that you are trying to pass URL as a host but without \ + specifying the protocol. " "Please add http:// or https:// to the host." ) @@ -92,7 +98,8 @@ def __init__(self, system: System): system.settings.require("chroma_server_host") system.settings.require("chroma_server_http_port") - self._telemetry_client = self.require(Telemetry) + self._opentelemetry_client = self.require(OpenTelemetryClient) + self._product_telemetry_client = self.require(ProductTelemetryClient) self._settings = system.settings self._api_url = FastAPI.resolve_url( @@ -127,6 +134,7 @@ def __init__(self, system: System): if self._header is not None: self._session.headers.update(self._header) + @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" @@ -134,6 +142,7 @@ def heartbeat(self) -> int: raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @override def list_collections(self) -> Sequence[Collection]: """Returns a list of all collections""" @@ -146,6 +155,7 @@ def list_collections(self) -> Sequence[Collection]: return collections + @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, @@ -171,6 +181,7 @@ def create_collection( metadata=resp_json["metadata"], ) + @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, @@ -189,6 +200,9 @@ def get_collection( metadata=resp_json["metadata"], ) + @trace_method( + "FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION + ) @override def get_or_create_collection( self, @@ -200,6 +214,7 @@ def get_or_create_collection( name, metadata, embedding_function, get_or_create=True ) + @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, @@ -214,12 +229,14 @@ def _modify( ) raise_chroma_error(resp) + @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection(self, name: str) -> None: """Deletes a collection""" resp = self._session.delete(self._api_url + "/collections/" + name) raise_chroma_error(resp) + @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count(self, collection_id: UUID) -> int: """Returns the number of embeddings in the database""" @@ -229,6 +246,7 @@ def _count(self, collection_id: UUID) -> int: raise_chroma_error(resp) return cast(int, resp.json()) + @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._get( @@ -237,6 +255,7 @@ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: include=["embeddings", "documents", "metadatas"], ) + @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, @@ -279,6 +298,7 @@ def _get( documents=body.get("documents", None), ) + @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, @@ -298,6 +318,7 @@ def _delete( raise_chroma_error(resp) return cast(IDs, resp.json()) + @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( self, batch: Tuple[ @@ -321,6 +342,7 @@ def _submit_batch( ) return resp + @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL) @override def _add( self, @@ -340,6 +362,7 @@ def _add( raise_chroma_error(resp) return True + @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override def _update( self, @@ -361,6 +384,7 @@ def _update( resp.raise_for_status() return True + @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @override def _upsert( self, @@ -382,6 +406,7 @@ def _upsert( resp.raise_for_status() return True + @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) @override def _query( self, @@ -417,6 +442,7 @@ def _query( documents=body.get("documents", None), ) + @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) @override def reset(self) -> bool: """Resets the database""" @@ -424,6 +450,7 @@ def reset(self) -> bool: raise_chroma_error(resp) return cast(bool, resp.json()) + @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override def get_version(self) -> str: """Returns the version of the server""" @@ -437,6 +464,7 @@ def get_settings(self) -> Settings: return self._settings @property + @trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION) @override def max_batch_size(self) -> int: if self._max_batch_size == -1: diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index cfe1300e76e..45dcefc6697 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -2,7 +2,13 @@ from chromadb.config import Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader -from chromadb.telemetry import Telemetry +from chromadb.telemetry.opentelemetry import ( + add_attributes_to_current_span, + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) +from chromadb.telemetry.product import ProductTelemetryClient from chromadb.ingest import Producer from chromadb.api.models.Collection import Collection from chromadb import __version__ @@ -28,7 +34,7 @@ validate_where_document, validate_batch, ) -from chromadb.telemetry.events import ( +from chromadb.telemetry.product.events import ( CollectionAddEvent, CollectionDeleteEvent, CollectionGetEvent, @@ -78,7 +84,10 @@ class SegmentAPI(API): _sysdb: SysDB _manager: SegmentManager _producer: Producer - _telemetry_client: Telemetry + _product_telemetry_client: ProductTelemetryClient + _opentelemetry_client: OpenTelemetryClient + _tenant_id: str + _topic_ns: str _collection_cache: Dict[UUID, t.Collection] def __init__(self, system: System): @@ -86,7 +95,8 @@ def __init__(self, system: System): self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) - self._telemetry_client = self.require(Telemetry) + self._product_telemetry_client = self.require(ProductTelemetryClient) + self._opentelemetry_client = self.require(OpenTelemetryClient) self._producer = self.require(Producer) self._collection_cache = {} @@ -97,6 +107,7 @@ def heartbeat(self) -> int: # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. + @trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, @@ -127,12 +138,13 @@ def create_collection( self._sysdb.create_segment(segment) # TODO: This event doesn't capture the get_or_create case appropriately - self._telemetry_client.capture( + self._product_telemetry_client.capture( ClientCreateCollectionEvent( collection_uuid=str(id), embedding_function=embedding_function.__class__.__name__, ) ) + add_attributes_to_current_span({"collection_uuid": str(id)}) return Collection( client=self, @@ -142,6 +154,9 @@ def create_collection( embedding_function=embedding_function, ) + @trace_method( + "SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION + ) @override def get_or_create_collection( self, @@ -149,7 +164,7 @@ def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: - return self.create_collection( + return self.create_collection( # type: ignore name=name, metadata=metadata, embedding_function=embedding_function, @@ -159,6 +174,7 @@ def get_or_create_collection( # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings + @trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, @@ -178,6 +194,7 @@ def get_collection( else: raise ValueError(f"Collection {name} does not exist.") + @trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION) @override def list_collections(self) -> Sequence[Collection]: collections = [] @@ -193,6 +210,7 @@ def list_collections(self) -> Sequence[Collection]: ) return collections + @trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, @@ -216,6 +234,7 @@ def _modify( elif new_metadata: self._sysdb.update_collection(id, metadata=new_metadata) + @trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection(self, name: str) -> None: existing = self._sysdb.get_collections(name=name) @@ -229,6 +248,7 @@ def delete_collection(self, name: str) -> None: else: raise ValueError(f"Collection {name} does not exist.") + @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION) @override def _add( self, @@ -256,7 +276,7 @@ def _add( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionAddEvent( collection_uuid=str(collection_id), add_amount=len(ids), @@ -266,6 +286,7 @@ def _add( ) return True + @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION) @override def _update( self, @@ -293,7 +314,7 @@ def _update( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionUpdateEvent( collection_uuid=str(collection_id), update_amount=len(ids), @@ -305,6 +326,7 @@ def _update( return True + @trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION) @override def _upsert( self, @@ -334,6 +356,7 @@ def _upsert( return True + @trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, @@ -348,6 +371,13 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "ids_count": len(ids) if ids else 0, + } + ) + where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -388,7 +418,7 @@ def _get( documents = [_doc(m) for m in metadatas] ids_amount = len(ids) if ids else 0 - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionGetEvent( collection_uuid=str(collection_id), ids_count=ids_amount, @@ -407,6 +437,7 @@ def _get( documents=documents if "documents" in include else None, # type: ignore ) + @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, @@ -415,6 +446,13 @@ def _delete( where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "ids_count": len(ids) if ids else 0, + } + ) + where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -461,18 +499,21 @@ def _delete( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionDeleteEvent( collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) ) ) return ids_to_delete + @trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count(self, collection_id: UUID) -> int: + add_attributes_to_current_span({"collection_id": str(collection_id)}) metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() + @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) @override def _query( self, @@ -483,6 +524,13 @@ def _query( where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "n_results": n_results, + "where": str(where), + } + ) where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( validate_where_document(where_document) @@ -552,7 +600,7 @@ def _query( documents.append(doc_list) # type: ignore query_amount = len(query_embeddings) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, @@ -573,9 +621,11 @@ def _query( documents=documents if documents else None, ) + @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: - return self._get(collection_id, limit=n) + add_attributes_to_current_span({"collection_id": str(collection_id)}) + return self._get(collection_id, limit=n) # type: ignore @override def get_version(self) -> str: @@ -601,20 +651,24 @@ def max_batch_size(self) -> int: # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. + # TODO: promote collection -> topic to a base class method so that it can be + # used for channel assignment in the distributed version of the system. + @trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL) def _validate_embedding_record( self, collection: t.Collection, record: t.SubmitEmbeddingRecord ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" + add_attributes_to_current_span({"collection_id": str(collection["id"])}) if record["embedding"]: self._validate_dimension(collection, len(record["embedding"]), update=True) + @trace_method("SegmentAPI._validate_dimension", OpenTelemetryGranularity.ALL) def _validate_dimension( self, collection: t.Collection, dim: int, update: bool ) -> None: """Validate that a collection supports records of the given dimension. If update is true, update the collection if the collection doesn't already have a dimension.""" - if collection["dimension"] is None: if update: id = collection["id"] @@ -627,6 +681,7 @@ def _validate_dimension( else: return # all is well + @trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL) def _get_collection(self, collection_id: UUID) -> t.Collection: """Read-through cache for collection data""" if collection_id not in self._collection_cache: diff --git a/chromadb/auth/basic/__init__.py b/chromadb/auth/basic/__init__.py index a03d195e8ae..a9888598a22 100644 --- a/chromadb/auth/basic/__init__.py +++ b/chromadb/auth/basic/__init__.py @@ -17,6 +17,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils import get_class logger = logging.getLogger(__name__) @@ -84,6 +89,7 @@ def __init__(self, system: System) -> None: ), ) + @trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: try: diff --git a/chromadb/auth/fastapi.py b/chromadb/auth/fastapi.py index a488ef5f2b3..14b531e48e8 100644 --- a/chromadb/auth/fastapi.py +++ b/chromadb/auth/fastapi.py @@ -17,6 +17,11 @@ ChromaAuthMiddleware, ) from chromadb.auth.registry import resolve_provider +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) logger = logging.getLogger(__name__) @@ -72,6 +77,9 @@ def __init__(self, system: System) -> None: ) self._auth_provider = cast(ServerAuthProvider, self.require(_cls)) + @trace_method( + "FastAPIChromaAuthMiddleware.authenticate", OpenTelemetryGranularity.ALL + ) @override def authenticate( self, request: ServerAuthenticationRequest[Any] diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index eceee3bc2ab..2982b9e15a6 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -15,6 +15,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) T = TypeVar("T") @@ -34,6 +39,10 @@ def __init__(self, system: System) -> None: "The bcrypt python package is not installed. Please install it with `pip install bcrypt`" ) + @trace_method( + "HtpasswdServerAuthCredentialsProvider.validate_credentials", + OpenTelemetryGranularity.ALL, + ) @override def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) diff --git a/chromadb/auth/token/__init__.py b/chromadb/auth/token/__init__.py index 5132fa35798..6dfa8635942 100644 --- a/chromadb/auth/token/__init__.py +++ b/chromadb/auth/token/__init__.py @@ -19,6 +19,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils import get_class T = TypeVar("T") @@ -86,6 +91,10 @@ def __init__(self, system: System) -> None: check_token(token_str) self._token = SecretStr(token_str) + @trace_method( + "TokenConfigServerAuthCredentialsProvider.validate_credentials", + OpenTelemetryGranularity.ALL, + ) @override def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) @@ -150,6 +159,7 @@ def __init__(self, system: System) -> None: str(system.settings.chroma_server_auth_token_transport_header) ] + @trace_method("TokenAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: try: @@ -189,6 +199,7 @@ def __init__(self, system: System) -> None: str(system.settings.chroma_client_auth_token_transport_header) ] + @trace_method("TokenAuthClientProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self) -> ClientAuthResponse: _token = self._credentials_provider.get_credentials() diff --git a/chromadb/config.py b/chromadb/config.py index eb7bca93ef5..6731db255b7 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -64,7 +64,7 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { "chromadb.api.API": "chroma_api_impl", - "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", + "chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", "chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa @@ -83,7 +83,9 @@ class Settings(BaseSettings): # type: ignore chroma_db_impl: Optional[str] = None chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" - chroma_telemetry_impl: str = "chromadb.telemetry.posthog.Posthog" + chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog" + # Required for backwards compatibility + chroma_telemetry_impl: str = chroma_product_telemetry_impl # New architecture components chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB" @@ -174,6 +176,11 @@ def chroma_server_auth_credentials_file_non_empty_file_exists( anonymized_telemetry: bool = True + chroma_otel_collection_endpoint: Optional[str] = "" + chroma_otel_service_name: Optional[str] = "chromadb" + chroma_otel_collection_headers: Dict[str, str] = {} + chroma_otel_granularity: Optional[str] = "none" + allow_reset: bool = False migrations: Literal["none", "validate", "apply"] = "apply" diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index aed14deb8e2..6652d21333a 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -4,6 +4,11 @@ import chromadb.db.base as base from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue from chromadb.db.mixins.sysdb import SqlSysDB +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils.delete_file import delete_file import sqlite3 from overrides import override @@ -67,6 +72,7 @@ def __init__(self, system: System): files("chromadb.migrations.metadb"), ] self._is_persistent = self._settings.require("is_persistent") + self._opentelemetry_client = system.require(OpenTelemetryClient) if not self._is_persistent: # In order to allow sqlite to be shared between multiple threads, we need to use a # URI connection string with shared cache. @@ -84,6 +90,7 @@ def __init__(self, system: System): self._tx_stack = local() super().__init__(system) + @trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -92,6 +99,7 @@ def start(self) -> None: cur.execute("PRAGMA case_sensitive_like = ON") self.initialize_migrations() + @trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: super().stop() @@ -122,6 +130,7 @@ def tx(self) -> TxWrapper: self._tx_stack.stack = [] return TxWrapper(self._conn_pool, stack=self._tx_stack) + @trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL) @override def reset_state(self) -> None: if not self._settings.require("allow_reset"): @@ -132,9 +141,9 @@ def reset_state(self) -> None: # Drop all tables cur.execute( """ - SELECT name FROM sqlite_master - WHERE type='table' - """ + SELECT name FROM sqlite_master + WHERE type='table' + """ ) for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") @@ -144,28 +153,30 @@ def reset_state(self) -> None: self.start() super().reset_state() + @trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL) @override def setup_migrations(self) -> None: with self.tx() as cur: cur.execute( """ - CREATE TABLE IF NOT EXISTS migrations ( - dir TEXT NOT NULL, - version INTEGER NOT NULL, - filename TEXT NOT NULL, - sql TEXT NOT NULL, - hash TEXT NOT NULL, - PRIMARY KEY (dir, version) - ) - """ + CREATE TABLE IF NOT EXISTS migrations ( + dir TEXT NOT NULL, + version INTEGER NOT NULL, + filename TEXT NOT NULL, + sql TEXT NOT NULL, + hash TEXT NOT NULL, + PRIMARY KEY (dir, version) + ) + """ ) + @trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL) @override def migrations_initialized(self) -> bool: with self.tx() as cur: cur.execute( """SELECT count(*) FROM sqlite_master - WHERE type='table' AND name='migrations'""" + WHERE type='table' AND name='migrations'""" ) if cur.fetchone()[0] == 0: @@ -173,6 +184,7 @@ def migrations_initialized(self) -> bool: else: return True + @trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL) @override def db_migrations(self, dir: Traversable) -> Sequence[Migration]: with self.tx() as cur: diff --git a/chromadb/db/migrations.py b/chromadb/db/migrations.py index af2ecce4375..76476502d1b 100644 --- a/chromadb/db/migrations.py +++ b/chromadb/db/migrations.py @@ -6,6 +6,11 @@ from chromadb.db.base import SqlDB, Cursor from abc import abstractmethod from chromadb.config import System, Settings +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) class MigrationFile(TypedDict): @@ -82,6 +87,7 @@ class MigratableDB(SqlDB): def __init__(self, system: System) -> None: self._settings = system.settings + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @staticmethod @@ -127,6 +133,7 @@ def initialize_migrations(self) -> None: if migrate == "apply": self.apply_migrations() + @trace_method("MigratableDB.validate_migrations", OpenTelemetryGranularity.ALL) def validate_migrations(self) -> None: """Validate all migrations and throw an exception if there are any unapplied migrations in the source repo.""" @@ -142,6 +149,7 @@ def validate_migrations(self) -> None: version = unapplied_migrations[0]["version"] raise UnappliedMigrationsError(dir=dir.name, version=version) + @trace_method("MigratableDB.apply_migrations", OpenTelemetryGranularity.ALL) def apply_migrations(self) -> None: """Validate existing migrations, and apply all new ones.""" self.setup_migrations() diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index 472e0254283..f926d608e05 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -14,6 +14,11 @@ Operation, ) from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from overrides import override from collections import defaultdict from typing import Sequence, Tuple, Optional, Dict, Set, cast @@ -79,8 +84,10 @@ def __init__( def __init__(self, system: System): self._subscriptions = defaultdict(set) self._max_batch_size = None + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) + @trace_method("SqlEmbeddingsQueue.reset_state", OpenTelemetryGranularity.ALL) @override def reset_state(self) -> None: super().reset_state() @@ -91,6 +98,7 @@ def create_topic(self, topic_name: str) -> None: # Topic creation is implicit for this impl pass + @trace_method("SqlEmbeddingsQueue.delete_topic", OpenTelemetryGranularity.ALL) @override def delete_topic(self, topic_name: str) -> None: t = Table("embeddings_queue") @@ -104,6 +112,7 @@ def delete_topic(self, topic_name: str) -> None: sql, params = get_sql(q, self.parameter_format()) cur.execute(sql, params) + @trace_method("SqlEmbeddingsQueue.submit_embedding", OpenTelemetryGranularity.ALL) @override def submit_embedding( self, topic_name: str, embedding: SubmitEmbeddingRecord @@ -113,6 +122,7 @@ def submit_embedding( return self.submit_embeddings(topic_name, [embedding])[0] + @trace_method("SqlEmbeddingsQueue.submit_embeddings", OpenTelemetryGranularity.ALL) @override def submit_embeddings( self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] @@ -126,10 +136,10 @@ def submit_embeddings( if len(embeddings) > self.max_batch_size: raise ValueError( f""" - Cannot submit more than {self.max_batch_size:,} embeddings at once. - Please submit your embeddings in batches of size - {self.max_batch_size:,} or less. - """ + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ ) t = Table("embeddings_queue") @@ -182,6 +192,7 @@ def submit_embeddings( self._notify_all(topic_name, embedding_records) return seq_ids + @trace_method("SqlEmbeddingsQueue.subscribe", OpenTelemetryGranularity.ALL) @override def subscribe( self, @@ -207,6 +218,7 @@ def subscribe( return subscription_id + @trace_method("SqlEmbeddingsQueue.unsubscribe", OpenTelemetryGranularity.ALL) @override def unsubscribe(self, subscription_id: UUID) -> None: for topic_name, subscriptions in self._subscriptions.items(): @@ -226,6 +238,7 @@ def max_seqid(self) -> SeqId: return 2**63 - 1 @property + @trace_method("SqlEmbeddingsQueue.max_batch_size", OpenTelemetryGranularity.ALL) @override def max_batch_size(self) -> int: if self._max_batch_size is None: @@ -247,6 +260,10 @@ def max_batch_size(self) -> int: self._max_batch_size = 999 // self.VARIABLES_PER_RECORD return self._max_batch_size + @trace_method( + "SqlEmbeddingsQueue._prepare_vector_encoding_metadata", + OpenTelemetryGranularity.ALL, + ) def _prepare_vector_encoding_metadata( self, embedding: SubmitEmbeddingRecord ) -> Tuple[Optional[bytes], Optional[str], Optional[str]]: @@ -260,6 +277,7 @@ def _prepare_vector_encoding_metadata( metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None return embedding_bytes, encoding, metadata + @trace_method("SqlEmbeddingsQueue._backfill", OpenTelemetryGranularity.ALL) def _backfill(self, subscription: Subscription) -> None: """Backfill the given subscription with any currently matching records in the DB""" @@ -298,6 +316,7 @@ def _backfill(self, subscription: Subscription) -> None: ], ) + @trace_method("SqlEmbeddingsQueue._validate_range", OpenTelemetryGranularity.ALL) def _validate_range( self, start: Optional[SeqId], end: Optional[SeqId] ) -> Tuple[int, int]: @@ -311,6 +330,7 @@ def _validate_range( raise ValueError(f"Invalid SeqID range: {start} to {end}") return start, end + @trace_method("SqlEmbeddingsQueue._next_seq_id", OpenTelemetryGranularity.ALL) def _next_seq_id(self) -> int: """Get the next SeqID for this database.""" t = Table("embeddings_queue") @@ -319,12 +339,14 @@ def _next_seq_id(self) -> int: cur.execute(q.get_sql()) return int(cur.fetchone()[0]) + 1 + @trace_method("SqlEmbeddingsQueue._notify_all", OpenTelemetryGranularity.ALL) def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None: """Send a notification to each subscriber of the given topic.""" if self._running: for sub in self._subscriptions[topic]: self._notify_one(sub, embeddings) + @trace_method("SqlEmbeddingsQueue._notify_one", OpenTelemetryGranularity.ALL) def _notify_one( self, sub: Subscription, embeddings: Sequence[EmbeddingRecord] ) -> None: diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d105918e700..d9deb144f66 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -14,6 +14,12 @@ UniqueConstraintError, ) from chromadb.db.system import SysDB +from chromadb.telemetry.opentelemetry import ( + add_attributes_to_current_span, + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.ingest import CollectionAssignmentPolicy, Producer from chromadb.types import ( OptionalArgument, @@ -35,7 +41,9 @@ class SqlSysDB(SqlDB, SysDB): def __init__(self, system: System): self._assignment_policy = system.instance(CollectionAssignmentPolicy) super().__init__(system) + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -43,6 +51,15 @@ def start(self) -> None: @override def create_segment(self, segment: Segment) -> None: + add_attributes_to_current_span( + { + "segment_id": str(segment["id"]), + "segment_type": segment["type"], + "segment_scope": segment["scope"].value, + "segment_topic": str(segment["topic"]), + "collection": str(segment["collection"]), + } + ) with self.tx() as cur: segments = Table("segments") insert_segment = ( @@ -80,6 +97,7 @@ def create_segment(self, segment: Segment) -> None: segment["metadata"], ) + @trace_method("SqlSysDB.create_collection", OpenTelemetryGranularity.ALL) @override def create_collection( self, @@ -92,6 +110,13 @@ def create_collection( if id is None and not get_or_create: raise ValueError("id must be specified if get_or_create is False") + add_attributes_to_current_span( + { + "collection_id": str(id), + "collection_name": name, + } + ) + existing = self.get_collections(name=name) if existing: if get_or_create: @@ -146,6 +171,7 @@ def create_collection( ) return collection, True + @trace_method("SqlSysDB.get_segments", OpenTelemetryGranularity.ALL) @override def get_segments( self, @@ -155,6 +181,15 @@ def get_segments( topic: Optional[str] = None, collection: Optional[UUID] = None, ) -> Sequence[Segment]: + add_attributes_to_current_span( + { + "segment_id": str(id), + "segment_type": type if type else "", + "segment_scope": scope.value if scope else "", + "segment_topic": topic if topic else "", + "collection": str(collection), + } + ) segments_t = Table("segments") metadata_t = Table("segment_metadata") q = ( @@ -214,6 +249,7 @@ def get_segments( return segments + @trace_method("SqlSysDB.get_collections", OpenTelemetryGranularity.ALL) @override def get_collections( self, @@ -222,6 +258,13 @@ def get_collections( name: Optional[str] = None, ) -> Sequence[Collection]: """Get collections by name, embedding function and/or metadata""" + add_attributes_to_current_span( + { + "collection_id": str(id), + "collection_topic": topic if topic else "", + "collection_name": name if name else "", + } + ) collections_t = Table("collections") metadata_t = Table("collection_metadata") q = ( @@ -272,9 +315,15 @@ def get_collections( return collections + @trace_method("SqlSysDB.delete_segment", OpenTelemetryGranularity.ALL) @override def delete_segment(self, id: UUID) -> None: """Delete a segment from the SysDB""" + add_attributes_to_current_span( + { + "segment_id": str(id), + } + ) t = Table("segments") q = ( self.querybuilder() @@ -290,9 +339,15 @@ def delete_segment(self, id: UUID) -> None: if not result: raise NotFoundError(f"Segment {id} not found") + @trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL) @override def delete_collection(self, id: UUID) -> None: """Delete a topic and all associated segments from the SysDB""" + add_attributes_to_current_span( + { + "collection_id": str(id), + } + ) t = Table("collections") q = ( self.querybuilder() @@ -309,6 +364,7 @@ def delete_collection(self, id: UUID) -> None: raise NotFoundError(f"Collection {id} not found") self._producer.delete_topic(result[1]) + @trace_method("SqlSysDB.update_segment", OpenTelemetryGranularity.ALL) @override def update_segment( self, @@ -317,6 +373,12 @@ def update_segment( collection: OptionalArgument[Optional[UUID]] = Unspecified(), metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: + add_attributes_to_current_span( + { + "segment_id": str(id), + "collection": str(collection), + } + ) segments_t = Table("segments") metadata_t = Table("segment_metadata") @@ -361,6 +423,7 @@ def update_segment( set(metadata.keys()), ) + @trace_method("SqlSysDB.update_collection", OpenTelemetryGranularity.ALL) @override def update_collection( self, @@ -370,6 +433,11 @@ def update_collection( dimension: OptionalArgument[Optional[int]] = Unspecified(), metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: + add_attributes_to_current_span( + { + "collection_id": str(id), + } + ) collections_t = Table("collections") metadata_t = Table("collection_metadata") @@ -419,11 +487,17 @@ def update_collection( set(metadata.keys()), ) + @trace_method("SqlSysDB._metadata_from_rows", OpenTelemetryGranularity.ALL) def _metadata_from_rows( self, rows: Sequence[Tuple[Any, ...]] ) -> Optional[Metadata]: """Given SQL rows, return a metadata map (assuming that the last four columns are the key, str_value, int_value & float_value)""" + add_attributes_to_current_span( + { + "num_rows": len(rows), + } + ) metadata: Dict[str, Union[str, int, float]] = {} for row in rows: key = str(row[-4]) @@ -435,6 +509,7 @@ def _metadata_from_rows( metadata[key] = float(row[-1]) return metadata or None + @trace_method("SqlSysDB._insert_metadata", OpenTelemetryGranularity.ALL) def _insert_metadata( self, cur: Cursor, @@ -447,6 +522,11 @@ def _insert_metadata( # It would be cleaner to use something like ON CONFLICT UPDATE here But that is # very difficult to do in a portable way (e.g sqlite and postgres have # completely different sytnax) + add_attributes_to_current_span( + { + "num_keys": len(metadata), + } + ) if clear_keys: q = ( self.querybuilder() @@ -462,7 +542,11 @@ def _insert_metadata( self.querybuilder() .into(table) .columns( - id_col, table.key, table.str_value, table.int_value, table.float_value + id_col, + table.key, + table.str_value, + table.int_value, + table.float_value, ) ) sql_id = self.uuid_to_db(id) diff --git a/chromadb/ingest/impl/pulsar.py b/chromadb/ingest/impl/pulsar.py index 3f293c90580..3f71a1db36a 100644 --- a/chromadb/ingest/impl/pulsar.py +++ b/chromadb/ingest/impl/pulsar.py @@ -10,6 +10,11 @@ from chromadb.ingest.impl.utils import create_pulsar_connection_str from chromadb.proto.convert import from_proto_submit, to_proto_submit import chromadb.proto.chroma_pb2 as proto +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import SeqId, SubmitEmbeddingRecord import pulsar from concurrent.futures import wait, Future @@ -18,8 +23,10 @@ class PulsarProducer(Producer, EnforceOverrides): + # TODO: ensure trace context propagates _connection_str: str _topic_to_producer: Dict[str, pulsar.Producer] + _opentelemetry_client: OpenTelemetryClient _client: pulsar.Client _admin: PulsarAdmin _settings: Settings @@ -31,6 +38,7 @@ def __init__(self, system: System) -> None: self._topic_to_producer = {} self._settings = system.settings self._admin = PulsarAdmin(system) + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @overrides @@ -51,6 +59,7 @@ def create_topic(self, topic_name: str) -> None: def delete_topic(self, topic_name: str) -> None: self._admin.delete_topic(topic_name) + @trace_method("PulsarProducer.submit_embedding", OpenTelemetryGranularity.ALL) @overrides def submit_embedding( self, topic_name: str, embedding: SubmitEmbeddingRecord @@ -62,6 +71,7 @@ def submit_embedding( msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString()) return pulsar_to_int(msg_id) + @trace_method("PulsarProducer.submit_embeddings", OpenTelemetryGranularity.ALL) @overrides def submit_embeddings( self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] @@ -75,10 +85,10 @@ def submit_embeddings( if len(embeddings) > self.max_batch_size: raise ValueError( f""" - Cannot submit more than {self.max_batch_size:,} embeddings at once. - Please submit your embeddings in batches of size - {self.max_batch_size:,} or less. - """ + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ ) producer = self._get_or_create_producer(topic_name) @@ -171,6 +181,7 @@ def __init__( _connection_str: str _client: pulsar.Client + _opentelemetry_client: OpenTelemetryClient _subscriptions: Dict[str, Set[PulsarSubscription]] _settings: Settings @@ -180,6 +191,7 @@ def __init__(self, system: System) -> None: self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port) self._subscriptions = defaultdict(set) self._settings = system.settings + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @overrides @@ -192,6 +204,7 @@ def stop(self) -> None: self._client.close() super().stop() + @trace_method("PulsarConsumer.subscribe", OpenTelemetryGranularity.ALL) @overrides def subscribe( self, diff --git a/chromadb/segment/impl/distributed/server.py b/chromadb/segment/impl/distributed/server.py index f7ea2f2ecaf..9b56ed4d18a 100644 --- a/chromadb/segment/impl/distributed/server.py +++ b/chromadb/segment/impl/distributed/server.py @@ -17,7 +17,11 @@ to_proto_vector_embedding_record, ) from chromadb.segment import SegmentImplementation, SegmentType, VectorReader -from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ScalarEncoding, Segment, SegmentScope import logging @@ -38,11 +42,16 @@ class SegmentServer(SegmentServerServicer, VectorReaderServicer): _segment_cache: Dict[UUID, SegmentImplementation] = {} _system: System + _opentelemetry_client: OpenTelemetryClient def __init__(self, system: System) -> None: super().__init__() self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method( + "SegmentServer.LoadSegment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT + ) def LoadSegment( self, request: proto.Segment, context: Any ) -> proto.SegmentServerResponse: @@ -85,6 +94,9 @@ def QueryVectors( context.set_details("Query segment not implemented yet") return proto.QueryVectorsResponse() + @trace_method( + "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT + ) def GetVectors( self, request: proto.GetVectorsRequest, context: Any ) -> proto.GetVectorsResponse: diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index a7c673920a8..e03b58db224 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -15,6 +15,11 @@ from chromadb.db.system import SysDB from overrides import override from chromadb.segment.distributed import SegmentDirectory +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 @@ -35,6 +40,7 @@ class DistributedSegmentManager(SegmentManager): _sysdb: SysDB _system: System + _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _segment_cache: Dict[ UUID, Dict[SegmentScope, Segment] @@ -48,11 +54,16 @@ def __init__(self, system: System): self._sysdb = self.require(SysDB) self._segment_directory = self.require(SegmentDirectory) self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._segment_cache = defaultdict(dict) self._segment_server_stubs = {} self._lock = Lock() + @trace_method( + "DistributedSegmentManager.create_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( @@ -67,6 +78,10 @@ def create_segments(self, collection: Collection) -> Sequence[Segment]: def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: raise NotImplementedError() + @trace_method( + "DistributedSegmentManager.get_segment", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def get_segment(self, collection_id: UUID, type: type[S]) -> S: if type == MetadataReader: @@ -96,6 +111,10 @@ def get_segment(self, collection_id: UUID, type: type[S]) -> S: instance = self._instance(self._segment_cache[collection_id][scope]) return cast(S, instance) + @trace_method( + "DistributedSegmentManager.hint_use_collection", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # TODO: this should call load/release on the target node, node should be stored in metadata @@ -114,6 +133,13 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None segment = next(filter(lambda s: s["type"] in known_types, segments)) grpc_url = self._segment_directory.get_segment_endpoint(segment) + if grpc_url not in self._segment_server_stubs: + channel = grpc.insecure_channel(grpc_url) + self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore + + self._segment_server_stubs[grpc_url].LoadSegment( + to_proto_segment(segment) + ) if grpc_url not in self._segment_server_stubs: channel = grpc.insecure_channel(grpc_url) self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index a5b797e31c6..5e7e8b53784 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -13,6 +13,11 @@ from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, ) +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 @@ -37,6 +42,7 @@ class LocalSegmentManager(SegmentManager): _sysdb: SysDB _system: System + _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _vector_instances_file_handle_cache: LRUCache[ UUID, PersistentLocalHnswSegment @@ -52,6 +58,7 @@ def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._segment_cache = defaultdict(dict) self._lock = Lock() @@ -93,6 +100,10 @@ def reset_state(self) -> None: self._segment_cache = defaultdict(dict) super().reset_state() + @trace_method( + "LocalSegmentManager.create_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( @@ -103,6 +114,10 @@ def create_segments(self, collection: Collection) -> Sequence[Segment]: ) return [vector_segment, metadata_segment] + @trace_method( + "LocalSegmentManager.delete_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: segments = self._sysdb.get_segments(collection=collection_id) @@ -118,6 +133,10 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: del self._segment_cache[collection_id] return [s["id"] for s in segments] + @trace_method( + "LocalSegmentManager.get_segment", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: @@ -140,6 +159,10 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: instance = self._instance(self._segment_cache[collection_id][scope]) return cast(S, instance) + @trace_method( + "LocalSegmentManager.hint_use_collection", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # The local segment manager responds to hints by pre-loading both the metadata and vector diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index a7098d7808b..1bdb4eea63c 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -10,6 +10,11 @@ ParameterValue, get_sql, ) +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( Where, WhereDocument, @@ -39,6 +44,7 @@ class SqliteMetadataSegment(MetadataReader): _consumer: Consumer _db: SqliteDB _id: UUID + _opentelemetry_client: OpenTelemetryClient _topic: Optional[str] _subscription: Optional[UUID] @@ -46,8 +52,10 @@ def __init__(self, system: System, segment: Segment): self._db = system.instance(SqliteDB) self._consumer = system.instance(Consumer) self._id = segment["id"] + self._opentelemetry_client = system.require(OpenTelemetryClient) self._topic = segment["topic"] + @trace_method("SqliteMetadataSegment.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: if self._topic: @@ -56,11 +64,13 @@ def start(self) -> None: self._topic, self._write_metadata, start=seq_id ) + @trace_method("SqliteMetadataSegment.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: if self._subscription: self._consumer.unsubscribe(self._subscription) + @trace_method("SqliteMetadataSegment.max_seqid", OpenTelemetryGranularity.ALL) @override def max_seqid(self) -> SeqId: t = Table("max_seq_id") @@ -79,6 +89,7 @@ def max_seqid(self) -> SeqId: else: return _decode_seq_id(result[0]) + @trace_method("SqliteMetadataSegment.count", OpenTelemetryGranularity.ALL) @override def count(self) -> int: embeddings_t = Table("embeddings") @@ -95,6 +106,7 @@ def count(self) -> int: result = cur.execute(sql, params).fetchone()[0] return cast(int, result) + @trace_method("SqliteMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL) @override def get_metadata( self, @@ -162,6 +174,7 @@ def _records( for _, group in group_iterator: yield self._record(list(group)) + @trace_method("SqliteMetadataSegment._record", OpenTelemetryGranularity.ALL) def _record(self, rows: Sequence[Tuple[Any, ...]]) -> MetadataEmbeddingRecord: """Given a list of DB rows with the same ID, construct a MetadataEmbeddingRecord""" @@ -187,6 +200,7 @@ def _record(self, rows: Sequence[Tuple[Any, ...]]) -> MetadataEmbeddingRecord: metadata=metadata or None, ) + @trace_method("SqliteMetadataSegment._insert_record", OpenTelemetryGranularity.ALL) def _insert_record( self, cur: Cursor, record: EmbeddingRecord, upsert: bool ) -> None: @@ -221,6 +235,9 @@ def _insert_record( if record["metadata"]: self._update_metadata(cur, id, record["metadata"]) + @trace_method( + "SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL + ) def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None: """Update the metadata for a single EmbeddingRecord""" t = Table("embedding_metadata") @@ -238,6 +255,9 @@ def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> No self._insert_metadata(cur, id, metadata) + @trace_method( + "SqliteMetadataSegment._insert_metadata", OpenTelemetryGranularity.ALL + ) def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None: """Insert or update each metadata row for a single embedding record""" t = Table("embedding_metadata") @@ -245,7 +265,12 @@ def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> No self._db.querybuilder() .into(t) .columns( - t.id, t.key, t.string_value, t.int_value, t.float_value, t.bool_value + t.id, + t.key, + t.string_value, + t.int_value, + t.float_value, + t.bool_value, ) ) for key, value in metadata.items(): @@ -321,6 +346,7 @@ def insert_into_fulltext_search() -> None: cur.execute(sql, params) insert_into_fulltext_search() + @trace_method("SqliteMetadataSegment._delete_record", OpenTelemetryGranularity.ALL) def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None: """Delete a single EmbeddingRecord from the DB""" t = Table("embeddings") @@ -351,6 +377,7 @@ def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None: sql, params = get_sql(q) cur.execute(sql, params) + @trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL) def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None: """Update a single EmbeddingRecord in the DB""" t = Table("embeddings") @@ -371,6 +398,7 @@ def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None: if record["metadata"]: self._update_metadata(cur, id, record["metadata"]) + @trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL) def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None: """Write embedding metadata to the database. Care should be taken to ensure records are append-only (that is, that seq-ids should increase monotonically)""" @@ -398,6 +426,9 @@ def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None: elif record["operation"] == Operation.UPDATE: self._update_record(cur, record) + @trace_method( + "SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL + ) def _where_map_criterion( self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table ) -> Criterion: @@ -427,6 +458,9 @@ def _where_map_criterion( clause.append(embeddings_t.id.isin(sq)) return reduce(lambda x, y: x & y, clause) + @trace_method( + "SqliteMetadataSegment._where_doc_criterion", OpenTelemetryGranularity.ALL + ) def _where_doc_criterion( self, q: QueryBuilder, diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py index 0aac3baa253..89cc1b814f0 100644 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ b/chromadb/segment/impl/vector/grpc_segment.py @@ -9,6 +9,11 @@ ) from chromadb.segment import MetadataReader, VectorReader from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( Metadata, ScalarEncoding, @@ -30,6 +35,7 @@ class GrpcVectorSegment(VectorReader, EnforceOverrides): _vector_reader_stub: VectorReaderStub _segment: Segment + _opentelemetry_client: OpenTelemetryClient def __init__(self, system: System, segment: Segment): # TODO: move to start() method @@ -40,7 +46,9 @@ def __init__(self, system: System, segment: Segment): channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) self._vector_reader_stub = VectorReaderStub(channel) # type: ignore self._segment = segment + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -53,6 +61,7 @@ def get_vectors( results.append(result) return results + @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL) @override def query_vectors( self, query: VectorQuery diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index c45af628d2f..e4437881b2a 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -6,6 +6,11 @@ from chromadb.config import System, Settings from chromadb.segment.impl.vector.batch import Batch from chromadb.segment.impl.vector.hnsw_params import HnswParams +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( EmbeddingRecord, VectorEmbeddingRecord, @@ -46,6 +51,8 @@ class LocalHnswSegment(VectorReader): _label_to_id: Dict[int, str] _id_to_seq_id: Dict[str, SeqId] + _opentelemtry_client: OpenTelemetryClient + def __init__(self, system: System, segment: Segment): self._consumer = system.instance(Consumer) self._id = segment["id"] @@ -63,6 +70,7 @@ def __init__(self, system: System, segment: Segment): self._label_to_id = {} self._lock = ReadWriteLock() + self._opentelemtry_client = system.require(OpenTelemetryClient) super().__init__(system, segment) @staticmethod @@ -72,6 +80,7 @@ def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: segment_metadata = HnswParams.extract(metadata) return segment_metadata + @trace_method("LocalHnswSegment.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -81,12 +90,14 @@ def start(self) -> None: self._topic, self._write_records, start=seq_id ) + @trace_method("LocalHnswSegment.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: super().stop() if self._subscription: self._consumer.unsubscribe(self._subscription) + @trace_method("LocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -112,6 +123,7 @@ def get_vectors( return results + @trace_method("LocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL) @override def query_vectors( self, query: VectorQuery @@ -181,6 +193,7 @@ def max_seqid(self) -> SeqId: def count(self) -> int: return len(self._id_to_label) + @trace_method("LocalHnswSegment._init_index", OpenTelemetryGranularity.ALL) def _init_index(self, dimensionality: int) -> None: # more comments available at the source: https://github.com/nmslib/hnswlib @@ -198,6 +211,7 @@ def _init_index(self, dimensionality: int) -> None: self._index = index self._dimensionality = dimensionality + @trace_method("LocalHnswSegment._ensure_index", OpenTelemetryGranularity.ALL) def _ensure_index(self, n: int, dim: int) -> None: """Create or resize the index as necessary to accomodate N new records""" if not self._index: @@ -218,6 +232,7 @@ def _ensure_index(self, n: int, dim: int) -> None: ) index.resize_index(max(new_size, DEFAULT_CAPACITY)) + @trace_method("LocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL) def _apply_batch(self, batch: Batch) -> None: """Apply a batch of changes, as atomically as possible.""" deleted_ids = batch.get_deleted_ids() @@ -267,6 +282,7 @@ def _apply_batch(self, batch: Batch) -> None: # If that succeeds, finally the seq ID self._max_seq_id = batch.max_seq_id + @trace_method("LocalHnswSegment._write_records", OpenTelemetryGranularity.ALL) def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" if not self._running: diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index f8c74bd0fe7..4ab60a1725d 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -11,6 +11,11 @@ LocalHnswSegment, ) from chromadb.segment.impl.vector.brute_force_index import BruteForceIndex +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( EmbeddingRecord, Metadata, @@ -81,9 +86,13 @@ class PersistentLocalHnswSegment(LocalHnswSegment): _persist_directory: str _allow_reset: bool + _opentelemtry_client: OpenTelemetryClient + def __init__(self, system: System, segment: Segment): super().__init__(system, segment) + self._opentelemtry_client = system.require(OpenTelemetryClient) + self._params = PersistentHnswParams(segment["metadata"] or {}) self._batch_size = self._params.batch_size self._sync_threshold = self._params.sync_threshold @@ -138,6 +147,9 @@ def _get_storage_folder(self) -> str: folder = os.path.join(self._persist_directory, str(self._id)) return folder + @trace_method( + "PersistentLocalHnswSegment._init_index", OpenTelemetryGranularity.ALL + ) @override def _init_index(self, dimensionality: int) -> None: index = hnswlib.Index(space=self._params.space, dim=dimensionality) @@ -172,6 +184,7 @@ def _init_index(self, dimensionality: int) -> None: self._dimensionality = dimensionality self._index_initialized = True + @trace_method("PersistentLocalHnswSegment._persist", OpenTelemetryGranularity.ALL) def _persist(self) -> None: """Persist the index and data to disk""" index = cast(hnswlib.Index, self._index) @@ -193,6 +206,9 @@ def _persist(self) -> None: with open(self._get_metadata_file(), "wb") as metadata_file: pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) + @trace_method( + "PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL + ) @override def _apply_batch(self, batch: Batch) -> None: super()._apply_batch(batch) @@ -202,6 +218,9 @@ def _apply_batch(self, batch: Batch) -> None: ): self._persist() + @trace_method( + "PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL + ) @override def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" @@ -267,6 +286,9 @@ def count(self) -> int: - self._curr_batch.delete_count ) + @trace_method( + "PersistentLocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL + ) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -310,6 +332,9 @@ def get_vectors( return results # type: ignore ## Python can't cast List with Optional to List with VectorEmbeddingRecord + @trace_method( + "PersistentLocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL + ) @override def query_vectors( self, query: VectorQuery @@ -395,6 +420,9 @@ def query_vectors( results.append(curr_results) return results + @trace_method( + "PersistentLocalHnswSegment.reset_state", OpenTelemetryGranularity.ALL + ) @override def reset_state(self) -> None: if self._allow_reset: @@ -403,6 +431,7 @@ def reset_state(self) -> None: self.close_persistent_index() shutil.rmtree(data_path, ignore_errors=True) + @trace_method("PersistentLocalHnswSegment.delete", OpenTelemetryGranularity.ALL) @override def delete(self) -> None: data_path = self._get_storage_folder() diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index e92d16d63ba..4921392d3ee 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -35,7 +35,12 @@ from starlette.requests import Request import logging -from chromadb.telemetry import ServerContext, Telemetry +from chromadb.telemetry.product import ServerContext, ProductTelemetryClient +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) logger = logging.getLogger(__name__) @@ -102,9 +107,10 @@ def include_in_schema(path: str) -> bool: class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) - Telemetry.SERVER_CONTEXT = ServerContext.FASTAPI + ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI self._app = fastapi.FastAPI(debug=True) self._api: chromadb.api.API = chromadb.Client(settings) + self._opentelemetry_client = self._api.require(OpenTelemetryClient) self._app.middleware("http")(catch_exceptions_middleware) self._app.add_middleware( @@ -221,9 +227,11 @@ def heartbeat(self) -> Dict[str, int]: def version(self) -> str: return self._api.get_version() + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) def list_collections(self) -> Sequence[Collection]: return self._api.list_collections() + @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) def create_collection(self, collection: CreateCollection) -> Collection: return self._api.create_collection( name=collection.name, @@ -231,9 +239,11 @@ def create_collection(self, collection: CreateCollection) -> Collection: get_or_create=collection.get_or_create, ) + @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) def get_collection(self, collection_name: str) -> Collection: return self._api.get_collection(collection_name) + @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION) def update_collection( self, collection_id: str, collection: UpdateCollection ) -> None: @@ -243,9 +253,11 @@ def update_collection( new_metadata=collection.new_metadata, ) + @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) def delete_collection(self, collection_name: str) -> None: return self._api.delete_collection(collection_name) + @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) def add(self, collection_id: str, add: AddEmbedding) -> None: try: result = self._api._add( @@ -259,6 +271,7 @@ def add(self, collection_id: str, add: AddEmbedding) -> None: raise HTTPException(status_code=500, detail=str(e)) return result + @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) def update(self, collection_id: str, add: UpdateEmbedding) -> None: return self._api._update( ids=add.ids, @@ -268,6 +281,7 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None: metadatas=add.metadatas, ) + @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION) def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: return self._api._upsert( collection_id=_uuid(collection_id), @@ -277,6 +291,7 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: metadatas=upsert.metadatas, ) + @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION) def get(self, collection_id: str, get: GetEmbedding) -> GetResult: return self._api._get( collection_id=_uuid(collection_id), @@ -289,6 +304,7 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult: include=get.include, ) + @trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION) def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: return self._api._delete( where=delete.where, @@ -297,12 +313,14 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: where_document=delete.where_document, ) + @trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION) def count(self, collection_id: str) -> int: return self._api._count(_uuid(collection_id)) def reset(self) -> bool: return self._api.reset() + @trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION) def get_nearest_neighbors( self, collection_id: str, query: QueryEmbedding ) -> QueryResult: diff --git a/chromadb/telemetry/README.md b/chromadb/telemetry/README.md new file mode 100644 index 00000000000..c406074e41e --- /dev/null +++ b/chromadb/telemetry/README.md @@ -0,0 +1,10 @@ +# Telemetry + +This directory holds all the telemetry for Chroma. + +- `product/` contains anonymized product telemetry which we, Chroma, collect so we can + understand usage patterns. For more information, see https://docs.trychroma.com/telemetry. +- `opentelemetry/` contains all of the config for Chroma's [OpenTelemetry](https://opentelemetry.io/docs/instrumentation/python/getting-started/) + setup. These metrics are *not* sent back to Chroma -- anyone operating a Chroma instance + can use the OpenTelemetry metrics and traces to understand how their instance of Chroma + is behaving. \ No newline at end of file diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index d20b8e5d71c..e69de29bb2d 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -1,122 +0,0 @@ -from abc import abstractmethod -import os -from typing import Callable, ClassVar, Dict, Any -import uuid -import time -from threading import Event, Thread -import chromadb -from chromadb.config import Component -from pathlib import Path -from enum import Enum - -TELEMETRY_WHITELISTED_SETTINGS = [ - "chroma_api_impl", - "is_persistent", - "chroma_server_ssl_enabled", -] - - -class ServerContext(Enum): - NONE = "None" - FASTAPI = "FastAPI" - - -class TelemetryEvent: - max_batch_size: ClassVar[int] = 1 - batch_size: int - - def __init__(self, batch_size: int = 1): - self.batch_size = batch_size - - @property - def properties(self) -> Dict[str, Any]: - return self.__dict__ - - @property - def name(self) -> str: - return self.__class__.__name__ - - # A batch key is used to determine whether two events can be batched together. - # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be implemented. - # Otherwise they are ignored. - @property - def batch_key(self) -> str: - return self.name - - def batch(self, other: "TelemetryEvent") -> "TelemetryEvent": - raise NotImplementedError - - -class RepeatedTelemetry: - def __init__(self, interval: int, function: Callable[[], None]): - self.interval = interval - self.function = function - self.start = time.time() - self.event = Event() - self.thread = Thread(target=self._target) - self.thread.daemon = True - self.thread.start() - - def _target(self) -> None: - while not self.event.wait(self._time): - self.function() - - @property - def _time(self) -> float: - return self.interval - ((time.time() - self.start) % self.interval) - - def stop(self) -> None: - self.event.set() - self.thread.join() - - -class Telemetry(Component): - USER_ID_PATH = str(Path.home() / ".cache" / "chroma" / "telemetry_user_id") - UNKNOWN_USER_ID = "UNKNOWN" - SERVER_CONTEXT: ServerContext = ServerContext.NONE - _curr_user_id = None - - @abstractmethod - def capture(self, event: TelemetryEvent) -> None: - pass - - # Schedule a function that creates a TelemetryEvent to be called every `every_seconds` seconds. - def schedule_event_function( - self, event_function: Callable[..., TelemetryEvent], every_seconds: int - ) -> None: - RepeatedTelemetry(every_seconds, lambda: self.capture(event_function())) - - @property - def context(self) -> Dict[str, Any]: - chroma_version = chromadb.__version__ - settings = chromadb.get_settings() - telemetry_settings = {} - for whitelisted in TELEMETRY_WHITELISTED_SETTINGS: - telemetry_settings[whitelisted] = settings[whitelisted] - - self._context = { - "chroma_version": chroma_version, - "server_context": self.SERVER_CONTEXT.value, - **telemetry_settings, - } - return self._context - - @property - def user_id(self) -> str: - if self._curr_user_id: - return self._curr_user_id - - # File access may fail due to permissions or other reasons. We don't want to crash so we catch all exceptions. - try: - if not os.path.exists(self.USER_ID_PATH): - os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) - with open(self.USER_ID_PATH, "w") as f: - new_user_id = str(uuid.uuid4()) - f.write(new_user_id) - self._curr_user_id = new_user_id - else: - with open(self.USER_ID_PATH, "r") as f: - self._curr_user_id = f.read() - except Exception: - self._curr_user_id = self.UNKNOWN_USER_ID - return self._curr_user_id diff --git a/chromadb/telemetry/opentelemetry/__init__.py b/chromadb/telemetry/opentelemetry/__init__.py new file mode 100644 index 00000000000..0840871bcae --- /dev/null +++ b/chromadb/telemetry/opentelemetry/__init__.py @@ -0,0 +1,128 @@ +from functools import wraps +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union + +from opentelemetry import trace +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, +) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + +from chromadb.config import Component +from chromadb.config import System + + +class OpenTelemetryGranularity(Enum): + """The granularity of the OpenTelemetry spans.""" + + NONE = "none" + """No spans are emitted.""" + + OPERATION = "operation" + """Spans are emitted for each operation.""" + + OPERATION_AND_SEGMENT = "operation_and_segment" + """Spans are emitted for each operation and segment.""" + + ALL = "all" + """Spans are emitted for almost every method call.""" + + # Greater is more restrictive. So "all" < "operation" (and everything else), + # "none" > everything. + def __lt__(self, other: Any) -> bool: + """Compare two granularities.""" + order = [ + OpenTelemetryGranularity.ALL, + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + OpenTelemetryGranularity.OPERATION, + OpenTelemetryGranularity.NONE, + ] + return order.index(self) < order.index(other) + + +class OpenTelemetryClient(Component): + def __init__(self, system: System): + super().__init__(system) + otel_init( + system.settings.chroma_otel_service_name, + system.settings.chroma_otel_collection_endpoint, + system.settings.chroma_otel_collection_headers, + OpenTelemetryGranularity(system.settings.chroma_otel_granularity), + ) + + +tracer: Optional[trace.Tracer] = None +granularity: OpenTelemetryGranularity = OpenTelemetryGranularity("none") + + +def otel_init( + otel_service_name: Optional[str], + otel_collection_endpoint: Optional[str], + otel_collection_headers: Optional[Dict[str, str]], + otel_granularity: OpenTelemetryGranularity, +) -> None: + """Initializes module-level state for OpenTelemetry. + + Parameters match the environment variables which configure OTel as documented + at https://docs.trychroma.com/observability. + - otel_service_name: The name of the service for OTel tagging and aggregation. + - otel_collection_endpoint: The endpoint to which OTel spans are sent (e.g. api.honeycomb.com). + - otel_collection_headers: The headers to send with OTel spans (e.g. {"x-honeycomb-team": "abc123"}). + - otel_granularity: The granularity of the spans to emit. + """ + if otel_granularity == OpenTelemetryGranularity.NONE: + return + resource = Resource(attributes={SERVICE_NAME: str(otel_service_name)}) + provider = TracerProvider(resource=resource) + provider.add_span_processor( + BatchSpanProcessor( + # TODO: we may eventually want to make this configurable. + OTLPSpanExporter( + endpoint=str(otel_collection_endpoint), + headers=otel_collection_headers, + ) + ) + ) + trace.set_tracer_provider(provider) + + global tracer, granularity + tracer = trace.get_tracer(__name__) + granularity = otel_granularity + + +def trace_method( + trace_name: str, + trace_granularity: OpenTelemetryGranularity, + attributes: Dict[str, Union[str, bool, float, int]] = {}, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """A decorator that traces a method.""" + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + global tracer, granularity, _transform_attributes + if trace_granularity < granularity: + return f(*args, **kwargs) + if not tracer: + return + with tracer.start_as_current_span(trace_name, attributes=attributes): + return f(*args, **kwargs) + + return wrapper + + return decorator + + +def add_attributes_to_current_span( + attributes: Dict[str, Union[str, bool, float, int]] +) -> None: + """Add attributes to the current span.""" + global tracer, granularity, _transform_attributes + if granularity == OpenTelemetryGranularity.NONE: + return + if not tracer: + return + span = trace.get_current_span() + span.set_attributes(_transform_attributes(attributes)) # type: ignore diff --git a/chromadb/telemetry/product/__init__.py b/chromadb/telemetry/product/__init__.py new file mode 100644 index 00000000000..a6fd0d7ad87 --- /dev/null +++ b/chromadb/telemetry/product/__init__.py @@ -0,0 +1,93 @@ +from abc import abstractmethod +import os +from typing import ClassVar, Dict, Any +import uuid +import chromadb +from chromadb.config import Component +from pathlib import Path +from enum import Enum + +TELEMETRY_WHITELISTED_SETTINGS = [ + "chroma_api_impl", + "is_persistent", + "chroma_server_ssl_enabled", +] + + +class ServerContext(Enum): + NONE = "None" + FASTAPI = "FastAPI" + + +class ProductTelemetryEvent: + max_batch_size: ClassVar[int] = 1 + batch_size: int + + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size + + @property + def properties(self) -> Dict[str, Any]: + return self.__dict__ + + @property + def name(self) -> str: + return self.__class__.__name__ + + # A batch key is used to determine whether two events can be batched together. + # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be + # implemented. + # Otherwise they are ignored. + @property + def batch_key(self) -> str: + return self.name + + def batch(self, other: "ProductTelemetryEvent") -> "ProductTelemetryEvent": + raise NotImplementedError + + +class ProductTelemetryClient(Component): + USER_ID_PATH = str(Path.home() / ".cache" / "chroma" / "telemetry_user_id") + UNKNOWN_USER_ID = "UNKNOWN" + SERVER_CONTEXT: ServerContext = ServerContext.NONE + _curr_user_id = None + + @abstractmethod + def capture(self, event: ProductTelemetryEvent) -> None: + pass + + @property + def context(self) -> Dict[str, Any]: + chroma_version = chromadb.__version__ + settings = chromadb.get_settings() + telemetry_settings = {} + for whitelisted in TELEMETRY_WHITELISTED_SETTINGS: + telemetry_settings[whitelisted] = settings[whitelisted] + + self._context = { + "chroma_version": chroma_version, + "server_context": self.SERVER_CONTEXT.value, + **telemetry_settings, + } + return self._context + + @property + def user_id(self) -> str: + if self._curr_user_id: + return self._curr_user_id + + # File access may fail due to permissions or other reasons. We don't want to + # crash so we catch all exceptions. + try: + if not os.path.exists(self.USER_ID_PATH): + os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) + with open(self.USER_ID_PATH, "w") as f: + new_user_id = str(uuid.uuid4()) + f.write(new_user_id) + self._curr_user_id = new_user_id + else: + with open(self.USER_ID_PATH, "r") as f: + self._curr_user_id = f.read() + except Exception: + self._curr_user_id = self.UNKNOWN_USER_ID + return self._curr_user_id diff --git a/chromadb/telemetry/events.py b/chromadb/telemetry/product/events.py similarity index 89% rename from chromadb/telemetry/events.py rename to chromadb/telemetry/product/events.py index e662cd85fa7..e5f6bc688c1 100644 --- a/chromadb/telemetry/events.py +++ b/chromadb/telemetry/product/events.py @@ -1,14 +1,14 @@ from typing import cast, ClassVar -from chromadb.telemetry import TelemetryEvent +from chromadb.telemetry.product import ProductTelemetryEvent from chromadb.utils.embedding_functions import get_builtins -class ClientStartEvent(TelemetryEvent): +class ClientStartEvent(ProductTelemetryEvent): def __init__(self) -> None: super().__init__() -class ClientCreateCollectionEvent(TelemetryEvent): +class ClientCreateCollectionEvent(ProductTelemetryEvent): collection_uuid: str embedding_function: str @@ -25,7 +25,7 @@ def __init__(self, collection_uuid: str, embedding_function: str): ) -class CollectionAddEvent(TelemetryEvent): +class CollectionAddEvent(ProductTelemetryEvent): max_batch_size: ClassVar[int] = 100 batch_size: int collection_uuid: str @@ -52,7 +52,7 @@ def __init__( def batch_key(self) -> str: return self.collection_uuid + self.name - def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": + def batch(self, other: "ProductTelemetryEvent") -> "CollectionAddEvent": if not self.batch_key == other.batch_key: raise ValueError("Cannot batch events") other = cast(CollectionAddEvent, other) @@ -66,7 +66,7 @@ def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": ) -class CollectionUpdateEvent(TelemetryEvent): +class CollectionUpdateEvent(ProductTelemetryEvent): collection_uuid: str update_amount: int with_embeddings: int @@ -89,7 +89,7 @@ def __init__( self.with_documents = with_documents -class CollectionQueryEvent(TelemetryEvent): +class CollectionQueryEvent(ProductTelemetryEvent): max_batch_size: ClassVar[int] = 20 batch_size: int collection_uuid: str @@ -128,7 +128,7 @@ def __init__( def batch_key(self) -> str: return self.collection_uuid + self.name - def batch(self, other: "TelemetryEvent") -> "CollectionQueryEvent": + def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent": if not self.batch_key == other.batch_key: raise ValueError("Cannot batch events") other = cast(CollectionQueryEvent, other) @@ -146,7 +146,7 @@ def batch(self, other: "TelemetryEvent") -> "CollectionQueryEvent": ) -class CollectionGetEvent(TelemetryEvent): +class CollectionGetEvent(ProductTelemetryEvent): collection_uuid: str ids_count: int limit: int @@ -169,7 +169,7 @@ def __init__( self.include_documents = include_documents -class CollectionDeleteEvent(TelemetryEvent): +class CollectionDeleteEvent(ProductTelemetryEvent): collection_uuid: str delete_amount: int diff --git a/chromadb/telemetry/posthog.py b/chromadb/telemetry/product/posthog.py similarity index 77% rename from chromadb/telemetry/posthog.py rename to chromadb/telemetry/product/posthog.py index 21676b9fbe7..05c46b07256 100644 --- a/chromadb/telemetry/posthog.py +++ b/chromadb/telemetry/product/posthog.py @@ -3,19 +3,23 @@ import sys from typing import Any, Dict, Set from chromadb.config import System -from chromadb.telemetry import Telemetry, TelemetryEvent +from chromadb.telemetry.product import ( + ProductTelemetryClient, + ProductTelemetryEvent, +) from overrides import override logger = logging.getLogger(__name__) -class Posthog(Telemetry): +class Posthog(ProductTelemetryClient): def __init__(self, system: System): if not system.settings.anonymized_telemetry or "pytest" in sys.modules: posthog.disabled = True else: logger.info( - "Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information." + "Anonymized telemetry enabled. See \ + https://docs.trychroma.com/telemetry for more information." ) posthog.project_api_key = "phc_YeUxaojbKk5KPi8hNlx1bBKHzuZ4FDtl67kH1blv8Bh" @@ -23,13 +27,13 @@ def __init__(self, system: System): # Silence posthog's logging posthog_logger.disabled = True - self.batched_events: Dict[str, TelemetryEvent] = {} + self.batched_events: Dict[str, ProductTelemetryEvent] = {} self.seen_event_types: Set[Any] = set() super().__init__(system) @override - def capture(self, event: TelemetryEvent) -> None: + def capture(self, event: ProductTelemetryEvent) -> None: if event.max_batch_size == 1 or event.batch_key not in self.seen_event_types: self.seen_event_types.add(event.batch_key) self._direct_capture(event) @@ -44,7 +48,7 @@ def capture(self, event: TelemetryEvent) -> None: self._direct_capture(batched_event) del self.batched_events[batch_key] - def _direct_capture(self, event: TelemetryEvent) -> None: + def _direct_capture(self, event: ProductTelemetryEvent) -> None: try: posthog.capture( self.user_id, diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 529fe02dda7..3bd83231b32 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -46,7 +46,9 @@ def _bool_to_int(metadata: Dict[str, Any]) -> Dict[str, Any]: def _patch_boolean_metadata( - collection: strategies.Collection, embeddings: strategies.RecordSet + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, ) -> None: # Since the old version does not support boolean value metadata, we will convert # boolean value metadata to int @@ -64,15 +66,29 @@ def _patch_boolean_metadata( _bool_to_int(metadata) +def _patch_telemetry_client( + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, +) -> None: + # chroma 0.4.14 added OpenTelemetry, distinct from ProductTelemetry. Before 0.4.14 + # ProductTelemetry was simply called Telemetry. + settings.chroma_telemetry_impl = "chromadb.telemetry.posthog.Posthog" + + version_patches: List[ - Tuple[str, Callable[[strategies.Collection, strategies.RecordSet], None]] + Tuple[str, Callable[[strategies.Collection, strategies.RecordSet, Settings], None]] ] = [ ("0.4.3", _patch_boolean_metadata), + ("0.4.14", _patch_telemetry_client), ] def patch_for_version( - version: str, collection: strategies.Collection, embeddings: strategies.RecordSet + version: str, + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, ) -> None: """Override aspects of the collection and embeddings, before testing, to account for breaking changes in old versions.""" @@ -81,7 +97,7 @@ def patch_for_version( if packaging_version.Version(version) <= packaging_version.Version( patch_version ): - patch(collection, embeddings) + patch(collection, embeddings, settings) def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: @@ -261,7 +277,7 @@ def test_cycle_versions( for m in embeddings_strategy["metadatas"] ] - patch_for_version(version, collection_strategy, embeddings_strategy) + patch_for_version(version, collection_strategy, embeddings_strategy, settings) # Can't pickle a function, and we won't need them collection_strategy.embedding_function = None diff --git a/docker-compose.yml b/docker-compose.yml index 93581dd23c7..3bc5d5a9404 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,6 +22,10 @@ services: - CHROMA_SERVER_AUTH_CREDENTIALS=${CHROMA_SERVER_AUTH_CREDENTIALS} - CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=${CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER} - PERSIST_DIRECTORY=${PERSIST_DIRECTORY:-/chroma/chroma} + - CHROMA_OTEL_EXPORTER_ENDPOINT=${CHROMA_OTEL_EXPORTER_ENDPOINT} + - CHROMA_OTEL_EXPORTER_HEADERS=${CHROMA_OTEL_EXPORTER_HEADERS} + - CHROMA_OTEL_SERVICE_NAME=${CHROMA_OTEL_SERVICE_NAME} + - CHROMA_OTEL_GRANULARITY=${CHROMA_OTEL_GRANULARITY} ports: - 8000:8000 networks: diff --git a/requirements.txt b/requirements.txt index 7b60e6101bb..f3093341f14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,9 @@ kubernetes>=28.1.0 numpy==1.21.6; python_version < '3.8' numpy>=1.22.4; python_version >= '3.8' onnxruntime>=1.14.1 +opentelemetry-api>=1.2.0 +opentelemetry-exporter-otlp-proto-grpc>=1.2.0 +opentelemetry-sdk>=1.2.0 overrides==7.3.1 posthog==2.4.0 pulsar-client==3.1.0 diff --git a/server.htpasswd b/server.htpasswd new file mode 100644 index 00000000000..77f277a399b --- /dev/null +++ b/server.htpasswd @@ -0,0 +1 @@ +admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS