Skip to content

Commit

Permalink
migrate weaviate connector to new framework
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Jun 6, 2024
1 parent f1cab24 commit a0134e4
Show file tree
Hide file tree
Showing 11 changed files with 418 additions and 20 deletions.
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/cli/base/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_pipline(
f"setting destination on pipeline {dest} with options: {destination_options}"
)
if uploader_stager := self.get_upload_stager(dest=dest, options=destination_options):
pipeline_kwargs["upload_stager"] = uploader_stager
pipeline_kwargs["stager"] = uploader_stager
pipeline_kwargs["uploader"] = self.get_uploader(dest=dest, options=destination_options)
else:
# Default to local uploader
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_upload_stager(dest: str, options: dict[str, Any]) -> Optional[UploadStag
dest_entry = destination_registry[dest]
upload_stager_kwargs: dict[str, Any] = {}
if upload_stager_config_cls := dest_entry.upload_stager_config:
upload_stager_kwargs["config"] = extract_config(
upload_stager_kwargs["upload_stager_config"] = extract_config(
flat_data=options, config=upload_stager_config_cls
)
if upload_stager_cls := dest_entry.upload_stager:
Expand Down
2 changes: 2 additions & 0 deletions unstructured/ingest/v2/cli/cmds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .fsspec.s3 import s3_dest_cmd, s3_src_cmd
from .fsspec.sftp import sftp_dest_cmd, sftp_src_cmd
from .local import local_dest_cmd, local_src_cmd
from .weaviate import weaviate_dest_cmd

src_cmds = [
azure_src_cmd,
Expand Down Expand Up @@ -37,6 +38,7 @@
local_dest_cmd,
s3_dest_cmd,
sftp_dest_cmd,
weaviate_dest_cmd,
]

duplicate_dest_names = [
Expand Down
100 changes: 100 additions & 0 deletions unstructured/ingest/v2/cli/cmds/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import dataclass

import click

from unstructured.ingest.v2.cli.base import DestCmd
from unstructured.ingest.v2.cli.interfaces import CliConfig
from unstructured.ingest.v2.cli.utils import DelimitedString
from unstructured.ingest.v2.processes.connectors.weaviate import CONNECTOR_TYPE


@dataclass
class WeaviateCliConnectionConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--host-url"],
required=True,
help="Weaviate instance url",
),
click.Option(
["--class-name"],
default=None,
type=str,
help="Name of the class to push the records into, e.g: Pdf-elements",
),
click.Option(
["--access-token"], default=None, type=str, help="Used to create the bearer token."
),
click.Option(
["--refresh-token"],
default=None,
type=str,
help="Will tie this value to the bearer token. If not provided, "
"the authentication will expire once the lifetime of the access token is up.",
),
click.Option(
["--api-key"],
default=None,
type=str,
),
click.Option(
["--client-secret"],
default=None,
type=str,
),
click.Option(
["--scope"],
default=None,
type=DelimitedString(),
),
click.Option(
["--username"],
default=None,
type=str,
),
click.Option(
["--password"],
default=None,
type=str,
),
click.Option(
["--anonymous"],
is_flag=True,
default=False,
type=bool,
help="if set, all auth values will be ignored",
),
]
return options


@dataclass
class WeaviateCliUploaderConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--batch-size"],
default=100,
type=int,
help="Number of records per batch",
)
]
return options


@dataclass
class WeaviateCliUploadStagerConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
return []


weaviate_dest_cmd = DestCmd(
cmd_name=CONNECTOR_TYPE,
connection_config=WeaviateCliConnectionConfig,
uploader_config=WeaviateCliUploaderConfig,
upload_stager_config=WeaviateCliUploadStagerConfig,
)
1 change: 0 additions & 1 deletion unstructured/ingest/v2/cli/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def add_params(cmd: click.Command, params: list[click.Parameter]):
existing_opts = []
for param in cmd.params:
existing_opts.extend(param.opts)

for param in params:
for opt in param.opts:
if opt in existing_opts:
Expand Down
8 changes: 4 additions & 4 deletions unstructured/ingest/v2/interfaces/connector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar

from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field
from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin


@dataclass
Expand All @@ -16,7 +16,7 @@ class AccessConfig(EnhancedDataClassJsonMixin):

@dataclass
class ConnectionConfig(EnhancedDataClassJsonMixin):
access_config: Optional[AccessConfigT] = enhanced_field(sensitive=True, default=None)
access_config: AccessConfigT

