Skip to content

Commit

Permalink
feat: migrate weaviate connector to new framework (#3160)
Browse files Browse the repository at this point in the history
### Description
Add weaviate output connector to those supported in the new v2 ingest
framework. Some fixes were needed to the upoad stager step as this was
the first connector moved over that leverages this part of the pipeline.
  • Loading branch information
rbiseck3 authored Jun 6, 2024
1 parent a883fc9 commit 0fe0f15
Show file tree
Hide file tree
Showing 13 changed files with 420 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.14.5-dev6
## 0.14.5-dev7

### Enhancements

Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.14.5-dev6" # pragma: no cover
__version__ = "0.14.5-dev7" # pragma: no cover
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/cli/base/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_pipline(
f"setting destination on pipeline {dest} with options: {destination_options}"
)
if uploader_stager := self.get_upload_stager(dest=dest, options=destination_options):
pipeline_kwargs["upload_stager"] = uploader_stager
pipeline_kwargs["stager"] = uploader_stager
pipeline_kwargs["uploader"] = self.get_uploader(dest=dest, options=destination_options)
else:
# Default to local uploader
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_upload_stager(dest: str, options: dict[str, Any]) -> Optional[UploadStag
dest_entry = destination_registry[dest]
upload_stager_kwargs: dict[str, Any] = {}
if upload_stager_config_cls := dest_entry.upload_stager_config:
upload_stager_kwargs["config"] = extract_config(
upload_stager_kwargs["upload_stager_config"] = extract_config(
flat_data=options, config=upload_stager_config_cls
)
if upload_stager_cls := dest_entry.upload_stager:
Expand Down
2 changes: 2 additions & 0 deletions unstructured/ingest/v2/cli/cmds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .fsspec.s3 import s3_dest_cmd, s3_src_cmd
from .fsspec.sftp import sftp_dest_cmd, sftp_src_cmd
from .local import local_dest_cmd, local_src_cmd
from .weaviate import weaviate_dest_cmd

src_cmds = [
azure_src_cmd,
Expand Down Expand Up @@ -37,6 +38,7 @@
local_dest_cmd,
s3_dest_cmd,
sftp_dest_cmd,
weaviate_dest_cmd,
]

duplicate_dest_names = [
Expand Down
100 changes: 100 additions & 0 deletions unstructured/ingest/v2/cli/cmds/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import dataclass

import click

from unstructured.ingest.v2.cli.base import DestCmd
from unstructured.ingest.v2.cli.interfaces import CliConfig
from unstructured.ingest.v2.cli.utils import DelimitedString
from unstructured.ingest.v2.processes.connectors.weaviate import CONNECTOR_TYPE


@dataclass
class WeaviateCliConnectionConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--host-url"],
required=True,
help="Weaviate instance url",
),
click.Option(
["--class-name"],
default=None,
type=str,
help="Name of the class to push the records into, e.g: Pdf-elements",
),
click.Option(
["--access-token"], default=None, type=str, help="Used to create the bearer token."
),
click.Option(
["--refresh-token"],
default=None,
type=str,
help="Will tie this value to the bearer token. If not provided, "
"the authentication will expire once the lifetime of the access token is up.",
),
click.Option(
["--api-key"],
default=None,
type=str,
),
click.Option(
["--client-secret"],
default=None,
type=str,
),
click.Option(
["--scope"],
default=None,
type=DelimitedString(),
),
click.Option(
["--username"],
default=None,
type=str,
),
click.Option(
["--password"],
default=None,
type=str,
),
click.Option(
["--anonymous"],
is_flag=True,
default=False,
type=bool,
help="if set, all auth values will be ignored",
),
]
return options


@dataclass
class WeaviateCliUploaderConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--batch-size"],
default=100,
type=int,
help="Number of records per batch",
)
]
return options


@dataclass
class WeaviateCliUploadStagerConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
return []


weaviate_dest_cmd = DestCmd(
cmd_name=CONNECTOR_TYPE,
connection_config=WeaviateCliConnectionConfig,
uploader_config=WeaviateCliUploaderConfig,
upload_stager_config=WeaviateCliUploadStagerConfig,
)
1 change: 0 additions & 1 deletion unstructured/ingest/v2/cli/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def add_params(cmd: click.Command, params: list[click.Parameter]):
existing_opts = []
for param in cmd.params:
existing_opts.extend(param.opts)

for param in params:
for opt in param.opts:
if opt in existing_opts:
Expand Down
8 changes: 4 additions & 4 deletions unstructured/ingest/v2/interfaces/connector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar

from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field
from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin


@dataclass
Expand All @@ -16,7 +16,7 @@ class AccessConfig(EnhancedDataClassJsonMixin):

