From a56c3e0b78adf224f084e3f8e6b43408e8281c0c Mon Sep 17 00:00:00 2001 From: Simon Suo Date: Thu, 14 Nov 2024 08:15:47 -0800 Subject: [PATCH] modernize the llama-cloud index and retriever for the latest versions of the llama-cloud api --- .../indices/managed/llama_cloud/api_utils.py | 140 +++++----- .../indices/managed/llama_cloud/base.py | 229 ++++++++-------- .../indices/managed/llama_cloud/retriever.py | 109 ++++---- .../pyproject.toml | 2 +- .../tests/test_indices_managed_llama_cloud.py | 254 +++++++++++++----- 5 files changed, 425 insertions(+), 309 deletions(-) diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/api_utils.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/api_utils.py index 7d61ce9d27ae9..31e5b354d2250 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/api_utils.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/api_utils.py @@ -1,80 +1,98 @@ -from typing import List, Optional +from typing import Optional, Tuple from llama_cloud import ( - ConfigurableTransformationNames, - ConfiguredTransformationItem, - PipelineCreate, + AutoTransformConfig, + Pipeline, + PipelineCreateEmbeddingConfig, + PipelineCreateEmbeddingConfig_OpenaiEmbedding, + PipelineCreateTransformConfig, PipelineType, - ProjectCreate, + Project, ) from llama_cloud.client import LlamaCloud -from llama_index.core.constants import ( - DEFAULT_PROJECT_NAME, -) -from llama_index.core.ingestion.transformations import ( - ConfiguredTransformation, -) -from llama_index.core.node_parser import SentenceSplitter -from llama_index.core.readers.base import ReaderConfig -from llama_index.core.schema import BaseNode, TransformComponent - -def default_transformations() -> List[TransformComponent]: - """Default transformations.""" +def default_embedding_config() -> PipelineCreateEmbeddingConfig: from llama_index.embeddings.openai import OpenAIEmbedding # pants: no-infer-dep - return [ - SentenceSplitter(), - OpenAIEmbedding(), - ] + return PipelineCreateEmbeddingConfig_OpenaiEmbedding( + type="OPENAI_EMBEDDING", + component=OpenAIEmbedding(), + ) -def get_pipeline_create( - pipeline_name: str, - client: LlamaCloud, - pipeline_type: PipelineType, - project_name: str = DEFAULT_PROJECT_NAME, - transformations: Optional[List[TransformComponent]] = None, - readers: Optional[List[ReaderConfig]] = None, - input_nodes: Optional[List[BaseNode]] = None, -) -> PipelineCreate: - """Get a pipeline create object.""" - transformations = transformations or [] +def default_transform_config() -> PipelineCreateTransformConfig: + return AutoTransformConfig() - configured_transformations: List[ConfiguredTransformation] = [] - for transformation in transformations: - try: - configured_transformations.append( - ConfiguredTransformation.from_component(transformation) + +def resolve_project( + client: LlamaCloud, + project_name: Optional[str], + project_id: Optional[str], + organization_id: Optional[str], +) -> Project: + if project_id is not None: + return client.projects.get_project(project_id=project_id) + else: + projects = client.projects.list_projects( + project_name=project_name, organization_id=organization_id + ) + if len(projects) == 0: + raise ValueError(f"No project found with name {project_name}") + elif len(projects) > 1: + raise ValueError( + f"Multiple projects found with name {project_name}. Please specify organization_id." ) - except ValueError: - raise ValueError(f"Unsupported transformation: {type(transformation)}") + return projects[0] + - configured_transformation_items: List[ConfiguredTransformationItem] = [] - for item in configured_transformations: - name = ConfigurableTransformationNames[ - item.configurable_transformation_type.name - ] - configured_transformation_items.append( - ConfiguredTransformationItem( - transformation_name=name, - component=item.component, - configurable_transformation_type=item.configurable_transformation_type.name, +def resolve_pipeline( + client: LlamaCloud, + pipeline_id: Optional[str], + project: Optional[Project], + pipeline_name: Optional[str], +) -> Pipeline: + if pipeline_id is not None: + return client.pipelines.get_pipeline(pipeline_id=pipeline_id) + else: + pipelines = client.pipelines.search_pipelines( + project_id=project.id, + pipeline_name=pipeline_name, + pipeline_type=PipelineType.MANAGED.value, + ) + if len(pipelines) == 0: + raise ValueError( + f"Unknown index name {pipeline_name}. Please confirm an index with this name exists." ) + elif len(pipelines) > 1: + raise ValueError( + f"Multiple pipelines found with name {pipeline_name} in project {project.name}" + ) + return pipelines[0] + + +def resolve_project_and_pipeline( + client: LlamaCloud, + pipeline_name: Optional[str], + pipeline_id: Optional[str], + project_name: Optional[str], + project_id: Optional[str], + organization_id: Optional[str], +) -> Tuple[Project, Pipeline]: + # resolve pipeline by ID + if pipeline_id is not None: + pipeline = resolve_pipeline( + client, pipeline_id=pipeline_id, project=None, pipeline_name=None ) + project_id = pipeline.project_id - # remove callback manager - configured_transformation_items[-1].component.pop("callback_manager", None) # type: ignore + # resolve project + project = resolve_project(client, project_name, project_id, organization_id) - project = client.projects.upsert_project(request=ProjectCreate(name=project_name)) - assert project.id is not None, "Project ID should not be None" + # resolve pipeline by name + if pipeline_id is None: + pipeline = resolve_pipeline( + client, pipeline_id=None, project=project, pipeline_name=pipeline_name + ) - # upload - return PipelineCreate( - name=pipeline_name, - configured_transformations=configured_transformation_items, - pipeline_type=pipeline_type, - # we are uploading document dicrectly, so we don't need llama parse - llama_parse_enabled=False, - ) + return project, pipeline diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/base.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/base.py index 49497582d9f78..139e164b6c478 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/base.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/base.py @@ -12,6 +12,9 @@ from urllib.parse import quote_plus from llama_cloud import ( + PipelineCreate, + PipelineCreateEmbeddingConfig, + PipelineCreateTransformConfig, PipelineType, ProjectCreate, ManagedIngestionStatus, @@ -31,10 +34,6 @@ get_aclient, get_client, ) -from llama_index.indices.managed.llama_cloud.api_utils import ( - default_transformations, - get_pipeline_create, -) from llama_index.core.schema import BaseNode, Document, TransformComponent from llama_index.core.settings import Settings from typing import Any, Dict, List, Optional, Sequence, Type @@ -47,41 +46,92 @@ Settings, ) from llama_index.core.storage.docstore.types import RefDocInfo +from llama_index.indices.managed.llama_cloud.api_utils import ( + default_embedding_config, + default_transform_config, + resolve_project_and_pipeline, +) import logging logger = logging.getLogger(__name__) class LlamaCloudIndex(BaseManagedIndex): - """LlamaIndex Platform Index.""" + """ + A managed index that stores documents in LlamaCloud. + + There are two main ways to use this index: + + 1. Connect to an existing LlamaCloud index: + ```python + # Connect using index ID (same as pipeline ID) + index = LlamaCloudIndex(id="") + + # Or connect using index name + index = LlamaCloudIndex( + name="my_index", + project_name="my_project", + organization_id="my_org_id" + ) + ``` + + 2. Create a new index with documents: + ```python + documents = [Document(...), Document(...)] + index = LlamaCloudIndex.from_documents( + documents, + name="my_new_index", + project_name="my_project", + organization_id="my_org_id" + ) + ``` + + The index supports standard operations like retrieval and querying + through the as_query_engine() and as_retriever() methods. + """ def __init__( self, - name: str, - nodes: Optional[List[BaseNode]] = None, - transformations: Optional[List[TransformComponent]] = None, - timeout: int = 60, + # index identifier + name: Optional[str] = None, + pipeline_id: Optional[str] = None, + index_id: Optional[str] = None, # alias for pipeline_id + id: Optional[str] = None, # alias for pipeline_id + # project identifier + project_id: Optional[str] = None, project_name: str = DEFAULT_PROJECT_NAME, organization_id: Optional[str] = None, + # connection params api_key: Optional[str] = None, base_url: Optional[str] = None, app_url: Optional[str] = None, - show_progress: bool = False, - callback_manager: Optional[CallbackManager] = None, + timeout: int = 60, httpx_client: Optional[httpx.Client] = None, async_httpx_client: Optional[httpx.AsyncClient] = None, + # misc + show_progress: bool = False, + callback_manager: Optional[CallbackManager] = None, + # deprecated + nodes: Optional[List[BaseNode]] = None, + transformations: Optional[List[TransformComponent]] = None, **kwargs: Any, ) -> None: """Initialize the Platform Index.""" - self.name = name - self.project_name = project_name - self.organization_id = organization_id - self.transformations = transformations or [] + if sum([bool(id), bool(index_id), bool(pipeline_id), bool(name)]) != 1: + raise ValueError( + "Exactly one of `name`, `id`, `pipeline_id` or `index_id` must be provided to identify the index." + ) if nodes is not None: # TODO: How to handle uploading nodes without running transforms on them? raise ValueError("LlamaCloudIndex does not support nodes on initialization") + if transformations is not None: + raise ValueError( + "Setting transformations is deprecated for LlamaCloudIndex, please use the `transform_config` and `embedding_config` parameters instead." + ) + + # initialize clients self._httpx_client = httpx_client self._async_httpx_client = async_httpx_client self._client = get_client(api_key, base_url, app_url, timeout, httpx_client) @@ -89,6 +139,15 @@ def __init__( api_key, base_url, app_url, timeout, async_httpx_client ) + self.organization_id = organization_id + pipeline_id = id or index_id or pipeline_id + + self.project, self.pipeline = resolve_project_and_pipeline( + self._client, name, pipeline_id, project_name, project_id, organization_id + ) + self.name = self.pipeline.name + self.project_name = self.project.name + self._api_key = api_key self._base_url = base_url self._app_url = app_url @@ -97,27 +156,24 @@ def __init__( self._service_context = None self._callback_manager = callback_manager or Settings.callback_manager - def _wait_for_pipeline_ingestion( + def wait_for_completion( self, verbose: bool = False, raise_on_partial_success: bool = False, ) -> None: - pipeline_id = self._get_pipeline_id() - client = self._client - if verbose: print("Syncing pipeline: ", end="") is_done = False while not is_done: - status = client.pipelines.get_pipeline_status( - pipeline_id=pipeline_id + status = self._client.pipelines.get_pipeline_status( + pipeline_id=self.pipeline.id ).status if status == ManagedIngestionStatus.ERROR or ( raise_on_partial_success and status == ManagedIngestionStatus.PARTIAL_SUCCESS ): - raise ValueError(f"Pipeline ingestion failed for {pipeline_id}") + raise ValueError(f"Pipeline ingestion failed for {self.pipeline.id}") elif status in [ ManagedIngestionStatus.NOT_STARTED, ManagedIngestionStatus.IN_PROGRESS, @@ -136,16 +192,14 @@ def _wait_for_file_ingestion( verbose: bool = False, raise_on_error: bool = False, ) -> None: - pipeline_id = self._get_pipeline_id() - client = self._client if verbose: print("Loading file: ", end="") # wait until the file is loaded is_done = False while not is_done: - status = client.pipelines.get_pipeline_file_status( - pipeline_id=pipeline_id, file_id=file_id + status = self._client.pipelines.get_pipeline_file_status( + pipeline_id=self.pipeline.id, file_id=file_id ).status if status == ManagedIngestionStatus.ERROR: if verbose: @@ -170,8 +224,6 @@ def _wait_for_documents_ingestion( verbose: bool = False, raise_on_error: bool = False, ) -> None: - pipeline_id = self._get_pipeline_id() - client = self._client if verbose: print("Loading data: ", end="") @@ -181,8 +233,9 @@ def _wait_for_documents_ingestion( docs_to_remove = set() for doc in pending_docs: # we have to quote the doc id twice because it is used as a path parameter - status = client.pipelines.get_pipeline_document_status( - pipeline_id=pipeline_id, document_id=quote_plus(quote_plus(doc)) + status = self._client.pipelines.get_pipeline_document_status( + pipeline_id=self.pipeline.id, + document_id=quote_plus(quote_plus(doc)), ) if status in [ ManagedIngestionStatus.NOT_STARTED, @@ -210,60 +263,13 @@ def _wait_for_documents_ingestion( # we have to wait for pipeline ingestion because retrieval only works when # the pipeline status is success - self._wait_for_pipeline_ingestion(verbose, raise_on_error) - - def _get_project_id(self) -> str: - projects = self._client.projects.list_projects( - organization_id=self.organization_id, - project_name=self.project_name, - ) - if len(projects) == 0: - raise ValueError( - f"Unknown project name {self.project_name}. Please confirm a " - "managed project with this name exists." - ) - elif len(projects) > 1: - raise ValueError( - f"Multiple projects found with name {self.project_name}. Please specify organization_id." - ) - project = projects[0] - - if project.id is None: - raise ValueError(f"No project found with name {self.project_name}") - - return project.id - - def _get_pipeline_id(self) -> str: - project_id = self._get_project_id() - pipelines = self._client.pipelines.search_pipelines( - project_id=project_id, - pipeline_name=self.name, - pipeline_type=PipelineType.MANAGED.value, - ) - if len(pipelines) == 0: - raise ValueError( - f"Unknown index name {self.name}. Please confirm a " - "managed index with this name exists." - ) - elif len(pipelines) > 1: - raise ValueError( - f"Multiple pipelines found with name {self.name} in project {self.project_name}" - ) - pipeline = pipelines[0] - - if pipeline.id is None: - raise ValueError( - f"No pipeline found with name {self.name} in project {self.project_name}" - ) - - return pipeline.id + self.wait_for_completion(verbose, raise_on_error) @classmethod def from_documents( # type: ignore cls: Type["LlamaCloudIndex"], documents: List[Document], name: str, - transformations: Optional[List[TransformComponent]] = None, project_name: str = DEFAULT_PROJECT_NAME, organization_id: Optional[str] = None, api_key: Optional[str] = None, @@ -272,21 +278,23 @@ def from_documents( # type: ignore timeout: int = 60, verbose: bool = False, raise_on_error: bool = False, + # ingestion configs + embedding_config: Optional[PipelineCreateEmbeddingConfig] = None, + transform_config: Optional[PipelineCreateTransformConfig] = None, + # deprecated + transformations: Optional[List[TransformComponent]] = None, **kwargs: Any, ) -> "LlamaCloudIndex": """Build a LlamaCloud managed index from a sequence of documents.""" app_url = app_url or os.environ.get("LLAMA_CLOUD_APP_URL", DEFAULT_APP_URL) client = get_client(api_key, base_url, app_url, timeout) - pipeline_create = get_pipeline_create( - name, - client, - PipelineType.MANAGED, - project_name=project_name, - transformations=transformations or default_transformations(), - input_nodes=documents, - ) + if transformations is not None: + raise ValueError( + "Setting transformations is deprecated for LlamaCloudIndex" + ) + # create project if it doesn't exist project = client.projects.upsert_project( organization_id=organization_id, request=ProjectCreate(name=project_name) ) @@ -295,6 +303,15 @@ def from_documents( # type: ignore if verbose: print(f"Created project {project.id} with name {project.name}") + # create pipeline + pipeline_create = PipelineCreate( + name=name, + pipeline_type=PipelineType.MANAGED, + embedding_config=embedding_config or default_embedding_config(), + transform_config=transform_config or default_transform_config(), + # we are uploading document directly, so we don't need llama parse + llama_parse_enabled=False, + ) pipeline = client.pipelines.upsert_pipeline( project_id=project.id, request=pipeline_create ) @@ -305,8 +322,7 @@ def from_documents( # type: ignore index = cls( name, - transformations=transformations, - project_name=project_name, + project_name=project.name, organization_id=project.organization_id, api_key=api_key, base_url=base_url, @@ -329,6 +345,7 @@ def from_documents( # type: ignore for doc in documents ], ) + doc_ids = [doc.id for doc in upserted_documents] index._wait_for_documents_ingestion( doc_ids, verbose=verbose, raise_on_error=raise_on_error @@ -350,9 +367,9 @@ def as_retriever(self, **kwargs: Any) -> BaseRetriever: dense_similarity_top_k = similarity_top_k return LlamaCloudRetriever( - self.name, - project_name=self.project_name, - api_key=self._api_key, + project_id=self.project.id, + pipeline_id=self.pipeline.id, + aoi_key=self._api_key, base_url=self._base_url, app_url=self._app_url, timeout=self._timeout, @@ -374,7 +391,7 @@ def as_query_engine(self, **kwargs: Any) -> BaseQueryEngine: @property def ref_doc_info(self, batch_size: int = 100) -> Dict[str, RefDocInfo]: """Retrieve a dict mapping of ingested documents and their metadata. The nodes list is empty.""" - pipeline_id = self._get_pipeline_id() + pipeline_id = self.pipeline.id pipeline_documents: List[CloudDocument] = [] skip = 0 limit = batch_size @@ -398,9 +415,8 @@ def insert( ) -> None: """Insert a document.""" with self._callback_manager.as_trace("insert"): - pipeline_id = self._get_pipeline_id() upserted_documents = self._client.pipelines.create_batch_pipeline_documents( - pipeline_id=pipeline_id, + pipeline_id=self.pipeline.id, request=[ CloudDocumentCreate( text=document.text, @@ -421,9 +437,8 @@ def update_ref_doc( ) -> None: """Upserts a document and its corresponding nodes.""" with self._callback_manager.as_trace("update"): - pipeline_id = self._get_pipeline_id() upserted_documents = self._client.pipelines.upsert_batch_pipeline_documents( - pipeline_id=pipeline_id, + pipeline_id=self.pipeline.id, request=[ CloudDocumentCreate( text=document.text, @@ -444,9 +459,8 @@ def refresh_ref_docs( ) -> List[bool]: """Refresh an index with documents that have changed.""" with self._callback_manager.as_trace("refresh"): - pipeline_id = self._get_pipeline_id() upserted_documents = self._client.pipelines.upsert_batch_pipeline_documents( - pipeline_id=pipeline_id, + pipeline_id=self.pipeline.id, request=[ CloudDocumentCreate( text=doc.text, @@ -473,11 +487,11 @@ def delete_ref_doc( **delete_kwargs: Any, ) -> None: """Delete a document and its nodes by using ref_doc_id.""" - pipeline_id = self._get_pipeline_id() try: # we have to quote the ref_doc_id twice because it is used as a path parameter self._client.pipelines.delete_pipeline_document( - pipeline_id=pipeline_id, document_id=quote_plus(quote_plus(ref_doc_id)) + pipeline_id=self.pipeline.id, + document_id=quote_plus(quote_plus(ref_doc_id)), ) except ApiError as e: if e.status_code == 404 and not raise_if_not_found: @@ -486,9 +500,7 @@ def delete_ref_doc( raise # we have to wait for the pipeline instead of the document, because the document is already deleted - self._wait_for_pipeline_ingestion( - verbose=verbose, raise_on_partial_success=False - ) + self.wait_for_completion(verbose=verbose, raise_on_partial_success=False) def upload_file( self, @@ -501,17 +513,16 @@ def upload_file( """Upload a file to the index.""" with open(file_path, "rb") as f: file = self._client.files.upload_file( - project_id=self._get_project_id(), upload_file=f + project_id=self.project.id, upload_file=f ) if verbose: print(f"Uploaded file {file.id} with name {file.name}") if resource_info: self._client.files.update(file_id=file.id, request=resource_info) # Add file to pipeline - pipeline_id = self._get_pipeline_id() pipeline_file_create = PipelineFileCreate(file_id=file.id) self._client.pipelines.add_files_to_pipeline( - pipeline_id=pipeline_id, request=[pipeline_file_create] + pipeline_id=self.pipeline.id, request=[pipeline_file_create] ) if wait_for_ingestion: @@ -534,7 +545,7 @@ def upload_file_from_url( ) -> str: """Upload a file from a URL to the index.""" file = self._client.files.upload_file_from_url( - project_id=self._get_project_id(), + project_id=self.project.id, name=file_name, url=url, proxy_url=proxy_url, @@ -544,11 +555,11 @@ def upload_file_from_url( ) if verbose: print(f"Uploaded file {file.id} with ID {file.id}") + # Add file to pipeline - pipeline_id = self._get_pipeline_id() pipeline_file_create = PipelineFileCreate(file_id=file.id) self._client.pipelines.add_files_to_pipeline( - pipeline_id=pipeline_id, request=[pipeline_file_create] + pipeline_id=self.pipeline.id, request=[pipeline_file_create] ) if wait_for_ingestion: diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py index b2efc73ded101..164680403ead7 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py @@ -1,7 +1,11 @@ from typing import Any, List, Optional -from llama_cloud import TextNodeWithScore, PageScreenshotNodeWithScore -from llama_cloud.resources.pipelines.client import OMIT, PipelineType +import httpx +from llama_cloud import ( + TextNodeWithScore, + PageScreenshotNodeWithScore, +) +from llama_cloud.resources.pipelines.client import OMIT from llama_cloud.client import LlamaCloud, AsyncLlamaCloud from llama_cloud.core import remove_none_from_dict from llama_cloud.core.api_error import ApiError @@ -14,6 +18,9 @@ import asyncio import urllib.parse import base64 +from llama_index.indices.managed.llama_cloud.api_utils import ( + resolve_project_and_pipeline, +) def _get_page_screenshot( @@ -21,7 +28,7 @@ def _get_page_screenshot( ) -> str: """Get the page screenshot.""" # TODO: this currently uses requests, should be replaced with the client - _response = client._client_wrapper.httpx_client.request( + _response = client._client_wrapper.tpx_client.request( "GET", urllib.parse.urljoin( f"{client._client_wrapper.get_base_url()}/", @@ -61,38 +68,56 @@ async def _aget_page_screenshot( class LlamaCloudRetriever(BaseRetriever): def __init__( self, - name: str, - project_name: str = DEFAULT_PROJECT_NAME, + # index identifier + name: Optional[str] = None, + index_id: Optional[str] = None, # alias for pipeline_id + id: Optional[str] = None, # alias for pipeline_id + pipeline_id: Optional[str] = None, + # project identifier + project_name: Optional[str] = DEFAULT_PROJECT_NAME, + project_id: Optional[str] = None, organization_id: Optional[str] = None, + # connection params + api_key: Optional[str] = None, + base_url: Optional[str] = None, + app_url: Optional[str] = None, + timeout: int = 60, + httpx_client: Optional[httpx.Client] = None, + async_httpx_client: Optional[httpx.AsyncClient] = None, + # retrieval params dense_similarity_top_k: Optional[int] = None, sparse_similarity_top_k: Optional[int] = None, enable_reranking: Optional[bool] = None, rerank_top_n: Optional[int] = None, alpha: Optional[float] = None, filters: Optional[MetadataFilters] = None, - api_key: Optional[str] = None, - base_url: Optional[str] = None, - app_url: Optional[str] = None, - timeout: int = 60, retrieval_mode: Optional[str] = None, files_top_k: Optional[int] = None, retrieve_image_nodes: Optional[bool] = None, **kwargs: Any, ) -> None: """Initialize the Platform Retriever.""" - self.name = name - self.project_name = project_name - self._client = get_client(api_key, base_url, app_url, timeout) - self._aclient = get_aclient(api_key, base_url, app_url, timeout) + if sum([bool(id), bool(index_id), bool(pipeline_id), bool(name)]) != 1: + raise ValueError( + "Exactly one of `name`, `id`, `pipeline_id` or `index_id` must be provided to identify the index." + ) - projects = self._client.projects.list_projects( - project_name=project_name, organization_id=organization_id + # initialize clients + self._httpx_client = httpx_client + self._async_httpx_client = async_httpx_client + self._client = get_client(api_key, base_url, app_url, timeout, httpx_client) + self._aclient = get_aclient( + api_key, base_url, app_url, timeout, async_httpx_client ) - if len(projects) == 0: - raise ValueError(f"No project found with name {project_name}") - self.project_id = projects[0].id + pipeline_id = id or index_id or pipeline_id + self.project, self.pipeline = resolve_project_and_pipeline( + self._client, name, pipeline_id, project_name, project_id, organization_id + ) + self.name = self.pipeline.name + self.project_name = self.project.name + # retrieval params self._dense_similarity_top_k = ( dense_similarity_top_k if dense_similarity_top_k is not None else OMIT ) @@ -184,31 +209,9 @@ async def _aimage_nodes_to_node_with_score( def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve from the platform.""" - pipelines = self._client.pipelines.search_pipelines( - project_name=self.project_name, - project_id=self.project_id, - pipeline_name=self.name, - pipeline_type=PipelineType.MANAGED.value, - ) - if len(pipelines) == 0: - raise ValueError( - f"Unknown index name {self.name}. Please confirm a " - "managed index with this name exists." - ) - elif len(pipelines) > 1: - raise ValueError( - f"Multiple pipelines found with name {self.name} in project {self.project_name}" - ) - pipeline = pipelines[0] - - if pipeline.id is None: - raise ValueError( - f"No pipeline found with name {self.name} in project {self.project_name}" - ) - results = self._client.pipelines.run_search( query=query_bundle.query_str, - pipeline_id=pipeline.id, + pipeline_id=self.pipeline.id, dense_similarity_top_k=self._dense_similarity_top_k, sparse_similarity_top_k=self._sparse_similarity_top_k, enable_reranking=self._enable_reranking, @@ -227,31 +230,9 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Asynchronously retrieve from the platform.""" - pipelines = await self._aclient.pipelines.search_pipelines( - project_name=self.project_name, - pipeline_name=self.name, - pipeline_type=PipelineType.MANAGED.value, - project_id=self.project_id, - ) - if len(pipelines) == 0: - raise ValueError( - f"Unknown index name {self.name}. Please confirm a " - "managed index with this name exists." - ) - elif len(pipelines) > 1: - raise ValueError( - f"Multiple pipelines found with name {self.name} in project {self.project_name}" - ) - pipeline = pipelines[0] - - if pipeline.id is None: - raise ValueError( - f"No pipeline found with name {self.name} in project {self.project_name}" - ) - results = await self._aclient.pipelines.run_search( query=query_bundle.query_str, - pipeline_id=pipeline.id, + pipeline_id=self.pipeline.id, dense_similarity_top_k=self._dense_similarity_top_k, sparse_similarity_top_k=self._sparse_similarity_top_k, enable_reranking=self._enable_reranking, diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml index 18f6bc81df588..3719ed5e79195 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml @@ -34,7 +34,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-indices-managed-llama-cloud" readme = "README.md" -version = "0.4.2" +version = "0.5.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/tests/test_indices_managed_llama_cloud.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/tests/test_indices_managed_llama_cloud.py index d1f3887e59e1a..7b3c5d958b4d7 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/tests/test_indices_managed_llama_cloud.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/tests/test_indices_managed_llama_cloud.py @@ -1,16 +1,86 @@ -from typing import Optional -import tempfile -from llama_index.core.indices.managed.base import BaseManagedIndex +from typing import Tuple +from llama_cloud import ( + AutoTransformConfig, + PipelineCreate, + PipelineFileCreate, + ProjectCreate, +) from llama_index.indices.managed.llama_cloud import LlamaCloudIndex +from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core.schema import Document import os import pytest from uuid import uuid4 +from llama_cloud.client import LlamaCloud +from llama_index.core.indices.managed.base import BaseManagedIndex +import tempfile base_url = os.environ.get("LLAMA_CLOUD_BASE_URL", None) api_key = os.environ.get("LLAMA_CLOUD_API_KEY", None) openai_api_key = os.environ.get("OPENAI_API_KEY", None) organization_id = os.environ.get("LLAMA_CLOUD_ORGANIZATION_ID", None) +project_name = os.environ.get("LLAMA_CLOUD_PROJECT_NAME", "framework_integration_test") + + +@pytest.fixture() +def remote_file() -> Tuple[str, str]: + test_file_url = "https://www.google.com/robots.txt" + test_file_name = "google_robots.txt" + return test_file_url, test_file_name + + +def _setup_empty_index( + client: LlamaCloud, +) -> LlamaCloudIndex: + # create project if it doesn't exist + project_create = ProjectCreate(name=project_name) + project = client.projects.upsert_project( + organization_id=organization_id, request=project_create + ) + + # create pipeline + pipeline_create = PipelineCreate( + name="test_empty_index_" + str(uuid4()), + embedding_config={"type": "OPENAI_EMBEDDING", "component": OpenAIEmbedding()}, + transform_config=AutoTransformConfig(), + ) + return client.pipelines.upsert_pipeline( + project_id=project.id, request=pipeline_create + ) + + +def _setup_index_with_file( + client: LlamaCloud, remote_file: Tuple[str, str] +) -> LlamaCloudIndex: + # create project if it doesn't exist + project_create = ProjectCreate(name=project_name) + project = client.projects.upsert_project( + organization_id=organization_id, request=project_create + ) + + # create pipeline + pipeline_create = PipelineCreate( + name="test_index_with_file_" + str(uuid4()), + embedding_config={"type": "OPENAI_EMBEDDING", "component": OpenAIEmbedding()}, + transform_config=AutoTransformConfig(), + ) + pipeline = client.pipelines.upsert_pipeline( + project_id=project.id, request=pipeline_create + ) + + # upload file to pipeline + test_file_url, test_file_name = remote_file + file = client.files.upload_file_from_url( + project_id=project.id, url=test_file_url, name=test_file_name + ) + + # add file to pipeline + pipeline_file_create = PipelineFileCreate(file_id=file.id) + client.pipelines.add_files_to_pipeline( + pipeline_id=pipeline.id, request=[pipeline_file_create] + ) + + return pipeline def test_class(): @@ -18,90 +88,59 @@ def test_class(): assert BaseManagedIndex.__name__ in names_of_base_classes +def test_conflicting_index_identifiers(): + with pytest.raises(ValueError): + LlamaCloudIndex(name="test", pipeline_id="test", index_id="test") + + @pytest.mark.skipif( not base_url or not api_key, reason="No platform base url or api key set" ) @pytest.mark.skipif(not openai_api_key, reason="No openai api key set") @pytest.mark.integration() -def test_retrieve(): - os.environ["OPENAI_API_KEY"] = openai_api_key +def test_resolve_index_with_id(remote_file): + """Test that we can instantiate an index with a given id.""" + client = LlamaCloud(token=api_key, base_url=base_url) + pipeline = _setup_index_with_file(client, remote_file) + index = LlamaCloudIndex( - name="test", # assumes this pipeline exists - project_name="Default", + pipeline_id=pipeline.id, api_key=api_key, base_url=base_url, ) - query = "test" - nodes = index.as_retriever().retrieve(query) - assert nodes is not None and len(nodes) > 0 + assert index is not None - response = index.as_query_engine().query(query) - assert response is not None and len(response.response) > 0 + index.wait_for_completion() + retriever = index.as_retriever() + + nodes = retriever.retrieve("Hello world.") + assert len(nodes) > 0 -@pytest.mark.parametrize("organization_id", [None, organization_id]) @pytest.mark.skipif( not base_url or not api_key, reason="No platform base url or api key set" ) @pytest.mark.skipif(not openai_api_key, reason="No openai api key set") @pytest.mark.integration() -def test_documents_crud(organization_id: Optional[str]): - os.environ["OPENAI_API_KEY"] = openai_api_key - documents = [ - Document(text="Hello world.", doc_id="1", metadata={"source": "test"}), - ] - index = LlamaCloudIndex.from_documents( - documents=documents, - name=f"test pipeline {uuid4()}", +def test_resolve_index_with_name(remote_file): + """Test that we can instantiate an index with a given name.""" + client = LlamaCloud(token=api_key, base_url=base_url) + pipeline = _setup_index_with_file(client, remote_file) + + index = LlamaCloudIndex( + name=pipeline.name, + project_name=project_name, + organization_id=organization_id, api_key=api_key, base_url=base_url, - organization_id=organization_id, - verbose=True, ) - docs = index.ref_doc_info - assert len(docs) == 1 - assert docs["1"].metadata["source"] == "test" - nodes = index.as_retriever().retrieve("Hello world.") - assert len(nodes) > 0 - assert all(n.node.ref_doc_id == "1" for n in nodes) - assert all(n.node.metadata["source"] == "test" for n in nodes) - - index.insert( - Document(text="Hello world.", doc_id="2", metadata={"source": "inserted"}), - verbose=True, - ) - docs = index.ref_doc_info - assert len(docs) == 2 - assert docs["2"].metadata["source"] == "inserted" - nodes = index.as_retriever().retrieve("Hello world.") - assert len(nodes) > 0 - assert all(n.node.ref_doc_id in ["1", "2"] for n in nodes) - assert any(n.node.ref_doc_id == "1" for n in nodes) - assert any(n.node.ref_doc_id == "2" for n in nodes) + assert index is not None - index.update_ref_doc( - Document(text="Hello world.", doc_id="2", metadata={"source": "updated"}), - verbose=True, - ) - docs = index.ref_doc_info - assert len(docs) == 2 - assert docs["2"].metadata["source"] == "updated" + index.wait_for_completion() + retriever = index.as_retriever() - index.refresh_ref_docs( - [ - Document(text="Hello world.", doc_id="1", metadata={"source": "refreshed"}), - Document(text="Hello world.", doc_id="3", metadata={"source": "refreshed"}), - ] - ) - docs = index.ref_doc_info - assert len(docs) == 3 - assert docs["3"].metadata["source"] == "refreshed" - assert docs["1"].metadata["source"] == "refreshed" - - index.delete_ref_doc("3", verbose=True) - docs = index.ref_doc_info - assert len(docs) == 2 - assert "3" not in docs + nodes = retriever.retrieve("Hello world.") + assert len(nodes) > 0 @pytest.mark.skipif( @@ -110,10 +149,12 @@ def test_documents_crud(organization_id: Optional[str]): @pytest.mark.skipif(not openai_api_key, reason="No openai api key set") @pytest.mark.integration() def test_upload_file(): - os.environ["OPENAI_API_KEY"] = openai_api_key + pipeline = _setup_empty_index(LlamaCloud(token=api_key, base_url=base_url)) + index = LlamaCloudIndex( - name="test", # assumes this pipeline exists - project_name="Default", + name=pipeline.name, + project_name=project_name, + organization_id=organization_id, api_key=api_key, base_url=base_url, ) @@ -145,18 +186,19 @@ def test_upload_file(): ) @pytest.mark.skipif(not openai_api_key, reason="No openai api key set") @pytest.mark.integration() -def test_upload_file_from_url(): - os.environ["OPENAI_API_KEY"] = openai_api_key +def test_upload_file_from_url(remote_file): + pipeline = _setup_empty_index(LlamaCloud(token=api_key, base_url=base_url)) + index = LlamaCloudIndex( - name="test", # assumes this pipeline exists - project_name="Default", + name=pipeline.name, + project_name=project_name, + organization_id=organization_id, api_key=api_key, base_url=base_url, ) # Define a URL to a file for testing - test_file_url = "https://www.google.com/robots.txt" - test_file_name = "google_robots.txt" + test_file_url, test_file_name = remote_file # Upload the file from the URL file_id = index.upload_file_from_url( @@ -167,3 +209,67 @@ def test_upload_file_from_url(): # Verify the file is part of the index docs = index.ref_doc_info assert any(test_file_name == doc.metadata.get("file_name") for doc in docs.values()) + + +@pytest.mark.skipif( + not base_url or not api_key, reason="No platform base url or api key set" +) +@pytest.mark.skipif(not openai_api_key, reason="No openai api key set") +@pytest.mark.integration() +def test_index_from_documents(): + documents = [ + Document(text="Hello world.", doc_id="1", metadata={"source": "test"}), + ] + index = LlamaCloudIndex.from_documents( + documents=documents, + name=f"test pipeline {uuid4()}", + project_name=project_name, + api_key=api_key, + base_url=base_url, + organization_id=organization_id, + verbose=True, + ) + docs = index.ref_doc_info + assert len(docs) == 1 + assert docs["1"].metadata["source"] == "test" + nodes = index.as_retriever().retrieve("Hello world.") + assert len(nodes) > 0 + assert all(n.node.ref_doc_id == "1" for n in nodes) + assert all(n.node.metadata["source"] == "test" for n in nodes) + + index.insert( + Document(text="Hello world.", doc_id="2", metadata={"source": "inserted"}), + verbose=True, + ) + docs = index.ref_doc_info + assert len(docs) == 2 + assert docs["2"].metadata["source"] == "inserted" + nodes = index.as_retriever().retrieve("Hello world.") + assert len(nodes) > 0 + assert all(n.node.ref_doc_id in ["1", "2"] for n in nodes) + assert any(n.node.ref_doc_id == "1" for n in nodes) + assert any(n.node.ref_doc_id == "2" for n in nodes) + + index.update_ref_doc( + Document(text="Hello world.", doc_id="2", metadata={"source": "updated"}), + verbose=True, + ) + docs = index.ref_doc_info + assert len(docs) == 2 + assert docs["2"].metadata["source"] == "updated" + + index.refresh_ref_docs( + [ + Document(text="Hello world.", doc_id="1", metadata={"source": "refreshed"}), + Document(text="Hello world.", doc_id="3", metadata={"source": "refreshed"}), + ] + ) + docs = index.ref_doc_info + assert len(docs) == 3 + assert docs["3"].metadata["source"] == "refreshed" + assert docs["1"].metadata["source"] == "refreshed" + + index.delete_ref_doc("3", verbose=True) + docs = index.ref_doc_info + assert len(docs) == 2 + assert "3" not in docs