def get_access_config(self) -> dict[str, Any]:
if not self.access_config:
Expand All @@ -29,4 +29,4 @@ def get_access_config(self) -> dict[str, Any]:

@dataclass
class BaseConnector(ABC):
connection_config: Optional[ConnectionConfigT] = None
connection_config: ConnectionConfigT
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/interfaces/downloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, TypeVar

Expand All @@ -19,7 +19,7 @@ class DownloaderConfig(EnhancedDataClassJsonMixin):

class Downloader(BaseProcess, BaseConnector, ABC):
connector_type: str
download_config: Optional[DownloaderConfigT] = field(default_factory=DownloaderConfig)
download_config: DownloaderConfigT

@property
def download_dir(self) -> Path:
Expand Down
26 changes: 23 additions & 3 deletions unstructured/ingest/v2/interfaces/upload_stager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,28 @@ class UploadStager(BaseProcess, ABC):
upload_stager_config: Optional[UploadStagerConfigT] = None

@abstractmethod
def run(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
pass

async def run_async(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
return self.run(elements_filepath=elements_filepath, file_data=file_data, **kwargs)
async def run_async(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
return self.run(
elements_filepath=elements_filepath,
output_dir=output_dir,
output_filename=output_filename,
file_data=file_data,
**kwargs
)
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/interfaces/uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypeVar

Expand All @@ -25,7 +25,7 @@ class UploadContent:

@dataclass
class Uploader(BaseProcess, BaseConnector, ABC):
upload_config: UploaderConfigT = field(default_factory=UploaderConfig)
upload_config: UploaderConfigT

def is_async(self) -> bool:
return False
Expand Down
28 changes: 24 additions & 4 deletions unstructured/ingest/v2/pipeline/steps/stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import json
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict
from typing import Optional, TypedDict

from unstructured.ingest.v2.interfaces.file_data import FileData
from unstructured.ingest.v2.interfaces.upload_stager import UploadStager
Expand Down Expand Up @@ -30,12 +32,16 @@ def __post_init__(self):
if self.process.upload_stager_config
else None
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Created {self.identifier} with configs: {config}")

def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse:
path = Path(path)
staged_output_path = self.process.run(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))

Expand All @@ -44,10 +50,24 @@ async def _run_async(self, path: str, file_data_path: str) -> UploadStageStepRes
if semaphore := self.context.semaphore:
async with semaphore:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
else:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))

def get_hash(self, extras: Optional[list[str]]) -> str:
hashable_string = json.dumps(
self.process.upload_stager_config.to_dict(), sort_keys=True, ensure_ascii=True
)
if extras:
hashable_string += "".join(extras)
return hashlib.sha256(hashable_string.encode()).hexdigest()[:12]
25 changes: 23 additions & 2 deletions unstructured/ingest/v2/processes/connectors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from unstructured.documents.elements import DataSourceMetadata
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Downloader,
DownloaderConfig,
FileData,
Expand All @@ -29,6 +31,16 @@
CONNECTOR_TYPE = "local"


@dataclass
class LocalAccessConfig(AccessConfig):
pass


@dataclass
class LocalConnectionConfig(ConnectionConfig):
access_config: LocalAccessConfig = field(default_factory=lambda: LocalAccessConfig())


@dataclass
class LocalIndexerConfig(IndexerConfig):
input_path: str
Expand All @@ -43,6 +55,9 @@ def path(self) -> Path:
@dataclass
class LocalIndexer(Indexer):
index_config: LocalIndexerConfig
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
connector_type: str = CONNECTOR_TYPE

def list_files(self) -> list[Path]:
Expand Down Expand Up @@ -115,7 +130,10 @@ class LocalDownloaderConfig(DownloaderConfig):
@dataclass
class LocalDownloader(Downloader):
connector_type: str = CONNECTOR_TYPE
download_config: Optional[LocalDownloaderConfig] = None
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
download_config: LocalDownloaderConfig = field(default_factory=lambda: LocalDownloaderConfig())

def get_download_path(self, file_data: FileData) -> Path:
return Path(file_data.source_identifiers.fullpath)
Expand All @@ -139,7 +157,10 @@ def __post_init__(self):

@dataclass
class LocalUploader(Uploader):
upload_config: LocalUploaderConfig = field(default_factory=LocalUploaderConfig)
upload_config: LocalUploaderConfig = field(default_factory=lambda: LocalUploaderConfig())
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)

def is_async(self) -> bool:
return False
Expand Down
Loading

0 comments on commit a0134e4

Please sign in to comment.