@dataclass
class ConnectionConfig(EnhancedDataClassJsonMixin):
access_config: Optional[AccessConfigT] = enhanced_field(sensitive=True, default=None)
access_config: AccessConfigT

def get_access_config(self) -> dict[str, Any]:
if not self.access_config:
Expand All @@ -29,4 +29,4 @@ def get_access_config(self) -> dict[str, Any]:

@dataclass
class BaseConnector(ABC):
connection_config: Optional[ConnectionConfigT] = None
connection_config: ConnectionConfigT
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/interfaces/downloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, TypeVar

Expand All @@ -19,7 +19,7 @@ class DownloaderConfig(EnhancedDataClassJsonMixin):

class Downloader(BaseProcess, BaseConnector, ABC):
connector_type: str
download_config: Optional[DownloaderConfigT] = field(default_factory=DownloaderConfig)
download_config: DownloaderConfigT

@property
def download_dir(self) -> Path:
Expand Down
26 changes: 23 additions & 3 deletions unstructured/ingest/v2/interfaces/upload_stager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,28 @@ class UploadStager(BaseProcess, ABC):
upload_stager_config: Optional[UploadStagerConfigT] = None

@abstractmethod
def run(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
pass

async def run_async(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
return self.run(elements_filepath=elements_filepath, file_data=file_data, **kwargs)
async def run_async(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
return self.run(
elements_filepath=elements_filepath,
output_dir=output_dir,
output_filename=output_filename,
file_data=file_data,
**kwargs
)
4 changes: 2 additions & 2 deletions unstructured/ingest/v2/interfaces/uploader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypeVar

Expand All @@ -25,7 +25,7 @@ class UploadContent:

@dataclass
class Uploader(BaseProcess, BaseConnector, ABC):
upload_config: UploaderConfigT = field(default_factory=UploaderConfig)
upload_config: UploaderConfigT

def is_async(self) -> bool:
return False
Expand Down
28 changes: 24 additions & 4 deletions unstructured/ingest/v2/pipeline/steps/stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import json
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict
from typing import Optional, TypedDict

from unstructured.ingest.v2.interfaces.file_data import FileData
from unstructured.ingest.v2.interfaces.upload_stager import UploadStager
Expand Down Expand Up @@ -30,12 +32,16 @@ def __post_init__(self):
if self.process.upload_stager_config
else None
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Created {self.identifier} with configs: {config}")

def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse:
path = Path(path)
staged_output_path = self.process.run(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))

Expand All @@ -44,10 +50,24 @@ async def _run_async(self, path: str, file_data_path: str) -> UploadStageStepRes
if semaphore := self.context.semaphore:
async with semaphore:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
else:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))

def get_hash(self, extras: Optional[list[str]]) -> str:
hashable_string = json.dumps(
self.process.upload_stager_config.to_dict(), sort_keys=True, ensure_ascii=True
)
if extras:
hashable_string += "".join(extras)
return hashlib.sha256(hashable_string.encode()).hexdigest()[:12]
25 changes: 23 additions & 2 deletions unstructured/ingest/v2/processes/connectors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from unstructured.documents.elements import DataSourceMetadata
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Downloader,
DownloaderConfig,
FileData,
Expand All @@ -29,6 +31,16 @@
CONNECTOR_TYPE = "local"


@dataclass
class LocalAccessConfig(AccessConfig):
pass


@dataclass
class LocalConnectionConfig(ConnectionConfig):
access_config: LocalAccessConfig = field(default_factory=lambda: LocalAccessConfig())


@dataclass
class LocalIndexerConfig(IndexerConfig):
input_path: str
Expand All @@ -43,6 +55,9 @@ def path(self) -> Path:
@dataclass
class LocalIndexer(Indexer):
index_config: LocalIndexerConfig
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
connector_type: str = CONNECTOR_TYPE

def list_files(self) -> list[Path]:
Expand Down Expand Up @@ -115,7 +130,10 @@ class LocalDownloaderConfig(DownloaderConfig):
@dataclass
class LocalDownloader(Downloader):
connector_type: str = CONNECTOR_TYPE
download_config: Optional[LocalDownloaderConfig] = None
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
download_config: LocalDownloaderConfig = field(default_factory=lambda: LocalDownloaderConfig())

def get_download_path(self, file_data: FileData) -> Path:
return Path(file_data.source_identifiers.fullpath)
Expand All @@ -139,7 +157,10 @@ def __post_init__(self):

@dataclass
class LocalUploader(Uploader):
upload_config: LocalUploaderConfig = field(default_factory=LocalUploaderConfig)
upload_config: LocalUploaderConfig = field(default_factory=lambda: LocalUploaderConfig())
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)

def is_async(self) -> bool:
return False
Expand Down
Loading

0 comments on commit 0fe0f15

Please sign in to comment.