diff --git a/unstructured/ingest/v2/cli/base/cmd.py b/unstructured/ingest/v2/cli/base/cmd.py index 76badac6c2..d2212e475b 100644 --- a/unstructured/ingest/v2/cli/base/cmd.py +++ b/unstructured/ingest/v2/cli/base/cmd.py @@ -74,7 +74,7 @@ def get_pipline( f"setting destination on pipeline {dest} with options: {destination_options}" ) if uploader_stager := self.get_upload_stager(dest=dest, options=destination_options): - pipeline_kwargs["upload_stager"] = uploader_stager + pipeline_kwargs["stager"] = uploader_stager pipeline_kwargs["uploader"] = self.get_uploader(dest=dest, options=destination_options) else: # Default to local uploader @@ -148,7 +148,7 @@ def get_upload_stager(dest: str, options: dict[str, Any]) -> Optional[UploadStag dest_entry = destination_registry[dest] upload_stager_kwargs: dict[str, Any] = {} if upload_stager_config_cls := dest_entry.upload_stager_config: - upload_stager_kwargs["config"] = extract_config( + upload_stager_kwargs["upload_stager_config"] = extract_config( flat_data=options, config=upload_stager_config_cls ) if upload_stager_cls := dest_entry.upload_stager: diff --git a/unstructured/ingest/v2/cli/cmds/__init__.py b/unstructured/ingest/v2/cli/cmds/__init__.py index 6ce3ece147..93711190b4 100644 --- a/unstructured/ingest/v2/cli/cmds/__init__.py +++ b/unstructured/ingest/v2/cli/cmds/__init__.py @@ -9,6 +9,7 @@ from .fsspec.s3 import s3_dest_cmd, s3_src_cmd from .fsspec.sftp import sftp_dest_cmd, sftp_src_cmd from .local import local_dest_cmd, local_src_cmd +from .weaviate import weaviate_dest_cmd src_cmds = [ azure_src_cmd, @@ -37,6 +38,7 @@ local_dest_cmd, s3_dest_cmd, sftp_dest_cmd, + weaviate_dest_cmd, ] duplicate_dest_names = [ diff --git a/unstructured/ingest/v2/cli/cmds/weaviate.py b/unstructured/ingest/v2/cli/cmds/weaviate.py new file mode 100644 index 0000000000..aaa051d050 --- /dev/null +++ b/unstructured/ingest/v2/cli/cmds/weaviate.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass + +import click + +from unstructured.ingest.v2.cli.base import DestCmd +from unstructured.ingest.v2.cli.interfaces import CliConfig +from unstructured.ingest.v2.cli.utils import DelimitedString +from unstructured.ingest.v2.processes.connectors.weaviate import CONNECTOR_TYPE + + +@dataclass +class WeaviateCliConnectionConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--host-url"], + required=True, + help="Weaviate instance url", + ), + click.Option( + ["--class-name"], + default=None, + type=str, + help="Name of the class to push the records into, e.g: Pdf-elements", + ), + click.Option( + ["--access-token"], default=None, type=str, help="Used to create the bearer token." + ), + click.Option( + ["--refresh-token"], + default=None, + type=str, + help="Will tie this value to the bearer token. If not provided, " + "the authentication will expire once the lifetime of the access token is up.", + ), + click.Option( + ["--api-key"], + default=None, + type=str, + ), + click.Option( + ["--client-secret"], + default=None, + type=str, + ), + click.Option( + ["--scope"], + default=None, + type=DelimitedString(), + ), + click.Option( + ["--username"], + default=None, + type=str, + ), + click.Option( + ["--password"], + default=None, + type=str, + ), + click.Option( + ["--anonymous"], + is_flag=True, + default=False, + type=bool, + help="if set, all auth values will be ignored", + ), + ] + return options + + +@dataclass +class WeaviateCliUploaderConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--batch-size"], + default=100, + type=int, + help="Number of records per batch", + ) + ] + return options + + +@dataclass +class WeaviateCliUploadStagerConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + return [] + + +weaviate_dest_cmd = DestCmd( + cmd_name=CONNECTOR_TYPE, + connection_config=WeaviateCliConnectionConfig, + uploader_config=WeaviateCliUploaderConfig, + upload_stager_config=WeaviateCliUploadStagerConfig, +) diff --git a/unstructured/ingest/v2/cli/interfaces.py b/unstructured/ingest/v2/cli/interfaces.py index 559590e11d..2a8a0e18ba 100644 --- a/unstructured/ingest/v2/cli/interfaces.py +++ b/unstructured/ingest/v2/cli/interfaces.py @@ -19,7 +19,6 @@ def add_params(cmd: click.Command, params: list[click.Parameter]): existing_opts = [] for param in cmd.params: existing_opts.extend(param.opts) - for param in params: for opt in param.opts: if opt in existing_opts: diff --git a/unstructured/ingest/v2/interfaces/connector.py b/unstructured/ingest/v2/interfaces/connector.py index f71f0ca2a2..dc700fc946 100644 --- a/unstructured/ingest/v2/interfaces/connector.py +++ b/unstructured/ingest/v2/interfaces/connector.py @@ -1,8 +1,8 @@ from abc import ABC from dataclasses import dataclass -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar -from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field +from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin @dataclass @@ -16,7 +16,7 @@ class AccessConfig(EnhancedDataClassJsonMixin): @dataclass class ConnectionConfig(EnhancedDataClassJsonMixin): - access_config: Optional[AccessConfigT] = enhanced_field(sensitive=True, default=None) + access_config: AccessConfigT def get_access_config(self) -> dict[str, Any]: if not self.access_config: @@ -29,4 +29,4 @@ def get_access_config(self) -> dict[str, Any]: @dataclass class BaseConnector(ABC): - connection_config: Optional[ConnectionConfigT] = None + connection_config: ConnectionConfigT diff --git a/unstructured/ingest/v2/interfaces/downloader.py b/unstructured/ingest/v2/interfaces/downloader.py index aee4bc47e3..a2c1ce8054 100644 --- a/unstructured/ingest/v2/interfaces/downloader.py +++ b/unstructured/ingest/v2/interfaces/downloader.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, TypeVar @@ -19,7 +19,7 @@ class DownloaderConfig(EnhancedDataClassJsonMixin): class Downloader(BaseProcess, BaseConnector, ABC): connector_type: str - download_config: Optional[DownloaderConfigT] = field(default_factory=DownloaderConfig) + download_config: DownloaderConfigT @property def download_dir(self) -> Path: diff --git a/unstructured/ingest/v2/interfaces/upload_stager.py b/unstructured/ingest/v2/interfaces/upload_stager.py index e89ba331d3..39e28355ac 100644 --- a/unstructured/ingest/v2/interfaces/upload_stager.py +++ b/unstructured/ingest/v2/interfaces/upload_stager.py @@ -21,8 +21,28 @@ class UploadStager(BaseProcess, ABC): upload_stager_config: Optional[UploadStagerConfigT] = None @abstractmethod - def run(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path: + def run( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any + ) -> Path: pass - async def run_async(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path: - return self.run(elements_filepath=elements_filepath, file_data=file_data, **kwargs) + async def run_async( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any + ) -> Path: + return self.run( + elements_filepath=elements_filepath, + output_dir=output_dir, + output_filename=output_filename, + file_data=file_data, + **kwargs + ) diff --git a/unstructured/ingest/v2/interfaces/uploader.py b/unstructured/ingest/v2/interfaces/uploader.py index 03763e299b..520628e5a2 100644 --- a/unstructured/ingest/v2/interfaces/uploader.py +++ b/unstructured/ingest/v2/interfaces/uploader.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any, TypeVar @@ -25,7 +25,7 @@ class UploadContent: @dataclass class Uploader(BaseProcess, BaseConnector, ABC): - upload_config: UploaderConfigT = field(default_factory=UploaderConfig) + upload_config: UploaderConfigT def is_async(self) -> bool: return False diff --git a/unstructured/ingest/v2/pipeline/steps/stage.py b/unstructured/ingest/v2/pipeline/steps/stage.py index e7a3644de5..59bbe90c16 100644 --- a/unstructured/ingest/v2/pipeline/steps/stage.py +++ b/unstructured/ingest/v2/pipeline/steps/stage.py @@ -1,6 +1,8 @@ +import hashlib +import json from dataclasses import dataclass from pathlib import Path -from typing import TypedDict +from typing import Optional, TypedDict from unstructured.ingest.v2.interfaces.file_data import FileData from unstructured.ingest.v2.interfaces.upload_stager import UploadStager @@ -30,12 +32,16 @@ def __post_init__(self): if self.process.upload_stager_config else None ) + self.cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Created {self.identifier} with configs: {config}") def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse: path = Path(path) staged_output_path = self.process.run( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) @@ -44,10 +50,24 @@ async def _run_async(self, path: str, file_data_path: str) -> UploadStageStepRes if semaphore := self.context.semaphore: async with semaphore: staged_output_path = await self.process.run_async( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) else: staged_output_path = await self.process.run_async( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) + + def get_hash(self, extras: Optional[list[str]]) -> str: + hashable_string = json.dumps( + self.process.upload_stager_config.to_dict(), sort_keys=True, ensure_ascii=True + ) + if extras: + hashable_string += "".join(extras) + return hashlib.sha256(hashable_string.encode()).hexdigest()[:12] diff --git a/unstructured/ingest/v2/processes/connectors/local.py b/unstructured/ingest/v2/processes/connectors/local.py index 00e7a4ab84..5cfeae7ef7 100644 --- a/unstructured/ingest/v2/processes/connectors/local.py +++ b/unstructured/ingest/v2/processes/connectors/local.py @@ -8,6 +8,8 @@ from unstructured.documents.elements import DataSourceMetadata from unstructured.ingest.v2.interfaces import ( + AccessConfig, + ConnectionConfig, Downloader, DownloaderConfig, FileData, @@ -29,6 +31,16 @@ CONNECTOR_TYPE = "local" +@dataclass +class LocalAccessConfig(AccessConfig): + pass + + +@dataclass +class LocalConnectionConfig(ConnectionConfig): + access_config: LocalAccessConfig = field(default_factory=lambda: LocalAccessConfig()) + + @dataclass class LocalIndexerConfig(IndexerConfig): input_path: str @@ -43,6 +55,9 @@ def path(self) -> Path: @dataclass class LocalIndexer(Indexer): index_config: LocalIndexerConfig + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) connector_type: str = CONNECTOR_TYPE def list_files(self) -> list[Path]: @@ -115,7 +130,10 @@ class LocalDownloaderConfig(DownloaderConfig): @dataclass class LocalDownloader(Downloader): connector_type: str = CONNECTOR_TYPE - download_config: Optional[LocalDownloaderConfig] = None + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) + download_config: LocalDownloaderConfig = field(default_factory=lambda: LocalDownloaderConfig()) def get_download_path(self, file_data: FileData) -> Path: return Path(file_data.source_identifiers.fullpath) @@ -139,7 +157,10 @@ def __post_init__(self): @dataclass class LocalUploader(Uploader): - upload_config: LocalUploaderConfig = field(default_factory=LocalUploaderConfig) + upload_config: LocalUploaderConfig = field(default_factory=lambda: LocalUploaderConfig()) + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) def is_async(self) -> bool: return False diff --git a/unstructured/ingest/v2/processes/connectors/weaviate.py b/unstructured/ingest/v2/processes/connectors/weaviate.py new file mode 100644 index 0000000000..c273df4ef8 --- /dev/null +++ b/unstructured/ingest/v2/processes/connectors/weaviate.py @@ -0,0 +1,236 @@ +import json +from dataclasses import dataclass, field +from datetime import date, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from dateutil import parser + +from unstructured.ingest.enhanced_dataclass import enhanced_field +from unstructured.ingest.v2.interfaces import ( + AccessConfig, + ConnectionConfig, + FileData, + UploadContent, + Uploader, + UploaderConfig, + UploadStager, + UploadStagerConfig, +) +from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.processes.connector_registry import ( + DestinationRegistryEntry, + add_destination_entry, +) + +if TYPE_CHECKING: + from weaviate import Client + +CONNECTOR_TYPE = "weaviate" + + +@dataclass +class WeaviateAccessConfig(AccessConfig): + access_token: Optional[str] + api_key: Optional[str] + client_secret: Optional[str] + password: Optional[str] + + +@dataclass +class WeaviateConnectionConfig(ConnectionConfig): + host_url: str + class_name: str + access_config: WeaviateAccessConfig = enhanced_field(sensitive=True) + username: Optional[str] = None + anonymous: bool = False + scope: Optional[list[str]] = None + refresh_token: Optional[str] = None + connector_type: str = CONNECTOR_TYPE + + +@dataclass +class WeaviateUploadStagerConfig(UploadStagerConfig): + pass + + +@dataclass +class WeaviateUploadStager(UploadStager): + upload_stager_config: WeaviateUploadStagerConfig = field( + default_factory=lambda: WeaviateUploadStagerConfig() + ) + + @staticmethod + def parse_date_string(date_string: str) -> date: + try: + timestamp = float(date_string) + return datetime.fromtimestamp(timestamp) + except Exception as e: + logger.debug(f"date {date_string} string not a timestamp: {e}") + return parser.parse(date_string) + + @classmethod + def conform_dict(cls, data: dict) -> None: + """ + Updates the element dictionary to conform to the Weaviate schema + """ + + # Dict as string formatting + if record_locator := data.get("metadata", {}).get("data_source", {}).get("record_locator"): + # Explicit casting otherwise fails schema type checking + data["metadata"]["data_source"]["record_locator"] = str(json.dumps(record_locator)) + + # Array of items as string formatting + if points := data.get("metadata", {}).get("coordinates", {}).get("points"): + data["metadata"]["coordinates"]["points"] = str(json.dumps(points)) + + if links := data.get("metadata", {}).get("links", {}): + data["metadata"]["links"] = str(json.dumps(links)) + + if permissions_data := ( + data.get("metadata", {}).get("data_source", {}).get("permissions_data") + ): + data["metadata"]["data_source"]["permissions_data"] = json.dumps(permissions_data) + + # Datetime formatting + if date_created := data.get("metadata", {}).get("data_source", {}).get("date_created"): + data["metadata"]["data_source"]["date_created"] = cls.parse_date_string( + date_created + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if date_modified := data.get("metadata", {}).get("data_source", {}).get("date_modified"): + data["metadata"]["data_source"]["date_modified"] = cls.parse_date_string( + date_modified + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if date_processed := data.get("metadata", {}).get("data_source", {}).get("date_processed"): + data["metadata"]["data_source"]["date_processed"] = cls.parse_date_string( + date_processed + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if last_modified := data.get("metadata", {}).get("last_modified"): + data["metadata"]["last_modified"] = cls.parse_date_string(last_modified).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + # String casting + if version := data.get("metadata", {}).get("data_source", {}).get("version"): + data["metadata"]["data_source"]["version"] = str(version) + + if page_number := data.get("metadata", {}).get("page_number"): + data["metadata"]["page_number"] = str(page_number) + + if regex_metadata := data.get("metadata", {}).get("regex_metadata"): + data["metadata"]["regex_metadata"] = str(json.dumps(regex_metadata)) + + def run( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any, + ) -> Path: + with open(elements_filepath) as elements_file: + elements_contents = json.load(elements_file) + for element in elements_contents: + self.conform_dict(data=element) + output_path = Path(output_dir) / Path(f"{output_filename}.json") + with open(output_path, "w") as output_file: + json.dump(elements_contents, output_file) + return output_path + + +@dataclass +class WeaviateUploaderConfig(UploaderConfig): + batch_size: int = 100 + + +@dataclass +class WeaviateUploader(Uploader): + upload_config: WeaviateUploaderConfig + connection_config: WeaviateConnectionConfig + client: Optional["Client"] = field(init=False) + + def __post_init__(self): + from weaviate import Client + + auth = self._resolve_auth_method() + self.client = Client(url=self.connection_config.host_url, auth_client_secret=auth) + + def is_async(self) -> bool: + return True + + def _resolve_auth_method(self): + access_configs = self.connection_config.access_config + connection_config = self.connection_config + if connection_config.anonymous: + return None + + if access_configs.access_token: + from weaviate.auth import AuthBearerToken + + return AuthBearerToken( + access_token=access_configs.access_token, + refresh_token=connection_config.refresh_token, + ) + elif access_configs.api_key: + from weaviate.auth import AuthApiKey + + return AuthApiKey(api_key=access_configs.api_key) + elif access_configs.client_secret: + from weaviate.auth import AuthClientCredentials + + return AuthClientCredentials( + client_secret=access_configs.client_secret, scope=connection_config.scope + ) + elif connection_config.username and access_configs.password: + from weaviate.auth import AuthClientPassword + + return AuthClientPassword( + username=connection_config.username, + password=access_configs.password, + scope=connection_config.scope, + ) + return None + + def run(self, contents: list[UploadContent], **kwargs: Any) -> None: + raise NotImplementedError + + async def run_async(self, path: Path, file_data: FileData, **kwargs: Any) -> None: + with open(path) as elements_file: + elements_dict = json.load(elements_file) + + logger.info( + f"writing {len(elements_dict)} objects to destination " + f"class {self.connection_config.class_name} " + f"at {self.connection_config.host_url}", + ) + + self.client.batch.configure(batch_size=self.upload_config.batch_size) + with self.client.batch as b: + for e in elements_dict: + vector = e.pop("embeddings", None) + b.add_data_object( + e, + self.connection_config.class_name, + vector=vector, + ) + + +add_destination_entry( + destination_type=CONNECTOR_TYPE, + entry=DestinationRegistryEntry( + connection_config=WeaviateConnectionConfig, + uploader=WeaviateUploader, + uploader_config=WeaviateUploaderConfig, + upload_stager=WeaviateUploadStager, + upload_stager_config=WeaviateUploadStagerConfig, + ), +)