diff --git a/nesis/api/core/document_loaders/loader_helper.py b/nesis/api/core/document_loaders/loader_helper.py index bfc3127..c15ef54 100644 --- a/nesis/api/core/document_loaders/loader_helper.py +++ b/nesis/api/core/document_loaders/loader_helper.py @@ -1,20 +1,31 @@ -import uuid +import datetime import json -from nesis.api.core.models.entities import Document -from nesis.api.core.util.dateutil import strptime import logging -from nesis.api.core.services.util import get_document, delete_document +import uuid +from typing import Optional, Dict, Any, Callable + +import nesis.api.core.util.http as http +from nesis.api.core.document_loaders.runners import ( + IngestRunner, + ExtractRunner, + RagRunner, +) +from nesis.api.core.models.entities import Document, Datasource +from nesis.api.core.services.util import delete_document +from nesis.api.core.services.util import ( + get_document, +) +from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT +from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) def upload_document_to_llm(upload_document, file_metadata, rag_endpoint, http_client): - return _upload_document_to_pgpt( - upload_document, file_metadata, rag_endpoint, http_client - ) + return _upload_document(upload_document, file_metadata, rag_endpoint, http_client) -def _upload_document_to_pgpt(upload_document, file_metadata, rag_endpoint, http_client): +def _upload_document(upload_document, file_metadata, rag_endpoint, http_client): document_id = file_metadata["unique_id"] file_name = file_metadata["name"] @@ -51,3 +62,101 @@ def _upload_document_to_pgpt(upload_document, file_metadata, rag_endpoint, http_ request_object = {"file_name": file_name, "text": upload_document.page_content} response = http_client.post(url=f"{rag_endpoint}/v1/ingest", payload=request_object) return json.loads(response) + + +class DocumentProcessor(object): + def __init__( + self, + config, + http_client: http.HttpClient, + datasource: Datasource, + ): + self._datasource = datasource + + # This is left package public for testing + self._extract_runner: ExtractRunner = Optional[None] + _ingest_runner = IngestRunner(config=config, http_client=http_client) + + if self._datasource.connection.get("destination") is not None: + self._extract_runner = ExtractRunner( + config=config, + http_client=http_client, + destination=self._datasource.connection.get("destination"), + ) + + self._mode = self._datasource.connection.get("mode") or "ingest" + + match self._mode: + case "ingest": + self._ingest_runners: list[RagRunner] = [_ingest_runner] + case "extract": + self._ingest_runners: list[RagRunner] = [self._extract_runner] + case _: + raise ValueError( + f"Invalid mode {self._mode}. Expected 'ingest' or 'extract'" + ) + + def sync( + self, + endpoint: str, + file_path: str, + last_modified: datetime.datetime, + metadata: Dict[str, Any], + store_metadata: Dict[str, Any], + ) -> None: + """ + Here we check if this file has been updated. + If the file has been updated, we delete it from the vector store and re-ingest the new updated file + """ + document_id = str( + uuid.uuid5( + uuid.NAMESPACE_DNS, f"{self._datasource.uuid}:{metadata['self_link']}" + ) + ) + document: Document = get_document(document_id=document_id) + for _ingest_runner in self._ingest_runners: + try: + response_json = _ingest_runner.run( + file_path=file_path, + metadata=metadata, + document_id=None if document is None else document.uuid, + last_modified=last_modified.replace(tzinfo=None).replace( + microsecond=0 + ), + datasource=self._datasource, + ) + except ValueError: + _LOG.warning(f"File {file_path} ingestion failed", exc_info=True) + response_json = None + except UserWarning: + _LOG.warning(f"File {file_path} is already processing") + continue + + if response_json is None: + _LOG.warning("No response from ingest runner received") + continue + + _ingest_runner.save( + document_id=document_id, + datasource_id=self._datasource.uuid, + filename=store_metadata["filename"], + base_uri=endpoint, + rag_metadata=response_json, + store_metadata=store_metadata, + last_modified=last_modified, + ) + + def unsync(self, clean: Callable) -> None: + endpoint = self._datasource.connection.get("endpoint") + + for _ingest_runner in self._ingest_runners: + documents = _ingest_runner.get(base_uri=endpoint) + for document in documents: + store_metadata = document.store_metadata + try: + rag_metadata = document.rag_metadata + except AttributeError: + rag_metadata = document.extract_metadata + + if clean(store_metadata=store_metadata): + _ingest_runner.delete(document=document, rag_metadata=rag_metadata) diff --git a/nesis/api/core/document_loaders/minio.py b/nesis/api/core/document_loaders/minio.py index 30b1b43..aee6318 100644 --- a/nesis/api/core/document_loaders/minio.py +++ b/nesis/api/core/document_loaders/minio.py @@ -1,39 +1,26 @@ -import concurrent -import concurrent.futures import logging -import multiprocessing +import logging import os -import queue import tempfile -from typing import Dict, Any, Optional +from typing import Dict, Any import memcache -import minio from minio import Minio import nesis.api.core.util.http as http -from nesis.api.core.document_loaders.runners import ( - IngestRunner, - ExtractRunner, - RagRunner, -) -from nesis.api.core.models.entities import Document, Datasource -from nesis.api.core.services.util import ( - get_document, - get_documents, -) +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor +from nesis.api.core.models.entities import Datasource from nesis.api.core.util import clean_control, isblank from nesis.api.core.util.concurrency import ( IOBoundPool, as_completed, - BlockingThreadPoolExecutor, ) from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT _LOG = logging.getLogger(__name__) -class MinioProcessor(object): +class MinioProcessor(DocumentProcessor): def __init__( self, config, @@ -41,36 +28,12 @@ def __init__( cache_client: memcache.Client, datasource: Datasource, ): + super().__init__(config, http_client, datasource) self._config = config self._http_client = http_client self._cache_client = cache_client self._datasource = datasource - # This is left public for testing - self._extract_runner: ExtractRunner = Optional[None] - _ingest_runner = IngestRunner(config=config, http_client=http_client) - if self._datasource.connection.get("destination") is not None: - self._extract_runner = ExtractRunner( - config=config, - http_client=http_client, - destination=self._datasource.connection.get("destination"), - ) - self._ingest_runners = [] - - self._ingest_runners = [IngestRunner(config=config, http_client=http_client)] - - self._mode = self._datasource.connection.get("mode") or "ingest" - - match self._mode: - case "ingest": - self._ingest_runners: list[RagRunner] = [_ingest_runner] - case "extract": - self._ingest_runners: list[RagRunner] = [self._extract_runner] - case _: - raise ValueError( - f"Invalid mode {self._mode}. Expected 'ingest' or 'extract'" - ) - def run(self, metadata: Dict[str, Any]): connection: Dict[str, str] = self._datasource.connection try: @@ -93,7 +56,6 @@ def run(self, metadata: Dict[str, Any]): ) self._unsync_documents( client=_minio_client, - connection=connection, ) except: _LOG.exception("Error fetching sharepoint documents") @@ -192,52 +154,22 @@ def _sync_document( file_path=file_path, ) - """ - Here we check if this file has been updated. - If the file has been updated, we delete it from the vector store and re-ingest the new updated file - """ - document: Document = get_document(document_id=item.etag) - document_id = None if document is None else document.uuid - - for _ingest_runner in self._ingest_runners: - try: - response_json = _ingest_runner.run( - file_path=file_path, - metadata=metadata, - document_id=document_id, - last_modified=item.last_modified.replace(tzinfo=None).replace( - microsecond=0 - ), - datasource=datasource, - ) - except ValueError: - _LOG.warning(f"File {file_path} ingestion failed", exc_info=True) - response_json = None - except UserWarning: - _LOG.debug(f"File {file_path} is already processing") - return - - if response_json is None: - return - - _ingest_runner.save( - document_id=item.etag, - datasource_id=datasource.uuid, - filename=item.object_name, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata={ - "bucket_name": item.bucket_name, - "object_name": item.object_name, - "etag": item.etag, - "size": item.size, - "last_modified": item.last_modified.strftime( - DEFAULT_DATETIME_FORMAT - ), - "version_id": item.version_id, - }, - last_modified=item.last_modified, - ) + self.sync( + endpoint, + file_path, + item.last_modified, + metadata, + store_metadata={ + "bucket_name": item.bucket_name, + "object_name": item.object_name, + "filename": item.object_name, + "size": item.size, + "last_modified": item.last_modified.strftime( + DEFAULT_DATETIME_FORMAT + ), + "version_id": item.version_id, + }, + ) _LOG.info( f"Done {self._mode}ing object {item.object_name} in bucket {bucket_name}" @@ -254,37 +186,27 @@ def _sync_document( def _unsync_documents( self, client: Minio, - connection: dict, ) -> None: - try: - endpoint = connection.get("endpoint") - - for _ingest_runner in self._ingest_runners: - documents = _ingest_runner.get(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - try: - rag_metadata = document.rag_metadata - except AttributeError: - rag_metadata = document.extract_metadata - bucket_name = store_metadata["bucket_name"] - object_name = store_metadata["object_name"] - try: - client.stat_object( - bucket_name=bucket_name, object_name=object_name - ) - except Exception as ex: - str_ex = str(ex) - if "NoSuchKey" in str_ex and "does not exist" in str_ex: - _ingest_runner.delete( - document=document, rag_metadata=rag_metadata - ) - else: - raise + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] + try: + client.stat_object( + bucket_name=store_metadata["bucket_name"], + object_name=store_metadata["object_name"], + ) + return False + except Exception as ex: + str_ex = str(ex) + if "NoSuchKey" in str_ex and "does not exist" in str_ex: + return True + else: + raise + try: + self.unsync(clean=clean) except: - _LOG.warn("Error fetching and updating documents", exc_info=True) + _LOG.warning("Error fetching and updating documents", exc_info=True) def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: diff --git a/nesis/api/core/document_loaders/runners.py b/nesis/api/core/document_loaders/runners.py index baf408c..279cb01 100644 --- a/nesis/api/core/document_loaders/runners.py +++ b/nesis/api/core/document_loaders/runners.py @@ -74,14 +74,16 @@ def run( ) -> Union[Dict[str, Any], None]: if document_id is not None: + _LOG.debug(f"Checking if document {document_id} is modified") _is_modified = self._is_modified( document_id=document_id, last_modified=last_modified ) if _is_modified is None or not _is_modified: + _LOG.debug(f"Document {document_id} is not modified") return url = f"{self._rag_endpoint}/v1/extractions/text" - + _LOG.debug(f"Document {document_id} is modified, performing extraction") response = self._http_client.upload( url=url, filepath=file_path, diff --git a/nesis/api/core/document_loaders/s3.py b/nesis/api/core/document_loaders/s3.py index 7d6f274..e46ce6a 100644 --- a/nesis/api/core/document_loaders/s3.py +++ b/nesis/api/core/document_loaders/s3.py @@ -2,12 +2,14 @@ import logging import pathlib import tempfile +from concurrent.futures import as_completed from typing import Dict, Any import boto3 import memcache import nesis.api.core.util.http as http +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor from nesis.api.core.models.entities import Document, Datasource from nesis.api.core.services import util from nesis.api.core.services.util import ( @@ -18,282 +20,223 @@ ingest_file, ) from nesis.api.core.util import clean_control, isblank +from nesis.api.core.util.concurrency import IOBoundPool +from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) -def fetch_documents( - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: Dict[str, Any], -) -> None: - try: - connection = datasource.connection - endpoint = connection.get("endpoint") - access_key = connection.get("user") - secret_key = connection.get("password") - region = connection.get("region") - if all([access_key, secret_key]): - if endpoint: - s3_client = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - endpoint_url=endpoint, - ) - else: - s3_client = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) - else: - if endpoint: - s3_client = boto3.client( - "s3", region_name=region, endpoint_url=endpoint - ) +class Processor(DocumentProcessor): + def __init__( + self, + config, + http_client: http.HttpClient, + cache_client: memcache.Client, + datasource: Datasource, + ): + super().__init__(config, http_client, datasource) + self._config = config + self._http_client = http_client + self._cache_client = cache_client + self._datasource = datasource + + def run(self, metadata: Dict[str, Any]): + connection: Dict[str, str] = self._datasource.connection + try: + endpoint = connection.get("endpoint") + access_key = connection.get("user") + secret_key = connection.get("password") + region = connection.get("region") + if all([access_key, secret_key]): + if endpoint: + s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + endpoint_url=endpoint, + ) + else: + s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + ) else: - s3_client = boto3.client("s3", region_name=region) - - _sync_documents( - client=s3_client, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - cache_client=cache_client, - metadata=metadata, - ) - _unsync_documents( - client=s3_client, - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, - ) - except Exception as ex: - _LOG.exception(f"Error fetching s3 documents - {ex}") - - -def _sync_documents( - client, - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: dict, -) -> None: + if endpoint: + s3_client = boto3.client( + "s3", region_name=region, endpoint_url=endpoint + ) + else: + s3_client = boto3.client("s3", region_name=region) - try: + self._sync_documents( + client=s3_client, + datasource=self._datasource, + metadata=metadata, + ) + self._unsync_documents( + client=s3_client, + ) + except: + _LOG.exception("Error fetching sharepoint documents") - # Data objects allow us to specify bucket names - connection = datasource.connection - bucket_paths = connection.get("dataobjects") - if bucket_paths is None: - _LOG.warning("No bucket names supplied, so I can't do much") + def _sync_documents( + self, + client, + datasource: Datasource, + metadata: dict, + ) -> None: - bucket_paths_parts = bucket_paths.split(",") + try: - _LOG.info(f"Initializing syncing to endpoint {rag_endpoint}") + # Data objects allow us to specify bucket names + connection = datasource.connection + bucket_paths = connection.get("dataobjects") + if bucket_paths is None: + _LOG.warning("No bucket names supplied, so I can't do much") - for bucket_path in bucket_paths_parts: + bucket_paths_parts = bucket_paths.split(",") + futures = [] + for bucket_path in bucket_paths_parts: - # a/b/c/// should only give [a,b,c] - bucket_path_parts = [ - part for part in bucket_path.split("/") if len(part) != 0 - ] + # a/b/c/// should only give [a,b,c] + bucket_path_parts = [ + part for part in bucket_path.split("/") if len(part) != 0 + ] - path = "/".join(bucket_path_parts[1:]) - bucket_name = bucket_path_parts[0] + path = "/".join(bucket_path_parts[1:]) + bucket_name = bucket_path_parts[0] - paginator = client.get_paginator("list_objects_v2") - page_iterator = paginator.paginate( - Bucket=bucket_name, - Prefix="" if path == "" else f"{path}/", - ) - for result in page_iterator: - if result["KeyCount"] == 0: - continue - # iterate through files - for item in result["Contents"]: - # Paths ending in / are folders so we skip them - if item["Key"].endswith("/"): + paginator = client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate( + Bucket=bucket_name, + Prefix="" if path == "" else f"{path}/", + ) + for result in page_iterator: + if result["KeyCount"] == 0: continue - - endpoint = connection["endpoint"] - self_link = f"{endpoint}/{bucket_name}/{item['Key']}" - _metadata = { - **(metadata or {}), - "file_name": f"{bucket_name}/{item['Key']}", - "self_link": self_link, - } - - """ - We use memcache's add functionality to implement a shared lock to allow for multiple instances - operating - """ - _lock_key = clean_control(f"{__name__}/locks/{self_link}") - if cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): - try: - _sync_document( - client=client, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=_metadata, - bucket_name=bucket_name, - item=item, + # iterate through files + for item in result["Contents"]: + # Paths ending in / are folders so we skip them + if item["Key"].endswith("/"): + continue + futures.append( + IOBoundPool.submit( + self._process_object, + bucket_name, + client, + datasource, + item, + metadata, ) - finally: - cache_client.delete(_lock_key) - else: - _LOG.info(f"Document {self_link} is already processing") - - _LOG.info(f"Completed syncing to endpoint {rag_endpoint}") - - except: - _LOG.warning("Error fetching and updating documents", exc_info=True) - - -def _sync_document( - client, - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - metadata: dict, - bucket_name: str, - item, -): - connection = datasource.connection - endpoint = connection["endpoint"] - _metadata = metadata - - with tempfile.NamedTemporaryFile( - dir=tempfile.gettempdir(), - ) as tmp: - key_parts = item["Key"].split("/") - - path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" - - try: - _LOG.info(f"Starting syncing object {item['Key']} in bucket {bucket_name}") - # Write item to file - client.download_file(bucket_name, item["Key"], path_to_tmp) - - document: Document = get_document(document_id=item["ETag"]) - if document and document.base_uri == endpoint: - store_metadata = document.store_metadata - if store_metadata and store_metadata.get("last_modified"): - last_modified = store_metadata["last_modified"] - if not strptime(date_string=last_modified).replace( - tzinfo=None - ) < item["LastModified"].replace(tzinfo=None).replace( - microsecond=0 - ): - _LOG.debug( - f"Skipping document {item['Key']} already up to date" ) - return - rag_metadata: dict = document.rag_metadata - if rag_metadata is None: - return - for document_data in rag_metadata.get("data") or []: - try: - util.un_ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - doc_id=document_data["doc_id"], - ) - except: - _LOG.warning( - f"Failed to delete document {document_data['doc_id']}" - ) + for future in as_completed(futures): + try: + future.result() + except: + _LOG.warning(future.exception()) - try: - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {item.object_name}'s record. Continuing anyway..." - ) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) + def _process_object(self, bucket_name, client, datasource, item, metadata): + connection = datasource.connection + endpoint = connection["endpoint"] + self_link = f"{endpoint}/{bucket_name}/{item['Key']}" + _metadata = { + **(metadata or {}), + "file_name": f"{bucket_name}/{item['Key']}", + "self_link": self_link, + } + """ + We use memcache's add functionality to implement a shared lock to allow for multiple instances + operating + """ + _lock_key = clean_control(f"{__name__}/locks/{self_link}") + if self._cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): try: - response = ingest_file( - http_client=http_client, - endpoint=rag_endpoint, + self._sync_document( + client=client, + datasource=datasource, metadata=_metadata, - file_path=path_to_tmp, + bucket_name=bucket_name, + item=item, ) - response_json = json.loads(response) - - except ValueError: - _LOG.warning(f"File {path_to_tmp} ingestion failed", exc_info=True) - response_json = {} - except UserWarning: - _LOG.debug(f"File {path_to_tmp} is already processing") - return - - save_document( - document_id=item["ETag"], - filename=item["Key"], - datasource_id=datasource.uuid, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata={ - "bucket_name": bucket_name, - "object_name": item["Key"], - "etag": item["ETag"], - "size": item["Size"], - "last_modified": str(item["LastModified"]), - }, - last_modified=item["LastModified"], - ) - - _LOG.info(f"Done syncing object {item['Key']} in bucket {bucket_name}") - except Exception as ex: - _LOG.warning( - f"Error when getting and ingesting document {item['Key']} - {ex}" - ) - + finally: + self._cache_client.delete(_lock_key) + else: + _LOG.info(f"Document {self_link} is already processing") + + def _sync_document( + self, + client, + datasource: Datasource, + metadata: dict, + bucket_name: str, + item, + ): + endpoint = datasource.connection["endpoint"] + _metadata = metadata + + with tempfile.NamedTemporaryFile( + dir=tempfile.gettempdir(), + ) as tmp: + key_parts = item["Key"].split("/") + + path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" -def _unsync_documents( - client, connection: dict, rag_endpoint: str, http_client: http.HttpClient -) -> None: + try: + _LOG.info( + f"Starting syncing object {item['Key']} in bucket {bucket_name}" + ) + # Write item to file + client.download_file(bucket_name, item["Key"], path_to_tmp) + self.sync( + endpoint, + path_to_tmp, + last_modified=item["LastModified"], + metadata=metadata, + store_metadata={ + "bucket_name": bucket_name, + "object_name": item["Key"], + "filename": item["Key"], + "size": item["Size"], + "last_modified": item["LastModified"].strftime( + DEFAULT_DATETIME_FORMAT + ), + }, + ) - try: - endpoint = connection.get("endpoint") + _LOG.info(f"Done syncing object {item['Key']} in bucket {bucket_name}") + except: + _LOG.warning( + f"Error when getting and ingesting document {item['Key']}", + exc_info=True, + ) - documents = get_documents(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - rag_metadata = document.rag_metadata - bucket_name = store_metadata["bucket_name"] - object_name = store_metadata["object_name"] + def _unsync_documents(self, client) -> None: + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] try: - client.head_object(Bucket=bucket_name, Key=object_name) + client.head_object( + Bucket=store_metadata["bucket_name"], + Key=store_metadata["object_name"], + ) + return False except Exception as ex: - str_ex = str(ex).lower() + str_ex = str(ex) if not ("object" in str_ex and "not found" in str_ex): + return True + else: raise - try: - http_client.deletes( - urls=[ - f"{rag_endpoint}/v1/ingest/documents/{document_data['doc_id']}" - for document_data in rag_metadata.get("data") or [] - ] - ) - _LOG.info(f"Deleting document {document.filename}") - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {document.filename}", - exc_info=True, - ) - except: - _LOG.warn("Error fetching and updating documents", exc_info=True) + try: + self.unsync(clean=clean) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: diff --git a/nesis/api/core/document_loaders/samba.py b/nesis/api/core/document_loaders/samba.py index dd368a6..8cd92da 100644 --- a/nesis/api/core/document_loaders/samba.py +++ b/nesis/api/core/document_loaders/samba.py @@ -1,316 +1,254 @@ -import uuid +import logging import pathlib -import json -import memcache +import uuid +from concurrent.futures import as_completed from datetime import datetime from typing import Dict, Any +import memcache import smbprotocol from smbclient import scandir, stat, shutil -import logging -from nesis.api.core.models.entities import Document -from nesis.api.core.services import util -from nesis.api.core.services.util import ( - save_document, - get_document, - delete_document, - get_documents, - ValidationException, - ingest_file, -) +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor +from nesis.api.core.models.entities import Datasource from nesis.api.core.util import http, clean_control, isblank +from nesis.api.core.util.concurrency import IOBoundPool from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT, DEFAULT_SAMBA_PORT -from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) -def fetch_documents( - connection: Dict[str, str], - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: Dict[str, Any], -) -> None: - try: - _sync_samba_documents( - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=metadata, - cache_client=cache_client, - ) - except: - _LOG.exception(f"Error syncing documents") - - try: - _unsync_samba_documents( - connection=connection, rag_endpoint=rag_endpoint, http_client=http_client - ) - except Exception as ex: - _LOG.exception(f"Error unsyncing documents") - - -def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: - port = connection.get("port") or DEFAULT_SAMBA_PORT - _valid_keys = ["port", "endpoint", "user", "password", "dataobjects"] - if not str(port).isnumeric(): - raise ValueError("Port value cannot be non numeric") +class Processor(DocumentProcessor): + def __init__( + self, + config, + http_client: http.HttpClient, + cache_client: memcache.Client, + datasource: Datasource, + ): + super().__init__(config, http_client, datasource) + self._config = config + self._http_client = http_client + self._cache_client = cache_client + self._datasource = datasource + self._futures = [] + + def run(self, metadata: Dict[str, Any]): + connection = self._datasource.connection + try: + self._sync_samba_documents( + metadata=metadata, + ) + except: + _LOG.exception(f"Error syncing documents") - assert not isblank( - connection.get("endpoint") - ), "A valid share address must be supplied" + try: + self._unsync_samba_documents( + connection=connection, + ) + except: + _LOG.exception(f"Error unsyncing documents") - try: - _connect_samba_server(connection) - except Exception as ex: - _LOG.exception( - f"Failed to connect to samba server at {connection['endpoint']}", - ) - raise ValueError(ex) - connection["port"] = port - return { - key: val - for key, val in connection.items() - if key in _valid_keys and not isblank(connection[key]) - } + for future in as_completed(self._futures): + try: + future.result() + except: + _LOG.warning(future.exception()) + def _sync_samba_documents(self, metadata): -def _connect_samba_server(connection): - username = connection.get("user") - password = connection.get("password") - endpoint = connection.get("endpoint") - port = connection.get("port") - next(scandir(endpoint, username=username, password=password, port=port)) + connection = self._datasource.connection + username = connection["user"] + password = connection["password"] + endpoint = connection["endpoint"] + port = connection["port"] + # These are any folder specified to scope the sync to + dataobjects = connection.get("dataobjects") or "" -def _sync_samba_documents( - connection, rag_endpoint, http_client, metadata, cache_client -): + dataobjects_parts = [do.strip() for do in dataobjects.split(",")] - username = connection["user"] - password = connection["password"] - endpoint = connection["endpoint"] - port = connection["port"] - # These are any folder specified to scope the sync to - dataobjects = connection.get("dataobjects") or "" + try: + file_shares = scandir( + endpoint, username=username, password=password, port=port + ) + except Exception as ex: + _LOG.exception( + f"Error while scanning share on samba server {endpoint} - {ex}" + ) + raise - dataobjects_parts = [do.strip() for do in dataobjects.split(",")] + work_dir = f"/tmp/{uuid.uuid4()}" + pathlib.Path(work_dir).mkdir(parents=True) - try: - file_shares = scandir(endpoint, username=username, password=password, port=port) - except Exception as ex: - _LOG.exception(f"Error while scanning share on samba server {endpoint} - {ex}") - raise - - work_dir = f"/tmp/{uuid.uuid4()}" - pathlib.Path(work_dir).mkdir(parents=True) - - for file_share in file_shares: - if ( - len(dataobjects_parts) > 0 - and file_share.is_dir() - and file_share.name not in dataobjects_parts - ): - continue - try: - self_link = file_share.path - _lock_key = clean_control(f"{__name__}/locks/{self_link}") + for file_share in file_shares: + if ( + len(dataobjects_parts) > 0 + and file_share.is_dir() + and file_share.name not in dataobjects_parts + ): + continue + try: - if cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): _metadata = { **(metadata or {}), "file_name": file_share.path, - "self_link": self_link, } - try: - _process_file( + + self._futures.append( + IOBoundPool.submit( + self._process_file, connection=connection, file_share=file_share, work_dir=work_dir, - http_client=http_client, - rag_endpoint=rag_endpoint, metadata=_metadata, ) - finally: - cache_client.delete(_lock_key) - else: - _LOG.info(f"Document {self_link} is already processing") - except: - _LOG.warn( - f"Error fetching and updating documents from shared_file share {file_share.path} - ", - exc_info=True, - ) + ) - _LOG.info( - f"Completed syncing files from samba server {endpoint} " - f"to endpoint {rag_endpoint}" - ) + except: + _LOG.warning( + f"Error fetching and updating documents from shared_file share {file_share.path} - ", + exc_info=True, + ) + def _process_file(self, connection, file_share, work_dir, metadata): + username = connection["user"] + password = connection["password"] + endpoint = connection["endpoint"] + port = connection["port"] -def _process_file( - connection, file_share, work_dir, http_client, rag_endpoint, metadata -): - username = connection["user"] - password = connection["password"] - endpoint = connection["endpoint"] - port = connection["port"] + if file_share.is_dir(): + if not file_share.name.startswith("."): + dir_files = scandir( + file_share.path, username=username, password=password, port=port + ) + for dir_file in dir_files: + self._process_file( + connection=connection, + file_share=dir_file, + work_dir=work_dir, + metadata=metadata, + ) + return + + file_name = file_share.name + file_stats = stat( + file_share.path, username=username, password=password, port=port + ) + last_change_datetime = datetime.fromtimestamp(file_stats.st_chgtime) - if file_share.is_dir(): - if not file_share.name.startswith("."): - dir_files = scandir( - file_share.path, username=username, password=password, port=port + try: + file_path = f"{work_dir}/{file_share.name}" + file_unique_id = f"{uuid.uuid5(uuid.NAMESPACE_DNS, file_share.path)}" + + _LOG.info( + f"Starting syncing shared_file {file_name} in shared directory share {file_share.path}" ) - for dir_file in dir_files: - _process_file( - connection=connection, - file_share=dir_file, - work_dir=work_dir, - http_client=http_client, - rag_endpoint=rag_endpoint, - metadata=metadata, + + try: + shutil.copyfile( + file_share.path, + file_path, + username=username, + password=password, + port=port, ) - return + self_link = file_share.path + _lock_key = clean_control(f"{__name__}/locks/{self_link}") - file_name = file_share.name - file_stats = stat(file_share.path, username=username, password=password, port=port) - last_change_datetime = datetime.fromtimestamp(file_stats.st_chgtime) + metadata["self_link"] = self_link - try: - file_path = f"{work_dir}/{file_share.name}" - file_unique_id = f"{uuid.uuid5(uuid.NAMESPACE_DNS, file_share.path)}" - - _LOG.info( - f"Starting syncing shared_file {file_name} in shared directory share {file_share.path}" - ) + if self._cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): + try: + self.sync( + endpoint, + file_path, + last_modified=last_change_datetime, + metadata=metadata, + store_metadata={ + "shared_folder": file_share.name, + "file_path": file_share.path, + "filename": file_share.path, + "file_id": file_unique_id, + "size": file_stats.st_size, + "name": file_name, + "last_modified": last_change_datetime.strftime( + DEFAULT_DATETIME_FORMAT + ), + }, + ) + finally: + self._cache_client.delete(_lock_key) + else: + _LOG.info(f"Document {self_link} is already processing") + + except: + _LOG.warning( + f"Failed to copy contents of shared_file {file_name} from shared location {file_share.path}", + exc_info=True, + ) + return - try: - shutil.copyfile( - file_share.path, - file_path, - username=username, - password=password, - port=port, + _LOG.info( + f"Done syncing shared_file {file_name} in location {file_share.path}" ) except Exception as ex: _LOG.warn( - f"Failed to copy contents of shared_file {file_name} from shared location {file_share.path}", + f"Error when getting and ingesting shared_file {file_name} - {ex}", exc_info=True, ) - return - """ - Here we check if this file has been updated. - If the file has been updated, we delete it from the vector store and re-ingest the new updated file - """ - document: Document = get_document(document_id=file_unique_id) - if document and document.base_uri == endpoint: - store_metadata = document.store_metadata - if store_metadata and store_metadata.get("last_modified"): - if not strptime(date_string=store_metadata["last_modified"]).replace( - tzinfo=None - ) < last_change_datetime.replace(tzinfo=None).replace(microsecond=0): - _LOG.debug(f"Skipping shared_file {file_name} already up to date") - return - rag_metadata: dict = document.rag_metadata - if rag_metadata is None: - return - for document_data in rag_metadata.get("data") or []: - try: - util.un_ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - doc_id=document_data["doc_id"], - ) - except: - _LOG.warn( - f"Failed to delete document {document_data['doc_id']}", - exc_info=True, - ) - try: - delete_document(document_id=file_unique_id) - except: - _LOG.warn( - f"Failed to delete shared_file {file_name}'s record. Continuing anyway...", - exc_info=True, - ) + def _unsync_samba_documents(self, connection): + username = connection["user"] + password = connection["password"] + port = connection["port"] - file_metadata = { - "shared_folder": file_share.name, - "file_path": file_share.path, - "file_id": file_unique_id, - "size": file_stats.st_size, - "name": file_name, - "last_modified": last_change_datetime.strftime(DEFAULT_DATETIME_FORMAT), - } + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] + file_path = store_metadata["file_path"] + try: + stat(file_path, username=username, password=password, port=port) + return False + except Exception as error: + if "No such file" in str(error): + return True + else: + raise try: - response = ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - metadata=metadata, - file_path=file_path, - ) - except UserWarning: - _LOG.debug(f"File {file_path} is already processing") - return - response_json = json.loads(response) - - save_document( - document_id=file_unique_id, - filename=file_share.path, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata=file_metadata, - ) + self.unsync(clean=clean) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) - _LOG.info(f"Done syncing shared_file {file_name} in location {file_share.path}") - except Exception as ex: - _LOG.warn( - f"Error when getting and ingesting shared_file {file_name} - {ex}", - exc_info=True, - ) - _LOG.info( - f"Completed syncing files from shared_file share {file_share.path} to endpoint {rag_endpoint}" - ) +def _connect_samba_server(connection): + username = connection.get("user") + password = connection.get("password") + endpoint = connection.get("endpoint") + port = connection.get("port") + next(scandir(endpoint, username=username, password=password, port=port)) -def _unsync_samba_documents(connection, rag_endpoint, http_client): - try: - username = connection["user"] - password = connection["password"] - endpoint = connection["endpoint"] - port = connection["port"] - work_dir = f"/tmp/{uuid.uuid4()}" - pathlib.Path(work_dir).mkdir(parents=True) +def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: + port = connection.get("port") or DEFAULT_SAMBA_PORT + _valid_keys = ["port", "endpoint", "user", "password", "dataobjects"] + if not str(port).isnumeric(): + raise ValueError("Port value cannot be non numeric") - documents = get_documents(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - rag_metadata = document.rag_metadata + assert not isblank( + connection.get("endpoint") + ), "A valid share address must be supplied" - file_path = store_metadata["file_path"] - try: - stat(file_path, username=username, password=password, port=port) - except smbprotocol.exceptions.SMBOSError as error: - if "No such file" not in str(error): - raise - try: - http_client.deletes( - [ - f"{rag_endpoint}/v1/ingest/documents/{document_data['doc_id']}" - for document_data in rag_metadata.get("data") or [] - ] - ) - _LOG.info(f"Deleting document {document.filename}") - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {document.filename}", - exc_info=True, - ) - _LOG.info(f"Completed unsyncing files from endpoint {rag_endpoint}") - except: - _LOG.warn("Error fetching and updating documents", exc_info=True) + try: + _connect_samba_server(connection) + except Exception as ex: + _LOG.exception( + f"Failed to connect to samba server at {connection['endpoint']}", + ) + raise ValueError(ex) + connection["port"] = port + return { + key: val + for key, val in connection.items() + if key in _valid_keys and not isblank(connection[key]) + } diff --git a/nesis/api/core/document_loaders/sharepoint.py b/nesis/api/core/document_loaders/sharepoint.py index 1ab496a..94f768e 100644 --- a/nesis/api/core/document_loaders/sharepoint.py +++ b/nesis/api/core/document_loaders/sharepoint.py @@ -10,6 +10,7 @@ from office365.sharepoint.client_context import ClientContext from office365.runtime.client_request_exception import ClientRequestException +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor from nesis.api.core.util import http, clean_control, isblank import logging from nesis.api.core.models.entities import Document, Datasource @@ -28,278 +29,195 @@ _LOG = logging.getLogger(__name__) -def fetch_documents( - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - metadata: Dict[str, Any], - cache_client: memcache.Client, -) -> None: - try: +class Processor(DocumentProcessor): + def __init__( + self, + config, + http_client: http.HttpClient, + cache_client: memcache.Client, + datasource: Datasource, + ): + super().__init__(config, http_client, datasource) + self._config = config + self._http_client = http_client + self._cache_client = cache_client + self._datasource = datasource + self._futures = [] - connection = datasource.connection - site_url = connection.get("endpoint") - client_id = connection.get("client_id") - tenant = connection.get("tenant_id") - thumbprint = connection.get("thumbprint") - - with tempfile.NamedTemporaryFile(dir=tempfile.gettempdir()) as tmp: - cert_path = f"{str(pathlib.Path(tmp.name).absolute())}-{uuid.uuid4()}.key" - pathlib.Path(cert_path).write_text(connection["certificate"]) - - _sharepoint_context = ClientContext(site_url).with_client_certificate( - tenant=tenant, - client_id=client_id, - thumbprint=thumbprint, - cert_path=cert_path, - ) + def run(self, metadata: Dict[str, Any]): - _sync_sharepoint_documents( - sp_context=_sharepoint_context, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=metadata, - cache_client=cache_client, - ) - _unsync_sharepoint_documents( - sp_context=_sharepoint_context, - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, - ) - except Exception as ex: - _LOG.exception(f"Error fetching sharepoint documents - {ex}") + try: + connection = self._datasource.connection + site_url = connection.get("endpoint") + client_id = connection.get("client_id") + tenant = connection.get("tenant_id") + thumbprint = connection.get("thumbprint") -def _sync_sharepoint_documents( - sp_context, datasource, rag_endpoint, http_client, metadata, cache_client -): - try: - _LOG.info(f"Initializing sharepoint syncing to endpoint {rag_endpoint}") + with tempfile.NamedTemporaryFile(dir=tempfile.gettempdir()) as tmp: + cert_path = ( + f"{str(pathlib.Path(tmp.name).absolute())}-{uuid.uuid4()}.key" + ) + pathlib.Path(cert_path).write_text(connection["certificate"]) - if sp_context is None: - raise Exception( - "Sharepoint context is null, cannot proceed with document processing." - ) + _sharepoint_context = ClientContext(site_url).with_client_certificate( + tenant=tenant, + client_id=client_id, + thumbprint=thumbprint, + cert_path=cert_path, + ) + + self._sync_sharepoint_documents( + sp_context=_sharepoint_context, + metadata=metadata, + ) + self._unsync_sharepoint_documents( + sp_context=_sharepoint_context, + ) + except Exception as ex: + _LOG.exception(f"Error fetching sharepoint documents - {ex}") - # Data objects allow us to specify folder names - connection = datasource.connection - sharepoint_folders = connection.get("dataobjects") - if sharepoint_folders is None: - _LOG.warning("Sharepoint folders are specified, so I can't do much") + def _sync_sharepoint_documents(self, sp_context, metadata): + try: - sp_folders = sharepoint_folders.split(",") + if sp_context is None: + raise Exception( + "Sharepoint context is null, cannot proceed with document processing." + ) - root_folder = sp_context.web.default_document_library().root_folder + # Data objects allow us to specify folder names + connection = self._datasource.connection + sharepoint_folders = connection.get("dataobjects") + if sharepoint_folders is None: + _LOG.warning("Sharepoint folders are specified, so I can't do much") - for folder_name in sp_folders: - sharepoint_folder = root_folder.folders.get_by_path(folder_name) + sp_folders = sharepoint_folders.split(",") - if sharepoint_folder is None: - _LOG.warning( - f"Cannot retrieve Sharepoint folder {sharepoint_folder} proceeding to process other folders" - ) - continue + root_folder = sp_context.web.default_document_library().root_folder - _process_folder_files( - sharepoint_folder, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=metadata, - cache_client=cache_client, - ) + for folder_name in sp_folders: + sharepoint_folder = root_folder.folders.get_by_path(folder_name) + + if sharepoint_folder is None: + _LOG.warning( + f"Cannot retrieve Sharepoint folder {sharepoint_folder} proceeding to process other folders" + ) + continue - # Recursively get all the child folders - _child_folders_recursive = sharepoint_folder.get_folders( - True - ).execute_query() - for _child_folder in _child_folders_recursive: - _process_folder_files( - _child_folder, - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, + self._process_folder_files( + sharepoint_folder, metadata=metadata, - cache_client=cache_client, ) - _LOG.info(f"Completed syncing to endpoint {rag_endpoint}") - - except Exception as file_ex: - _LOG.exception( - f"Error fetching and updating documents - Error: {file_ex}", exc_info=True - ) - - -def _process_file( - file, datasource: Datasource, rag_endpoint, http_client, metadata, cache_client -): - connection = datasource.connection - site_url = connection.get("endpoint") - parsed_site_url = urlparse(site_url) - site_root_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_site_url) - self_link = f"{site_root_url}{file.serverRelativeUrl}" - _metadata = { - **(metadata or {}), - "file_name": file.name, - "self_link": self_link, - } - """ - We use memcache's add functionality to implement a shared lock to allow for multiple instances - operating - """ - _lock_key = clean_control(f"{__name__}/locks/{self_link}") - if cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): - try: - _sync_document( - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=_metadata, - file=file, + # Recursively get all the child folders + _child_folders_recursive = sharepoint_folder.get_folders( + True + ).execute_query() + for _child_folder in _child_folders_recursive: + self._process_folder_files( + _child_folder, + metadata=metadata, + ) + + except Exception as file_ex: + _LOG.exception( + f"Error fetching and updating documents - Error: {file_ex}", + exc_info=True, ) - finally: - cache_client.delete(_lock_key) - else: - _LOG.info(f"Document {self_link} is already processing") - - -def _process_folder_files( - folder, datasource, rag_endpoint, http_client, metadata, cache_client -): - # process files in folder - _files = folder.get_files(False).execute_query() - for file in _files: - _process_file( - file=file, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=metadata, - cache_client=cache_client, - ) - - -def _sync_document( - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - metadata: dict, - file, -): - connection = datasource.connection - site_url = connection["endpoint"] - _metadata = metadata - - with tempfile.NamedTemporaryFile( - dir=tempfile.gettempdir(), - ) as tmp: - key_parts = file.serverRelativeUrl.split("/") - - path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" - try: - _LOG.info( - f"Starting syncing file {file.name} from {file.serverRelativeUrl}" - ) - # Write item to file - downloaded_file_name = path_to_tmp # os.path.join(path_to_tmp, file.name) - # How can we refine this for efficiency - with open(downloaded_file_name, "wb") as local_file: - file.download(local_file).execute_query() - - document: Document = get_document(document_id=file.unique_id) - if document and document.base_uri == site_url: - store_metadata = document.store_metadata - if store_metadata and store_metadata.get("last_modified"): - last_modified = store_metadata["last_modified"] - if ( - not strptime(date_string=last_modified).replace(tzinfo=None) - < file.time_last_modified - ): - _LOG.debug( - f"Skipping sharepoint document {file.name} already up to date" - ) - return - rag_metadata: dict = document.rag_metadata - if rag_metadata is None: - return - for document_data in rag_metadata.get("data") or []: - try: - un_ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - doc_id=document_data["doc_id"], - ) - except: - _LOG.warning( - f"Failed to delete document {document_data['doc_id']}" - ) - try: - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {file.name}'s record. Continuing anyway..." - ) + def _process_file( + self, + file, + metadata, + ): + connection = self._datasource.connection + site_url = connection.get("endpoint") + parsed_site_url = urlparse(site_url) + site_root_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_site_url) + self_link = f"{site_root_url}{file.serverRelativeUrl}" + _metadata = { + **(metadata or {}), + "file_name": file.name, + "self_link": self_link, + } + + """ + We use memcache's add functionality to implement a shared lock to allow for multiple instances + operating + """ + _lock_key = clean_control(f"{__name__}/locks/{self_link}") + if self._cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): try: - response = ingest_file( - http_client=http_client, - endpoint=rag_endpoint, + self._sync_document( metadata=_metadata, - file_path=downloaded_file_name, + file=file, ) - response_json = json.loads(response) + finally: + self._cache_client.delete(_lock_key) + else: + _LOG.info(f"Document {self_link} is already processing") + + def _process_folder_files(self, folder, metadata): + # process files in folder + _files = folder.get_files(False).execute_query() + for file in _files: + self._process_file( + file=file, + metadata=metadata, + ) - except ValueError: - _LOG.warning( - f"File {downloaded_file_name} ingestion failed", exc_info=True + def _sync_document( + self, + metadata: dict, + file, + ): + connection = self._datasource.connection + site_url = connection["endpoint"] + _metadata = metadata + + with tempfile.NamedTemporaryFile( + dir=tempfile.gettempdir(), + ) as tmp: + key_parts = file.serverRelativeUrl.split("/") + + path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" + + try: + _LOG.info( + f"Starting syncing file {file.name} from {file.serverRelativeUrl}" ) - response_json = {} - except UserWarning: - _LOG.debug(f"File {downloaded_file_name} is already processing") - return - - save_document( - document_id=file.unique_id, - filename=file.serverRelativeUrl, - base_uri=site_url, - rag_metadata=response_json, - datasource_id=datasource.uuid, - store_metadata={ - "file_name": file.name, - "file_url": file.serverRelativeUrl, - "etag": file.unique_id, - "size": file.length, - "author": file.author, - "last_modified": file.time_last_modified.strftime( - DEFAULT_DATETIME_FORMAT - ), - }, - last_modified=file.time_last_modified, - ) - _LOG.info(f"Done syncing object {file.name} in at {file.serverRelativeUrl}") - except Exception as ex: - _LOG.warning( - f"Error when getting and ingesting file {file.name}", exc_info=True - ) + # Write item to file + # How can we refine this for efficiency + with open(path_to_tmp, "wb") as local_file: + file.download(local_file).execute_query() -def _unsync_sharepoint_documents(sp_context, http_client, rag_endpoint, connection): + self.sync( + site_url, + path_to_tmp, + last_modified=file.time_last_modified, + metadata=metadata, + store_metadata={ + "filename": file.name, + "file_url": file.serverRelativeUrl, + "etag": file.unique_id, + "size": file.length, + "author": file.author, + "last_modified": file.time_last_modified.strftime( + DEFAULT_DATETIME_FORMAT + ), + }, + ) + except: + _LOG.warning( + f"Error when getting and ingesting file {file.name}", exc_info=True + ) - try: - site_url = connection.get("endpoint") + def _unsync_sharepoint_documents(self, sp_context): - if sp_context is None: - raise Exception( - "Sharepoint context is null, cannot proceed with document processing." - ) + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] - documents = get_documents(base_uri=site_url) - for document in documents: - store_metadata = document.store_metadata - rag_metadata = document.rag_metadata file_url = store_metadata["file_url"] try: # Check that the file still exists on the sharepoint server @@ -308,27 +226,14 @@ def _unsync_sharepoint_documents(sp_context, http_client, rag_endpoint, connecti ).get().execute_query() except ClientRequestException as e: if e.response.status_code == 404: - # File no longer exists on sharepoint server so we need to delete from model - try: - http_client.deletes( - urls=[ - f"{rag_endpoint}/v1/ingest/documents/{document_data['doc_id']}" - for document_data in rag_metadata.get("data") or [] - ] - ) - _LOG.info(f"Deleting document {document.filename}") - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {document.filename}", - exc_info=True, - ) - except Exception as ex: - _LOG.warning( - f"Failed to retrieve file {file_url} from sharepoint - {ex}" - ) - except: - _LOG.warning("Error fetching and updating documents", exc_info=True) + return True + else: + raise + + try: + self.unsync(clean=clean) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: diff --git a/nesis/api/core/tasks/document_management.py b/nesis/api/core/tasks/document_management.py index 890babd..c6e06bd 100644 --- a/nesis/api/core/tasks/document_management.py +++ b/nesis/api/core/tasks/document_management.py @@ -51,28 +51,35 @@ def ingest_datasource(**kwargs) -> None: minio_ingestor.run(metadata=metadata) case DatasourceType.SHAREPOINT: - sharepoint.fetch_documents( - datasource=datasource, - rag_endpoint=rag_endpoint, + + ingestor = sharepoint.Processor( + config=config, http_client=http_client, cache_client=cache_client, - metadata={"datasource": datasource.name}, + datasource=datasource, ) + + ingestor.run(metadata=metadata) case DatasourceType.WINDOWS_SHARE: - samba.fetch_documents( - connection=datasource.connection, - rag_endpoint=rag_endpoint, + + ingestor = samba.Processor( + config=config, http_client=http_client, - metadata={"datasource": datasource.name}, cache_client=cache_client, + datasource=datasource, ) + + ingestor.run(metadata=metadata) + case DatasourceType.S3: - s3.fetch_documents( - datasource=datasource, - rag_endpoint=rag_endpoint, + minio_ingestor = s3.Processor( + config=config, http_client=http_client, - metadata={"datasource": datasource.name}, cache_client=cache_client, + datasource=datasource, ) + + minio_ingestor.run(metadata=metadata) + case _: raise ValueError("Invalid datasource type") diff --git a/nesis/api/tests/core/document_loaders/test_minio.py b/nesis/api/tests/core/document_loaders/test_minio.py index 8148d71..9e51366 100644 --- a/nesis/api/tests/core/document_loaders/test_minio.py +++ b/nesis/api/tests/core/document_loaders/test_minio.py @@ -400,11 +400,17 @@ def test_update_ingest_documents( session.add(datasource) session.commit() - # The document record + self_link = "http://localhost:4566/my-test-bucket/SomeName" + # The document record document = Document( base_uri="http://localhost:4566", - document_id="d41d8cd98f00b204e9800998ecf8427e", + document_id=str( + uuid.uuid5( + uuid.NAMESPACE_DNS, + f"{datasource.uuid}:{self_link}", + ) + ), filename="invalid.pdf", rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, store_metadata={ @@ -449,8 +455,8 @@ def test_update_ingest_documents( ) # The document would be deleted from the rag engine - _, upload_kwargs = http_client.deletes.call_args_list[0] - urls = upload_kwargs["urls"] + _, deletes_kwargs = http_client.deletes.call_args_list[0] + urls = deletes_kwargs["urls"] assert ( urls[0] @@ -474,7 +480,7 @@ def test_update_ingest_documents( { "datasource": "documents", "file_name": "my-test-bucket/SomeName", - "self_link": "http://localhost:4566/my-test-bucket/SomeName", + "self_link": self_link, }, ) diff --git a/nesis/api/tests/core/document_loaders/test_s3.py b/nesis/api/tests/core/document_loaders/test_s3.py index d95f50d..f1c2028 100644 --- a/nesis/api/tests/core/document_loaders/test_s3.py +++ b/nesis/api/tests/core/document_loaders/test_s3.py @@ -13,6 +13,7 @@ import nesis.api.core.services as services import nesis.api.core.document_loaders.s3 as s3 from nesis.api import tests +from nesis.api.core.document_loaders.stores import SqlDocumentStore from nesis.api.core.models import DBSession from nesis.api.core.models import initialize_engine from nesis.api.core.models.entities import ( @@ -107,12 +108,14 @@ def test_sync_documents( ] s3_client.get_paginator.return_value = paginator - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.upload.call_args_list[0] @@ -150,7 +153,7 @@ def test_update_sync_documents( "connection": { "endpoint": "http://localhost:4566", "region": "us-east-1", - "dataobjects": "my-test-bucket", + "dataobjects": "some-bucket", }, } @@ -164,17 +167,26 @@ def test_update_sync_documents( session.add(datasource) session.commit() + self_link = "http://localhost:4566/some-bucket/invalid.pdf" + + # The document record document = Document( base_uri="http://localhost:4566", - document_id="d41d8cd98f00b204e9800998ecf8427e", + document_id=str( + uuid.uuid5( + uuid.NAMESPACE_DNS, + f"{datasource.uuid}:{self_link}", + ) + ), filename="invalid.pdf", rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, store_metadata={ "bucket_name": "some-bucket", - "object_name": "file/path.pdf", + "object_name": "invalid.pdf", "last_modified": "2023-07-18 06:40:07", }, - last_modified=datetime.datetime.utcnow(), + last_modified=strptime("2023-07-19 06:40:07"), + datasource_id=datasource.uuid, ) session.add(document) @@ -191,8 +203,8 @@ def test_update_sync_documents( "KeyCount": 1, "Contents": [ { - "Key": "image.jpg", - "LastModified": strptime("2023-07-19 06:40:07"), + "Key": "invalid.pdf", + "LastModified": strptime("2023-07-20 06:40:07"), "ETag": "d41d8cd98f00b204e9800998ecf8427e", "Size": 0, "StorageClass": "STANDARD", @@ -206,22 +218,24 @@ def test_update_sync_documents( ] s3_client.get_paginator.return_value = paginator - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + + ingestor.run( + metadata={"datasource": "documents"}, ) # The document would be deleted from the rag engine - _, upload_kwargs = http_client.delete.call_args_list[0] - url = upload_kwargs["url"] + _, deletes_kwargs = http_client.deletes.call_args_list[0] + url = deletes_kwargs["urls"] - assert ( - url - == f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" - ) + assert url == [ + f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" + ] # And then re-ingested _, upload_kwargs = http_client.upload.call_args_list[0] @@ -231,21 +245,22 @@ def test_update_sync_documents( field = upload_kwargs["field"] assert url == f"http://localhost:8080/v1/ingest/files" - assert file_path.endswith("image.jpg") + assert file_path.endswith("invalid.pdf") assert field == "file" ut.TestCase().assertDictEqual( metadata, { "datasource": "documents", - "file_name": "my-test-bucket/image.jpg", - "self_link": "http://localhost:4566/my-test-bucket/image.jpg", + "file_name": "some-bucket/invalid.pdf", + "self_link": self_link, }, ) # The document has now been updated documents = session.query(Document).all() assert len(documents) == 1 - assert documents[0].store_metadata["last_modified"] == "2023-07-19 06:40:07" + assert documents[0].store_metadata["last_modified"] == "2023-07-20 06:40:07" + assert str(documents[0].last_modified) == "2023-07-20 06:40:07" @mock.patch("nesis.api.core.document_loaders.s3.boto3.client") @@ -260,8 +275,6 @@ def test_unsync_s3_documents( "engine": "s3", "connection": { "endpoint": "http://localhost:4566", - # "user": "test", - # "password": "test", "region": "us-east-1", "dataobjects": "some-non-existing-bucket", }, @@ -299,12 +312,15 @@ def test_unsync_s3_documents( documents = session.query(Document).all() assert len(documents) == 1 - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.deletes.call_args_list[0] @@ -316,3 +332,189 @@ def test_unsync_s3_documents( ) documents = session.query(Document).all() assert len(documents) == 0 + + +@mock.patch("nesis.api.core.document_loaders.s3.boto3.client") +def test_extract_documents( + client: mock.MagicMock, cache: mock.MagicMock, session: Session +) -> None: + destination_sql_url = tests.config["database"]["url"] + # destination_sql_url = "mssql+pymssql://sa:Pa55woR.d12345@localhost:11433/master" + data = { + "name": "s3 documents", + "engine": "s3", + "connection": { + "endpoint": "https://s3.endpoint", + "access_key": "", + "secret_key": "", + "dataobjects": "buckets", + "mode": "extract", + "destination": { + "sql": {"url": destination_sql_url}, + }, + }, + } + + datasource = Datasource( + name=data["name"], + connection=data["connection"], + source_type=DatasourceType.S3, + status=DatasourceStatus.ONLINE, + ) + + session.add(datasource) + session.commit() + + http_client = mock.MagicMock() + http_client.upload.return_value = json.dumps({}) + s3_client = mock.MagicMock() + + client.return_value = s3_client + paginator = mock.MagicMock() + paginator.paginate.return_value = [ + { + "KeyCount": 1, + "Contents": [ + { + "Key": "image.jpg", + "LastModified": strptime("2023-07-18 06:40:07"), + "ETag": '"d41d8cd98f00b204e9800998ecf8427e"', + "Size": 0, + "StorageClass": "STANDARD", + "Owner": { + "DisplayName": "webfile", + "ID": "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a", + }, + } + ], + } + ] + s3_client.get_paginator.return_value = paginator + + ingestor = s3.Processor( + config=tests.config, + http_client=http_client, + cache_client=cache, + datasource=datasource, + ) + + extract_store = SqlDocumentStore( + url=data["connection"]["destination"]["sql"]["url"] + ) + + with Session(extract_store._engine) as session: + initial_count = len( + session.query(ingestor._extract_runner._extraction_store.Store) + .filter() + .all() + ) + + ingestor.run( + metadata={"datasource": "documents"}, + ) + + _, upload_kwargs = http_client.upload.call_args_list[0] + url = upload_kwargs["url"] + metadata = upload_kwargs["metadata"] + field = upload_kwargs["field"] + + assert url == "http://localhost:8080/v1/extractions/text" + assert field == "file" + ut.TestCase().assertDictEqual( + metadata, + { + "datasource": "documents", + "file_name": "buckets/image.jpg", + "self_link": "https://s3.endpoint/buckets/image.jpg", + }, + ) + + with Session(extract_store._engine) as session: + all_documents = ( + session.query(ingestor._extract_runner._extraction_store.Store) + .filter() + .all() + ) + assert len(all_documents) == initial_count + 1 + + +@mock.patch("nesis.api.core.document_loaders.s3.boto3.client") +def test_unextract_documents( + client: mock.MagicMock, cache: mock.MagicMock, session: Session +) -> None: + """ + Test deleting of s3 documents from the rag engine if they have been deleted from the s3 bucket + """ + destination_sql_url = tests.config["database"]["url"] + data = { + "name": "s3 documents", + "engine": "s3", + "connection": { + "endpoint": "https://s3.endpoint", + "access_key": "", + "secret_key": "", + "dataobjects": "buckets", + "mode": "extract", + "destination": { + "sql": {"url": destination_sql_url}, + }, + }, + } + datasource = Datasource( + name=data["name"], + connection=data["connection"], + source_type=DatasourceType.MINIO, + status=DatasourceStatus.ONLINE, + ) + + session.add(datasource) + session.commit() + + http_client = mock.MagicMock() + s3_client = mock.MagicMock() + + client.return_value = s3_client + s3_client.head_object.side_effect = Exception("HeadObject Not Found") + + minio_ingestor = s3.Processor( + config=tests.config, + http_client=http_client, + cache_client=cache, + datasource=datasource, + ) + + extract_store = SqlDocumentStore( + url=data["connection"]["destination"]["sql"]["url"] + ) + + with Session(extract_store._engine) as session: + session.query(minio_ingestor._extract_runner._extraction_store.Store).delete() + document = minio_ingestor._extract_runner._extraction_store.Store() + document.base_uri = data["connection"]["endpoint"] + document.uuid = str(uuid.uuid4()) + document.filename = "invalid.pdf" + document.extract_metadata = {"data": [{"doc_id": str(uuid.uuid4())}]} + document.store_metadata = { + "bucket_name": "some-bucket", + "object_name": "file/path.pdf", + } + document.last_modified = datetime.datetime.utcnow() + + session.add(document) + session.commit() + + initial_count = len( + session.query(minio_ingestor._extract_runner._extraction_store.Store) + .filter() + .all() + ) + + minio_ingestor.run( + metadata={"datasource": "documents"}, + ) + + with Session(extract_store._engine) as session: + documents = session.query( + minio_ingestor._extract_runner._extraction_store.Store + ).all() + assert len(documents) == initial_count - 1 diff --git a/nesis/api/tests/core/document_loaders/test_samba.py b/nesis/api/tests/core/document_loaders/test_samba.py index ca65db8..30d2eda 100644 --- a/nesis/api/tests/core/document_loaders/test_samba.py +++ b/nesis/api/tests/core/document_loaders/test_samba.py @@ -1,8 +1,10 @@ +import datetime import json import os import time import unittest as ut import unittest.mock as mock +import uuid import pytest from sqlalchemy.orm.session import Session @@ -14,12 +16,14 @@ from nesis.api.core.models import initialize_engine from nesis.api.core.models.entities import ( Datasource, + Document, ) from nesis.api.core.models.objects import ( DatasourceType, DatasourceStatus, ) +from nesis.api.core.util.dateutil import strptime @pytest.fixture @@ -48,14 +52,12 @@ def configure() -> None: @mock.patch("nesis.api.core.document_loaders.samba.scandir") @mock.patch("nesis.api.core.document_loaders.samba.stat") @mock.patch("nesis.api.core.document_loaders.samba.shutil") -def test_fetch_documents( - shutil, stat, scandir, cache: mock.MagicMock, session: Session -) -> None: +def test_ingest(shutil, stat, scandir, cache: mock.MagicMock, session: Session) -> None: data = { "name": "s3 documents", - "engine": "s3", + "engine": "samba", "connection": { - "endpoint": "https://s3.endpoint", + "endpoint": r"\\Share", "user": "user", "port": "445", "password": "password", @@ -87,12 +89,14 @@ def test_fetch_documents( http_client = mock.MagicMock() http_client.upload.return_value = json.dumps({}) - samba.fetch_documents( - connection=data["connection"], + ingestor = samba.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.upload.call_args_list[0] @@ -112,3 +116,180 @@ def test_fetch_documents( "self_link": r"\\Share\SomeName", }, ) + + +@mock.patch("nesis.api.core.document_loaders.samba.scandir") +@mock.patch("nesis.api.core.document_loaders.samba.stat") +@mock.patch("nesis.api.core.document_loaders.samba.shutil") +def test_uningest( + shutil, stat, scandir, cache: mock.MagicMock, session: Session +) -> None: + """ + Test deleting of s3 documents from the rag engine if they have been deleted from the s3 bucket + """ + data = { + "name": "s3 documents", + "engine": "windows_share", + "connection": { + "endpoint": r"\\Share", + "user": "user", + "port": "445", + "password": "password", + "dataobjects": "buckets", + }, + } + + datasource = Datasource( + name=data["name"], + connection=data["connection"], + source_type=DatasourceType.SHAREPOINT, + status=DatasourceStatus.ONLINE, + ) + + session.add(datasource) + session.commit() + + document = Document( + base_uri=datasource.connection["endpoint"], + datasource_id=datasource.uuid, + document_id=str(uuid.uuid4()), + filename="invalid.pdf", + rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, + store_metadata={ + "bucket_name": "some-bucket", + "object_name": "file/path.pdf", + "file_path": r"\\Share\file\path.pdf", + }, + last_modified=datetime.datetime.utcnow(), + ) + + session.add(document) + session.commit() + + http_client = mock.MagicMock() + + stat.side_effect = Exception("No such file") + + documents = session.query(Document).all() + assert len(documents) == 1 + + ingestor = samba.Processor( + config=tests.config, + http_client=http_client, + cache_client=cache, + datasource=datasource, + ) + + ingestor.run( + metadata={"datasource": "documents"}, + ) + + _, upload_kwargs = http_client.deletes.call_args_list[0] + urls = upload_kwargs["urls"] + + assert ( + urls[0] + == f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" + ) + documents = session.query(Document).all() + assert len(documents) == 0 + + +@mock.patch("nesis.api.core.document_loaders.samba.scandir") +@mock.patch("nesis.api.core.document_loaders.samba.stat") +@mock.patch("nesis.api.core.document_loaders.samba.shutil") +def test_update(shutil, stat, scandir, cache: mock.MagicMock, session: Session) -> None: + """ + Test updating documents if they have been updated at the s3 bucket end. + """ + + data = { + "name": "s3 documents", + "engine": "windows_share", + "connection": { + "endpoint": r"\\Share", + "user": "user", + "port": "445", + "password": "password", + "dataobjects": "buckets", + }, + } + + datasource = Datasource( + name=data["name"], + connection=data["connection"], + source_type=DatasourceType.SHAREPOINT, + status=DatasourceStatus.ONLINE, + ) + + session.add(datasource) + session.commit() + + self_link = "http://localhost:4566/some-bucket/invalid.pdf" + + # The document record + document = Document( + base_uri=r"\\Share", + document_id=str( + uuid.uuid5( + uuid.NAMESPACE_DNS, + rf"{datasource.uuid}:\\Share\file\path.pdf", + ) + ), + filename="invalid.pdf", + rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, + store_metadata={ + "bucket_name": "some-bucket", + "object_name": "invalid.pdf", + "last_modified": "2023-07-18 06:40:07", + "file_path": r"\\Share\file\path.pdf", + }, + last_modified=strptime("2023-07-19 06:40:07"), + datasource_id=datasource.uuid, + ) + + session.add(document) + session.commit() + + share = mock.MagicMock() + share.is_dir.return_value = False + type(share).name = mock.PropertyMock(return_value="SomeName") + type(share).path = mock.PropertyMock(return_value=r"\\Share\file\path.pdf") + scandir.return_value = [share] + + file_stat = mock.MagicMock() + stat.return_value = file_stat + type(file_stat).st_size = mock.PropertyMock(return_value=1) + type(file_stat).st_chgtime = mock.PropertyMock( + return_value=strptime("2023-07-20 06:40:07").timestamp() + ) + + http_client = mock.MagicMock() + http_client.upload.return_value = json.dumps({}) + + ingestor = samba.Processor( + config=tests.config, + http_client=http_client, + cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, + ) + + # The document would be deleted from the rag engine + _, deletes_kwargs = http_client.deletes.call_args_list[0] + url = deletes_kwargs["urls"] + + assert url == [ + f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" + ] + + # And then re-ingested + _, upload_kwargs = http_client.upload.call_args_list[0] + + # The document has now been updated + documents = session.query(Document).all() + assert len(documents) == 1 + assert documents[0].store_metadata["last_modified"] == "2023-07-20 06:40:07" + assert str(documents[0].last_modified) == "2023-07-20 06:40:07" diff --git a/nesis/api/tests/core/document_loaders/test_sharepoint.py b/nesis/api/tests/core/document_loaders/test_sharepoint.py index 5ad3d91..fcc48ec 100644 --- a/nesis/api/tests/core/document_loaders/test_sharepoint.py +++ b/nesis/api/tests/core/document_loaders/test_sharepoint.py @@ -22,10 +22,11 @@ DatasourceType, DatasourceStatus, ) +from nesis.api.core.util.dateutil import strptime @mock.patch("nesis.api.core.document_loaders.sharepoint.ClientContext") -def test_sync_sharepoint_documents( +def test_ingest( client_context: mock.MagicMock, cache: mock.MagicMock, session: Session ) -> None: data = { @@ -81,12 +82,14 @@ def test_sync_sharepoint_documents( http_client = mock.MagicMock() http_client.upload.return_value = json.dumps({}) - sharepoint.fetch_documents( - datasource=datasource, + ingestor = sharepoint.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.upload.call_args_list[0] @@ -109,7 +112,7 @@ def test_sync_sharepoint_documents( @mock.patch("nesis.api.core.document_loaders.sharepoint.ClientContext") -def test_sync_updated_sharepoint_documents( +def test_updated( sharepoint_context: mock.MagicMock, cache: mock.MagicMock, session: Session ) -> None: """ @@ -139,10 +142,13 @@ def test_sync_updated_sharepoint_documents( session.add(datasource) session.commit() + self_link = "https://ametnes.sharepoint.com/sites/nesit-test/Shared Documents/sharepoint_file.pdf" document = Document( datasource_id=datasource.uuid, base_uri="https://ametnes.sharepoint.com/sites/nesis-test/", - document_id="edu323-23423-23frs-234232", + document_id=str( + uuid.uuid5(uuid.NAMESPACE_DNS, f"{datasource.uuid}:{self_link}") + ), filename="sharepoint_file.pdf", rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, store_metadata={ @@ -153,7 +159,7 @@ def test_sync_updated_sharepoint_documents( "author": "author_name", "last_modified": "2024-01-10 06:40:07", }, - last_modified=datetime.datetime.utcnow(), + last_modified=strptime("2024-01-10 06:40:07"), ) session.add(document) @@ -180,7 +186,7 @@ def test_sync_updated_sharepoint_documents( ) type(file_mock).time_last_modified = mock.PropertyMock( return_value=datetime.datetime.strptime( - "2024-04-10 06:40:07", "%Y-%m-%d %H:%M:%S" + "2024-04-11 06:40:07", "%Y-%m-%d %H:%M:%S" ) ) type(file_mock).length = mock.PropertyMock(return_value=2023) @@ -195,22 +201,23 @@ def test_sync_updated_sharepoint_documents( http_client = mock.MagicMock() http_client.upload.return_value = json.dumps({}) - sharepoint.fetch_documents( - datasource=datasource, + ingestor = sharepoint.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) # The document would be deleted from the rag engine - _, upload_kwargs = http_client.delete.call_args_list[0] - url = upload_kwargs["url"] + _, upload_kwargs = http_client.deletes.call_args_list[0] + urls = upload_kwargs["urls"] - assert ( - url - == f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" - ) + assert urls == [ + f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" + ] # And then re-ingested _, upload_kwargs = http_client.upload.call_args_list[0] @@ -234,11 +241,11 @@ def test_sync_updated_sharepoint_documents( # The document has now been updated documents = session.query(Document).all() assert len(documents) == 1 - assert documents[0].store_metadata["last_modified"] == "2024-04-10 06:40:07" + assert documents[0].store_metadata["last_modified"] == "2024-04-11 06:40:07" @mock.patch("nesis.api.core.document_loaders.sharepoint.ClientContext") -def test_unsync_sharepoint_documents( +def test_uningest( sharepoint_context: mock.MagicMock, cache: mock.MagicMock, session: Session ) -> None: """ @@ -318,12 +325,14 @@ def test_unsync_sharepoint_documents( documents = session.query(Document).all() assert len(documents) == 1 - sharepoint.fetch_documents( - datasource=datasource, + ingestor = sharepoint.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.deletes.call_args_list[0] diff --git a/nesis/api/tests/tasks/test_document_management.py b/nesis/api/tests/tasks/test_document_management.py index c8bf7a6..55c9546 100644 --- a/nesis/api/tests/tasks/test_document_management.py +++ b/nesis/api/tests/tasks/test_document_management.py @@ -97,15 +97,9 @@ def test_ingest_datasource_minio( @mock.patch("nesis.api.core.document_loaders.samba.scandir") -@mock.patch("nesis.api.core.tasks.document_management.samba._unsync_samba_documents") -@mock.patch("nesis.api.core.tasks.document_management.samba._sync_samba_documents") +@mock.patch("nesis.api.core.tasks.document_management.samba.Processor") def test_ingest_datasource_samba( - _sync_samba_documents: mock.MagicMock, - _unsync_samba_documents: mock.MagicMock, - scandir, - tc, - cache_client, - http_client, + ingestor: mock.MagicMock, scandir, tc, cache_client, http_client ): """ Test the ingestion happy path @@ -117,7 +111,7 @@ def test_ingest_datasource_samba( password=tests.admin_password, ) datasource: Datasource = create_datasource( - token=admin_user.token, datasource_type="windows_share" + token=admin_user.token, datasource_type="WINDOWS_SHARE" ) ingest_datasource( @@ -127,21 +121,9 @@ def test_ingest_datasource_samba( params={"datasource": {"id": datasource.uuid}}, ) - _, kwargs_sync_samba_documents = _sync_samba_documents.call_args_list[0] - assert ( - kwargs_sync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) - tc.assertDictEqual(kwargs_sync_samba_documents["connection"], datasource.connection) - tc.assertDictEqual( - kwargs_sync_samba_documents["metadata"], {"datasource": datasource.name} - ) - - _, kwargs_unsync_samba_documents = _unsync_samba_documents.call_args_list[0] - assert ( - kwargs_unsync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) + _, kwargs_fetch_documents = ingestor.return_value.run.call_args_list[0] tc.assertDictEqual( - kwargs_unsync_samba_documents["connection"], datasource.connection + kwargs_fetch_documents["metadata"], {"datasource": datasource.name} ) @@ -162,17 +144,8 @@ def test_ingest_datasource_invalid_datasource( assert "Invalid datasource" in str(ex_info) -@mock.patch("nesis.api.core.document_loaders.s3.boto3.client") -@mock.patch("nesis.api.core.tasks.document_management.s3._unsync_documents") -@mock.patch("nesis.api.core.tasks.document_management.s3._sync_documents") -def test_ingest_datasource_s3( - _sync_documents: mock.MagicMock, - _unsync_documents: mock.MagicMock, - client: mock.MagicMock(), - tc, - cache_client, - http_client, -): +@mock.patch("nesis.api.core.tasks.document_management.s3.Processor") +def test_ingest_datasource_s3(ingestor: mock.MagicMock, tc, cache_client, http_client): """ Test the ingestion happy path """ @@ -193,37 +166,15 @@ def test_ingest_datasource_s3( params={"datasource": {"id": datasource.uuid}}, ) - _, kwargs_sync_samba_documents = _sync_documents.call_args_list[0] - assert ( - kwargs_sync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) - assert kwargs_sync_samba_documents.get("datasource") is not None - - tc.assertDictEqual( - kwargs_sync_samba_documents["metadata"], {"datasource": datasource.name} - ) - - _, kwargs_unsync_samba_documents = _unsync_documents.call_args_list[0] - assert ( - kwargs_unsync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) + _, kwargs_fetch_documents = ingestor.return_value.run.call_args_list[0] tc.assertDictEqual( - kwargs_unsync_samba_documents["connection"], datasource.connection + kwargs_fetch_documents["metadata"], {"datasource": datasource.name} ) -@mock.patch( - "nesis.api.core.tasks.document_management.sharepoint._unsync_sharepoint_documents" -) -@mock.patch( - "nesis.api.core.tasks.document_management.sharepoint._sync_sharepoint_documents" -) +@mock.patch("nesis.api.core.tasks.document_management.sharepoint.Processor") def test_ingest_datasource_sharepoint( - _sync_sharepoint_documents: mock.MagicMock, - _unsync_sharepoint_documents: mock.MagicMock, - tc, - cache_client, - http_client, + ingestor: mock.MagicMock, tc, cache_client, http_client ): """ Test the ingestion happy path @@ -245,24 +196,7 @@ def test_ingest_datasource_sharepoint( params={"datasource": {"id": datasource.uuid}}, ) - _, kwargs_sync_sharepoint_documents = _sync_sharepoint_documents.call_args_list[0] - assert ( - kwargs_sync_sharepoint_documents["rag_endpoint"] - == tests.config["rag"]["endpoint"] - ) - assert kwargs_sync_sharepoint_documents.get("datasource") is not None - - tc.assertDictEqual( - kwargs_sync_sharepoint_documents["metadata"], {"datasource": datasource.name} - ) - - _, kwargs_unsync_sharepoint_documents = _unsync_sharepoint_documents.call_args_list[ - 0 - ] - assert ( - kwargs_unsync_sharepoint_documents["rag_endpoint"] - == tests.config["rag"]["endpoint"] - ) + _, kwargs_fetch_documents = ingestor.return_value.run.call_args_list[0] tc.assertDictEqual( - kwargs_unsync_sharepoint_documents["connection"], datasource.connection + kwargs_fetch_documents["metadata"], {"datasource": datasource.name} )