diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 6b79060f07..6efa7ffc4c 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -15,6 +15,10 @@ env: RUNTIME__LOG_LEVEL: ERROR RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + # we need the secrets only for the rest_api_pipeline tests which are in tests/sources + # so we inject them only at the end + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + jobs: get_docs_changes: name: docs changes @@ -87,11 +91,11 @@ jobs: run: poetry install --no-interaction --with sentry-sdk - run: | - poetry run pytest tests/common tests/normalize tests/reflection tests/sources tests/load/test_dummy_client.py tests/extract/test_extract.py tests/extract/test_sources.py tests/pipeline/test_pipeline_state.py + poetry run pytest tests/common tests/normalize tests/reflection tests/load/test_dummy_client.py tests/extract/test_extract.py tests/extract/test_sources.py tests/pipeline/test_pipeline_state.py if: runner.os != 'Windows' name: Run common tests with minimum dependencies Linux/MAC - run: | - poetry run pytest tests/common tests/normalize tests/reflection tests/sources tests/load/test_dummy_client.py tests/extract/test_extract.py tests/extract/test_sources.py tests/pipeline/test_pipeline_state.py -m "not forked" + poetry run pytest tests/common tests/normalize tests/reflection tests/load/test_dummy_client.py tests/extract/test_extract.py tests/extract/test_sources.py tests/pipeline/test_pipeline_state.py -m "not forked" if: runner.os == 'Windows' name: Run common tests with minimum dependencies Windows shell: cmd @@ -122,15 +126,36 @@ jobs: name: Run pipeline tests with pyarrow but no pandas installed Windows shell: cmd - - name: Install pipeline dependencies - run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline -E deltalake + - name: create secrets.toml for examples + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - name: Install pipeline and sources dependencies + run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline -E deltalake -E sql_database + + # TODO: this is needed for the filesystem tests, not sure if this should be in an extra? + - name: Install openpyxl for excel tests + run: poetry run pip install openpyxl + + - run: | + poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations tests/sources + if: runner.os != 'Windows' + name: Run extract and pipeline tests Linux/MAC + - run: | + poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations tests/sources -m "not forked" + if: runner.os == 'Windows' + name: Run extract tests Windows + shell: cmd + + # here we upgrade sql alchemy to 2 an run the sql_database tests again + - name: Upgrade sql alchemy + run: poetry run pip install sqlalchemy==2.0.32 - run: | - poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations + poetry run pytest tests/sources/sql_database if: runner.os != 'Windows' name: Run extract and pipeline tests Linux/MAC - run: | - poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations -m "not forked" + poetry run pytest tests/sources/sql_database if: runner.os == 'Windows' name: Run extract tests Windows shell: cmd diff --git a/.github/workflows/test_destination_athena.yml b/.github/workflows/test_destination_athena.yml index c7aed6f70e..70a79cd218 100644 --- a/.github/workflows/test_destination_athena.yml +++ b/.github/workflows/test_destination_athena.yml @@ -73,11 +73,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || !github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_athena_iceberg.yml b/.github/workflows/test_destination_athena_iceberg.yml index 40514ce58e..2c35a99393 100644 --- a/.github/workflows/test_destination_athena_iceberg.yml +++ b/.github/workflows/test_destination_athena_iceberg.yml @@ -73,11 +73,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_bigquery.yml b/.github/workflows/test_destination_bigquery.yml index b3926fb18c..e0908892b3 100644 --- a/.github/workflows/test_destination_bigquery.yml +++ b/.github/workflows/test_destination_bigquery.yml @@ -72,5 +72,5 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux diff --git a/.github/workflows/test_destination_clickhouse.yml b/.github/workflows/test_destination_clickhouse.yml index 5b6848f2fe..89e189974c 100644 --- a/.github/workflows/test_destination_clickhouse.yml +++ b/.github/workflows/test_destination_clickhouse.yml @@ -75,7 +75,7 @@ jobs: name: Start ClickHouse OSS - - run: poetry run pytest tests/load -m "essential" + - run: poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux (ClickHouse OSS) if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} env: @@ -87,7 +87,7 @@ jobs: DESTINATION__CLICKHOUSE__CREDENTIALS__HTTP_PORT: 8123 DESTINATION__CLICKHOUSE__CREDENTIALS__SECURE: 0 - - run: poetry run pytest tests/load + - run: poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux (ClickHouse OSS) if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} env: @@ -105,12 +105,12 @@ jobs: # ClickHouse Cloud - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux (ClickHouse Cloud) if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux (ClickHouse Cloud) if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_databricks.yml b/.github/workflows/test_destination_databricks.yml index 81ec575145..b3d30bcefc 100644 --- a/.github/workflows/test_destination_databricks.yml +++ b/.github/workflows/test_destination_databricks.yml @@ -70,11 +70,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_dremio.yml b/.github/workflows/test_destination_dremio.yml index 7ec6c4f697..b78e67dc5c 100644 --- a/.github/workflows/test_destination_dremio.yml +++ b/.github/workflows/test_destination_dremio.yml @@ -68,7 +68,7 @@ jobs: run: poetry install --no-interaction -E s3 -E gs -E az -E parquet --with sentry-sdk --with pipeline - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources if: runner.os != 'Windows' name: Run tests Linux/MAC env: @@ -80,7 +80,7 @@ jobs: DESTINATION__MINIO__CREDENTIALS__ENDPOINT_URL: http://127.0.0.1:9010 - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources if: runner.os == 'Windows' name: Run tests Windows shell: cmd diff --git a/.github/workflows/test_destination_lancedb.yml b/.github/workflows/test_destination_lancedb.yml index 02b5ef66eb..b191f79465 100644 --- a/.github/workflows/test_destination_lancedb.yml +++ b/.github/workflows/test_destination_lancedb.yml @@ -71,11 +71,11 @@ jobs: run: poetry run pip install openai - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_motherduck.yml b/.github/workflows/test_destination_motherduck.yml index a51fb3cc8f..6c81dd28f7 100644 --- a/.github/workflows/test_destination_motherduck.yml +++ b/.github/workflows/test_destination_motherduck.yml @@ -70,11 +70,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_mssql.yml b/.github/workflows/test_destination_mssql.yml index 3b5bfd8d42..2065568a5e 100644 --- a/.github/workflows/test_destination_mssql.yml +++ b/.github/workflows/test_destination_mssql.yml @@ -75,5 +75,5 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml # always run full suite, also on branches - - run: poetry run pytest tests/load + - run: poetry run pytest tests/load --ignore tests/load/sources name: Run tests Linux diff --git a/.github/workflows/test_destination_qdrant.yml b/.github/workflows/test_destination_qdrant.yml index 168fe315ce..e231f4dbbb 100644 --- a/.github/workflows/test_destination_qdrant.yml +++ b/.github/workflows/test_destination_qdrant.yml @@ -69,11 +69,11 @@ jobs: run: poetry install --no-interaction -E qdrant -E parquet --with sentry-sdk --with pipeline - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_snowflake.yml b/.github/workflows/test_destination_snowflake.yml index 0c9a2b08d1..a2716fb597 100644 --- a/.github/workflows/test_destination_snowflake.yml +++ b/.github/workflows/test_destination_snowflake.yml @@ -70,11 +70,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destination_synapse.yml b/.github/workflows/test_destination_synapse.yml index 4d3049853c..be1b493916 100644 --- a/.github/workflows/test_destination_synapse.yml +++ b/.github/workflows/test_destination_synapse.yml @@ -73,11 +73,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index 7fae69ff9e..fc7eeadfe2 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -82,11 +82,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load -m "essential" + poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | - poetry run pytest tests/load + poetry run pytest tests/load --ignore tests/load/sources name: Run all tests Linux if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 78ea23ec1c..2d712814bd 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -97,11 +97,8 @@ jobs: - name: Install dependencies run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate -E qdrant --with sentry-sdk --with pipeline -E deltalake - - name: create secrets.toml - run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - # always run full suite, also on branches - - run: poetry run pytest tests/load && poetry run pytest tests/cli + - run: poetry run pytest tests/load --ignore tests/load/sources && poetry run pytest tests/cli name: Run tests Linux env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data diff --git a/.github/workflows/test_local_sources.yml b/.github/workflows/test_local_sources.yml new file mode 100644 index 0000000000..0178f59322 --- /dev/null +++ b/.github/workflows/test_local_sources.yml @@ -0,0 +1,101 @@ +# Tests sources against a couple of local destinations + +name: src | rest_api, sql_database, filesystem + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + + ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\"]" + ALL_FILESYSTEM_DRIVERS: "[\"file\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + + run_loader: + name: src | rest_api, sql_database, filesystem + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + strategy: + fail-fast: false + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres + # Provide the password for postgres + env: + POSTGRES_DB: dlt_data + POSTGRES_USER: loader + POSTGRES_PASSWORD: loader + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-sources + + # TODO: which deps should we enable? + - name: Install dependencies + run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E sql_database --with sentry-sdk --with pipeline + + # run sources tests in load against configured destinations + - run: poetry run pytest tests/load/sources + name: Run tests Linux + env: + DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data + + # here we upgrade sql alchemy to 2 an run the sql_database tests again + - name: Upgrade sql alchemy + run: poetry run pip install sqlalchemy==2.0.32 + + - run: poetry run pytest tests/load/sources/sql_database + name: Run tests Linux + env: + DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data \ No newline at end of file diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index 7c6526c0a2..9e7b12dc53 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -16,7 +16,7 @@ from dlt.cli.init_command import ( init_command, - list_verified_sources_command, + list_sources_command, DLT_INIT_DOCS_URL, DEFAULT_VERIFIED_SOURCES_REPO, ) @@ -54,12 +54,18 @@ def on_exception(ex: Exception, info: str) -> None: def init_command_wrapper( source_name: str, destination_type: str, - use_generic_template: bool, repo_location: str, branch: str, + omit_core_sources: bool = False, ) -> int: try: - init_command(source_name, destination_type, use_generic_template, repo_location, branch) + init_command( + source_name, + destination_type, + repo_location, + branch, + omit_core_sources, + ) except Exception as ex: on_exception(ex, DLT_INIT_DOCS_URL) return -1 @@ -67,9 +73,9 @@ def init_command_wrapper( @utils.track_command("list_sources", False) -def list_verified_sources_command_wrapper(repo_location: str, branch: str) -> int: +def list_sources_command_wrapper(repo_location: str, branch: str) -> int: try: - list_verified_sources_command(repo_location, branch) + list_sources_command(repo_location, branch) except Exception as ex: on_exception(ex, DLT_INIT_DOCS_URL) return -1 @@ -306,11 +312,11 @@ def main() -> int: ), ) init_cmd.add_argument( - "--list-verified-sources", + "--list-sources", "-l", default=False, action="store_true", - help="List available verified sources", + help="List available sources", ) init_cmd.add_argument( "source", @@ -334,14 +340,14 @@ def main() -> int: default=None, help="Advanced. Uses specific branch of the init repository to fetch the template.", ) + init_cmd.add_argument( - "--generic", + "--omit-core-sources", default=False, action="store_true", help=( - "When present uses a generic template with all the dlt loading code present will be" - " used. Otherwise a debug template is used that can be immediately run to get familiar" - " with the dlt sources." + "When present, will not create the new pipeline with a core source of the given name" + " but will take a source of this name from the default or provided location." ), ) @@ -588,15 +594,19 @@ def main() -> int: del command_kwargs["list_pipelines"] return pipeline_command_wrapper(**command_kwargs) elif args.command == "init": - if args.list_verified_sources: - return list_verified_sources_command_wrapper(args.location, args.branch) + if args.list_sources: + return list_sources_command_wrapper(args.location, args.branch) else: if not args.source or not args.destination: init_cmd.print_usage() return -1 else: return init_command_wrapper( - args.source, args.destination, args.generic, args.location, args.branch + args.source, + args.destination, + args.location, + args.branch, + args.omit_core_sources, ) elif args.command == "deploy": try: diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index a1434133f0..797917a165 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -5,6 +5,8 @@ from types import ModuleType from typing import Dict, List, Sequence, Tuple from importlib.metadata import version as pkg_version +from pathlib import Path +from importlib import import_module from dlt.common import git from dlt.common.configuration.paths import get_dlt_settings_dir, make_dlt_settings_path @@ -23,6 +25,7 @@ from dlt.common.schema.utils import is_valid_schema_name from dlt.common.schema.exceptions import InvalidSchemaName from dlt.common.storages.file_storage import FileStorage +from dlt.sources import pipeline_templates as init_module import dlt.reflection.names as n from dlt.reflection.script_inspector import inspect_pipeline_script, load_script_module @@ -31,28 +34,44 @@ from dlt.cli import utils from dlt.cli.config_toml_writer import WritableConfigValue, write_values from dlt.cli.pipeline_files import ( - VerifiedSourceFiles, + SourceConfiguration, TVerifiedSourceFileEntry, TVerifiedSourceFileIndex, ) from dlt.cli.exceptions import CliCommandException from dlt.cli.requirements import SourceRequirements + DLT_INIT_DOCS_URL = "https://dlthub.com/docs/reference/command-line-interface#dlt-init" DEFAULT_VERIFIED_SOURCES_REPO = "https://github.com/dlt-hub/verified-sources.git" -INIT_MODULE_NAME = "init" +TEMPLATES_MODULE_NAME = "pipeline_templates" SOURCES_MODULE_NAME = "sources" -def _get_template_files( - command_module: ModuleType, use_generic_template: bool -) -> Tuple[str, List[str]]: - template_files: List[str] = command_module.TEMPLATE_FILES - pipeline_script: str = command_module.PIPELINE_SCRIPT - if use_generic_template: - pipeline_script, py = os.path.splitext(pipeline_script) - pipeline_script = f"{pipeline_script}_generic{py}" - return pipeline_script, template_files +def _get_core_sources_storage() -> FileStorage: + """Get FileStorage for core sources""" + local_path = Path(os.path.dirname(os.path.realpath(__file__))).parent / SOURCES_MODULE_NAME + return FileStorage(str(local_path)) + + +def _get_templates_storage() -> FileStorage: + """Get FileStorage for single file templates""" + # look up init storage in core + init_path = ( + Path(os.path.dirname(os.path.realpath(__file__))).parent + / SOURCES_MODULE_NAME + / TEMPLATES_MODULE_NAME + ) + return FileStorage(str(init_path)) + + +def _clone_and_get_verified_sources_storage(repo_location: str, branch: str = None) -> FileStorage: + """Clone and get FileStorage for verified sources templates""" + + fmt.echo("Looking up verified sources at %s..." % fmt.bold(repo_location)) + clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) + # copy dlt source files from here + return FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) def _select_source_files( @@ -127,16 +146,38 @@ def _get_dependency_system(dest_storage: FileStorage) -> str: return None +def _list_template_sources() -> Dict[str, SourceConfiguration]: + template_storage = _get_templates_storage() + sources: Dict[str, SourceConfiguration] = {} + for source_name in files_ops.get_sources_names(template_storage, source_type="template"): + sources[source_name] = files_ops.get_template_configuration(template_storage, source_name) + return sources + + +def _list_core_sources() -> Dict[str, SourceConfiguration]: + core_sources_storage = _get_core_sources_storage() + + sources: Dict[str, SourceConfiguration] = {} + for source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"): + sources[source_name] = files_ops.get_core_source_configuration( + core_sources_storage, source_name + ) + return sources + + def _list_verified_sources( repo_location: str, branch: str = None -) -> Dict[str, VerifiedSourceFiles]: - clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) - sources_storage = FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) +) -> Dict[str, SourceConfiguration]: + verified_sources_storage = _clone_and_get_verified_sources_storage(repo_location, branch) - sources: Dict[str, VerifiedSourceFiles] = {} - for source_name in files_ops.get_verified_source_names(sources_storage): + sources: Dict[str, SourceConfiguration] = {} + for source_name in files_ops.get_sources_names( + verified_sources_storage, source_type="verified" + ): try: - sources[source_name] = files_ops.get_verified_source_files(sources_storage, source_name) + sources[source_name] = files_ops.get_verified_source_configuration( + verified_sources_storage, source_name + ) except Exception as ex: fmt.warning(f"Verified source {source_name} not available: {ex}") @@ -146,23 +187,23 @@ def _list_verified_sources( def _welcome_message( source_name: str, destination_type: str, - source_files: VerifiedSourceFiles, + source_configuration: SourceConfiguration, dependency_system: str, is_new_source: bool, ) -> None: fmt.echo() - if source_files.is_template: + if source_configuration.source_type in ["template", "core"]: fmt.echo("Your new pipeline %s is ready to be customized!" % fmt.bold(source_name)) fmt.echo( "* Review and change how dlt loads your data in %s" - % fmt.bold(source_files.dest_pipeline_script) + % fmt.bold(source_configuration.dest_pipeline_script) ) else: if is_new_source: fmt.echo("Verified source %s was added to your project!" % fmt.bold(source_name)) fmt.echo( "* See the usage examples and code snippets to copy from %s" - % fmt.bold(source_files.dest_pipeline_script) + % fmt.bold(source_configuration.dest_pipeline_script) ) else: fmt.echo( @@ -175,9 +216,16 @@ def _welcome_message( % (fmt.bold(destination_type), fmt.bold(make_dlt_settings_path(SECRETS_TOML))) ) + if destination_type == "destination": + fmt.echo( + "* You have selected the custom destination as your pipelines destination. Please refer" + " to our docs at https://dlthub.com/docs/dlt-ecosystem/destinations/destination on how" + " to add a destination function that will consume your data." + ) + if dependency_system: fmt.echo("* Add the required dependencies to %s:" % fmt.bold(dependency_system)) - compiled_requirements = source_files.requirements.compiled() + compiled_requirements = source_configuration.requirements.compiled() for dep in compiled_requirements: fmt.echo(" " + fmt.bold(dep)) fmt.echo( @@ -212,37 +260,69 @@ def _welcome_message( ) -def list_verified_sources_command(repo_location: str, branch: str = None) -> None: - fmt.echo("Looking up for verified sources in %s..." % fmt.bold(repo_location)) - for source_name, source_files in _list_verified_sources(repo_location, branch).items(): - reqs = source_files.requirements +def list_sources_command(repo_location: str, branch: str = None) -> None: + fmt.echo("---") + fmt.echo("Available dlt core sources:") + fmt.echo("---") + core_sources = _list_core_sources() + for source_name, source_configuration in core_sources.items(): + msg = "%s: %s" % (fmt.bold(source_name), source_configuration.doc) + fmt.echo(msg) + + fmt.echo("---") + fmt.echo("Available dlt single file templates:") + fmt.echo("---") + template_sources = _list_template_sources() + for source_name, source_configuration in template_sources.items(): + msg = "%s: %s" % (fmt.bold(source_name), source_configuration.doc) + fmt.echo(msg) + + fmt.echo("---") + fmt.echo("Available verified sources:") + fmt.echo("---") + for source_name, source_configuration in _list_verified_sources(repo_location, branch).items(): + reqs = source_configuration.requirements dlt_req_string = str(reqs.dlt_requirement_base) - msg = "%s: %s" % (fmt.bold(source_name), source_files.doc) + msg = "%s: " % (fmt.bold(source_name)) + if source_name in core_sources.keys(): + msg += "(Deprecated since dlt 1.0.0 in favor of core source of the same name) " + msg += source_configuration.doc if not reqs.is_installed_dlt_compatible(): msg += fmt.warning_style(" [needs update: %s]" % (dlt_req_string)) + fmt.echo(msg) def init_command( source_name: str, destination_type: str, - use_generic_template: bool, repo_location: str, branch: str = None, + omit_core_sources: bool = False, ) -> None: # try to import the destination and get config spec destination_reference = Destination.from_reference(destination_type) destination_spec = destination_reference.spec - fmt.echo("Looking up the init scripts in %s..." % fmt.bold(repo_location)) - clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) - # copy init files from here - init_storage = FileStorage(clone_storage.make_full_path(INIT_MODULE_NAME)) - # copy dlt source files from here - sources_storage = FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) - # load init module and get init files and script - init_module = load_script_module(clone_storage.storage_path, INIT_MODULE_NAME) - pipeline_script, template_files = _get_template_files(init_module, use_generic_template) + # lookup core storages + core_sources_storage = _get_core_sources_storage() + templates_storage = _get_templates_storage() + + # discover type of source + source_type: files_ops.TSourceType = "template" + if ( + source_name in files_ops.get_sources_names(core_sources_storage, source_type="core") + ) and not omit_core_sources: + source_type = "core" + else: + if omit_core_sources: + fmt.echo("Omitting dlt core sources.") + verified_sources_storage = _clone_and_get_verified_sources_storage(repo_location, branch) + if source_name in files_ops.get_sources_names( + verified_sources_storage, source_type="verified" + ): + source_type = "verified" + # prepare destination storage dest_storage = FileStorage(os.path.abspath(".")) if not dest_storage.has_folder(get_dlt_settings_dir()): @@ -256,16 +336,21 @@ def init_command( is_new_source = len(local_index["files"]) == 0 # look for existing source - source_files: VerifiedSourceFiles = None + source_configuration: SourceConfiguration = None remote_index: TVerifiedSourceFileIndex = None - if sources_storage.has_folder(source_name): + remote_modified: Dict[str, TVerifiedSourceFileEntry] = {} + remote_deleted: Dict[str, TVerifiedSourceFileEntry] = {} + + if source_type == "verified": # get pipeline files - source_files = files_ops.get_verified_source_files(sources_storage, source_name) + source_configuration = files_ops.get_verified_source_configuration( + verified_sources_storage, source_name + ) # get file index from remote verified source files being copied remote_index = files_ops.get_remote_source_index( - source_files.storage.storage_path, - source_files.files, - source_files.requirements.dlt_version_constraint(), + source_configuration.storage.storage_path, + source_configuration.files, + source_configuration.requirements.dlt_version_constraint(), ) # diff local and remote index to get modified and deleted files remote_new, remote_modified, remote_deleted = files_ops.gen_index_diff( @@ -292,39 +377,41 @@ def init_command( " update correctly in the future." ) # add template files - source_files.files.extend(template_files) + source_configuration.files.extend(files_ops.TEMPLATE_FILES) else: - if not is_valid_schema_name(source_name): - raise InvalidSchemaName(source_name) - dest_pipeline_script = source_name + ".py" - source_files = VerifiedSourceFiles( - True, - init_storage, - pipeline_script, - dest_pipeline_script, - template_files, - SourceRequirements([]), - "", - ) - if dest_storage.has_file(dest_pipeline_script): - fmt.warning("Pipeline script %s already exist, exiting" % dest_pipeline_script) + if source_type == "core": + source_configuration = files_ops.get_core_source_configuration( + core_sources_storage, source_name + ) + else: + if not is_valid_schema_name(source_name): + raise InvalidSchemaName(source_name) + source_configuration = files_ops.get_template_configuration( + templates_storage, source_name + ) + + if dest_storage.has_file(source_configuration.dest_pipeline_script): + fmt.warning( + "Pipeline script %s already exists, exiting" + % source_configuration.dest_pipeline_script + ) return # add .dlt/*.toml files to be copied - source_files.files.extend( + source_configuration.files.extend( [make_dlt_settings_path(CONFIG_TOML), make_dlt_settings_path(SECRETS_TOML)] ) # add dlt extras line to requirements - source_files.requirements.update_dlt_extras(destination_type) + source_configuration.requirements.update_dlt_extras(destination_type) # Check compatibility with installed dlt - if not source_files.requirements.is_installed_dlt_compatible(): + if not source_configuration.requirements.is_installed_dlt_compatible(): msg = ( "This pipeline requires a newer version of dlt than your installed version" - f" ({source_files.requirements.current_dlt_version()}). Pipeline requires" - f" '{source_files.requirements.dlt_requirement_base}'" + f" ({source_configuration.requirements.current_dlt_version()}). Pipeline requires" + f" '{source_configuration.requirements.dlt_requirement_base}'" ) fmt.warning(msg) if not fmt.confirm( @@ -332,28 +419,29 @@ def init_command( ): fmt.echo( "You can update dlt with: pip3 install -U" - f' "{source_files.requirements.dlt_requirement_base}"' + f' "{source_configuration.requirements.dlt_requirement_base}"' ) return # read module source and parse it visitor = utils.parse_init_script( "init", - source_files.storage.load(source_files.pipeline_script), - source_files.pipeline_script, + source_configuration.storage.load(source_configuration.src_pipeline_script), + source_configuration.src_pipeline_script, ) if visitor.is_destination_imported: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} import a destination from" - " dlt.destinations. You should specify destinations by name when calling dlt.pipeline" - " or dlt.run in init scripts.", + f"The pipeline script {source_configuration.src_pipeline_script} imports a destination" + " from dlt.destinations. You should specify destinations by name when calling" + " dlt.pipeline or dlt.run in init scripts.", ) if n.PIPELINE not in visitor.known_calls: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} does not seem to initialize" - " pipeline with dlt.pipeline. Please initialize pipeline explicitly in init scripts.", + f"The pipeline script {source_configuration.src_pipeline_script} does not seem to" + " initialize a pipeline with dlt.pipeline. Please initialize pipeline explicitly in" + " your init scripts.", ) # find all arguments in all calls to replace @@ -364,18 +452,18 @@ def init_command( ("pipeline_name", source_name), ("dataset_name", source_name + "_data"), ], - source_files.pipeline_script, + source_configuration.src_pipeline_script, ) # inspect the script inspect_pipeline_script( - source_files.storage.storage_path, - source_files.storage.to_relative_path(source_files.pipeline_script), + source_configuration.storage.storage_path, + source_configuration.storage.to_relative_path(source_configuration.src_pipeline_script), ignore_missing_imports=True, ) # detect all the required secrets and configs that should go into tomls files - if source_files.is_template: + if source_configuration.source_type == "template": # replace destination, pipeline_name and dataset_name in templates transformed_nodes = source_detection.find_call_arguments_to_replace( visitor, @@ -384,21 +472,22 @@ def init_command( ("pipeline_name", source_name), ("dataset_name", source_name + "_data"), ], - source_files.pipeline_script, + source_configuration.src_pipeline_script, ) # template sources are always in module starting with "pipeline" # for templates, place config and secrets into top level section required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, "pipeline", () + _SOURCES, source_configuration.source_module_prefix, () ) # template has a strict rules where sources are placed for source_q_name, source_config in checked_sources.items(): if source_q_name not in visitor.known_sources_resources: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} imports a source/resource" - f" {source_config.f.__name__} from module {source_config.module.__name__}. In" - " init scripts you must declare all sources and resources in single file.", + f"The pipeline script {source_configuration.src_pipeline_script} imports a" + f" source/resource {source_config.f.__name__} from module" + f" {source_config.module.__name__}. In init scripts you must declare all" + " sources and resources in single file.", ) # rename sources and resources transformed_nodes.extend( @@ -407,19 +496,22 @@ def init_command( else: # replace only destination for existing pipelines transformed_nodes = source_detection.find_call_arguments_to_replace( - visitor, [("destination", destination_type)], source_files.pipeline_script + visitor, [("destination", destination_type)], source_configuration.src_pipeline_script ) # pipeline sources are in module with name starting from {pipeline_name} # for verified pipelines place in the specific source section required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, source_name, (known_sections.SOURCES, source_name) + _SOURCES, + source_configuration.source_module_prefix, + (known_sections.SOURCES, source_name), ) - if len(checked_sources) == 0: + # the intro template does not use sources, for now allow it to pass here + if len(checked_sources) == 0 and source_name != "intro": raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} is not creating or importing any" - " sources or resources", + f"The pipeline script {source_configuration.src_pipeline_script} is not creating or" + " importing any sources or resources. Exiting...", ) # add destination spec to required secrets @@ -439,37 +531,57 @@ def init_command( # ask for confirmation if is_new_source: - if source_files.is_template: + if source_configuration.source_type == "core": + fmt.echo( + "Creating a new pipeline with the dlt core source %s (%s)" + % (fmt.bold(source_name), source_configuration.doc) + ) fmt.echo( - "A verified source %s was not found. Using a template to create a new source and" - " pipeline with name %s." % (fmt.bold(source_name), fmt.bold(source_name)) + "NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from the" + " verified sources repo but imported from dlt.sources. You can provide the" + " --omit-core-sources flag to revert to the old behavior." % (fmt.bold(source_name)) + ) + elif source_configuration.source_type == "verified": + fmt.echo( + "Creating and configuring a new pipeline with the verified source %s (%s)" + % (fmt.bold(source_name), source_configuration.doc) ) else: + if source_configuration.is_default_template: + fmt.echo( + "NOTE: Could not find a dlt source or template wih the name %s. Selecting the" + " default template." % (fmt.bold(source_name)) + ) + fmt.echo( + "NOTE: In case you did not want to use the default template, run 'dlt init -l'" + " to see all available sources and templates." + ) fmt.echo( - "Cloning and configuring a verified source %s (%s)" - % (fmt.bold(source_name), source_files.doc) + "Creating and configuring a new pipeline with the dlt core template %s (%s)" + % (fmt.bold(source_configuration.src_pipeline_script), source_configuration.doc) ) - if use_generic_template: - fmt.warning("--generic parameter is meaningless if verified source is found") + if not fmt.confirm("Do you want to proceed?", default=True): raise CliCommandException("init", "Aborted") dependency_system = _get_dependency_system(dest_storage) - _welcome_message(source_name, destination_type, source_files, dependency_system, is_new_source) + _welcome_message( + source_name, destination_type, source_configuration, dependency_system, is_new_source + ) # copy files at the very end - for file_name in source_files.files: + for file_name in source_configuration.files: dest_path = dest_storage.make_full_path(file_name) # get files from init section first - if init_storage.has_file(file_name): + if templates_storage.has_file(file_name): if dest_storage.has_file(dest_path): # do not overwrite any init files continue - src_path = init_storage.make_full_path(file_name) + src_path = templates_storage.make_full_path(file_name) else: # only those that were modified should be copied from verified sources if file_name in remote_modified: - src_path = source_files.storage.make_full_path(file_name) + src_path = source_configuration.storage.make_full_path(file_name) else: continue os.makedirs(os.path.dirname(dest_path), exist_ok=True) @@ -484,8 +596,8 @@ def init_command( source_name, remote_index, remote_modified, remote_deleted ) # create script - if not dest_storage.has_file(source_files.dest_pipeline_script): - dest_storage.save(source_files.dest_pipeline_script, dest_script_source) + if not dest_storage.has_file(source_configuration.dest_pipeline_script): + dest_storage.save(source_configuration.dest_pipeline_script, dest_script_source) # generate tomls with comments secrets_prov = SecretsTomlProvider() @@ -504,5 +616,5 @@ def init_command( # if there's no dependency system write the requirements file if dependency_system is None: - requirements_txt = "\n".join(source_files.requirements.compiled()) + requirements_txt = "\n".join(source_configuration.requirements.compiled()) dest_storage.save(utils.REQUIREMENTS_TXT, requirements_txt) diff --git a/dlt/cli/pipeline_files.py b/dlt/cli/pipeline_files.py index 49c0f71b21..6ca39e0195 100644 --- a/dlt/cli/pipeline_files.py +++ b/dlt/cli/pipeline_files.py @@ -4,7 +4,7 @@ import yaml import posixpath from pathlib import Path -from typing import Dict, NamedTuple, Sequence, Tuple, TypedDict, List +from typing import Dict, NamedTuple, Sequence, Tuple, TypedDict, List, Literal from dlt.cli.exceptions import VerifiedSourceRepoError from dlt.common import git @@ -16,21 +16,35 @@ from dlt.cli import utils from dlt.cli.requirements import SourceRequirements +TSourceType = Literal["core", "verified", "template"] SOURCES_INIT_INFO_ENGINE_VERSION = 1 SOURCES_INIT_INFO_FILE = ".sources" IGNORE_FILES = ["*.py[cod]", "*$py.class", "__pycache__", "py.typed", "requirements.txt"] -IGNORE_SOURCES = [".*", "_*"] - - -class VerifiedSourceFiles(NamedTuple): - is_template: bool +IGNORE_VERIFIED_SOURCES = [".*", "_*"] +IGNORE_CORE_SOURCES = [ + ".*", + "_*", + "helpers", + "pipeline_templates", +] +PIPELINE_FILE_SUFFIX = "_pipeline.py" + +# hardcode default template files here +TEMPLATE_FILES = [".gitignore", ".dlt/config.toml", ".dlt/secrets.toml"] +DEFAULT_PIPELINE_TEMPLATE = "default_pipeline.py" + + +class SourceConfiguration(NamedTuple): + source_type: TSourceType + source_module_prefix: str storage: FileStorage - pipeline_script: str + src_pipeline_script: str dest_pipeline_script: str files: List[str] requirements: SourceRequirements doc: str + is_default_template: bool class TVerifiedSourceFileEntry(TypedDict): @@ -147,22 +161,88 @@ def get_remote_source_index( } -def get_verified_source_names(sources_storage: FileStorage) -> List[str]: +def get_sources_names(sources_storage: FileStorage, source_type: TSourceType) -> List[str]: candidates: List[str] = [] - for name in [ - n - for n in sources_storage.list_folder_dirs(".", to_root=False) - if not any(fnmatch.fnmatch(n, ignore) for ignore in IGNORE_SOURCES) - ]: - # must contain at least one valid python script - if any(f.endswith(".py") for f in sources_storage.list_folder_files(name, to_root=False)): - candidates.append(name) + + # for the templates we just find all the filenames + if source_type == "template": + for name in sources_storage.list_folder_files(".", to_root=False): + if name.endswith(PIPELINE_FILE_SUFFIX): + candidates.append(name.replace(PIPELINE_FILE_SUFFIX, "")) + else: + ignore_cases = IGNORE_VERIFIED_SOURCES if source_type == "verified" else IGNORE_CORE_SOURCES + for name in [ + n + for n in sources_storage.list_folder_dirs(".", to_root=False) + if not any(fnmatch.fnmatch(n, ignore) for ignore in ignore_cases) + ]: + # must contain at least one valid python script + if any( + f.endswith(".py") for f in sources_storage.list_folder_files(name, to_root=False) + ): + candidates.append(name) + + candidates.sort() return candidates -def get_verified_source_files( +def _get_docstring_for_module(sources_storage: FileStorage, source_name: str) -> str: + # read the docs + init_py = os.path.join(source_name, utils.MODULE_INIT) + docstring: str = "" + if sources_storage.has_file(init_py): + docstring = get_module_docstring(sources_storage.load(init_py)) + if docstring: + docstring = docstring.splitlines()[0] + return docstring + + +def get_template_configuration( sources_storage: FileStorage, source_name: str -) -> VerifiedSourceFiles: +) -> SourceConfiguration: + destination_pipeline_file_name = source_name + PIPELINE_FILE_SUFFIX + source_pipeline_file_name = destination_pipeline_file_name + + if not sources_storage.has_file(source_pipeline_file_name): + source_pipeline_file_name = DEFAULT_PIPELINE_TEMPLATE + + docstring = get_module_docstring(sources_storage.load(source_pipeline_file_name)) + if docstring: + docstring = docstring.splitlines()[0] + return SourceConfiguration( + "template", + source_pipeline_file_name.replace("pipeline.py", ""), + sources_storage, + source_pipeline_file_name, + destination_pipeline_file_name, + TEMPLATE_FILES, + SourceRequirements([]), + docstring, + source_pipeline_file_name == DEFAULT_PIPELINE_TEMPLATE, + ) + + +def get_core_source_configuration( + sources_storage: FileStorage, source_name: str +) -> SourceConfiguration: + pipeline_file = source_name + "_pipeline.py" + + return SourceConfiguration( + "core", + "dlt.sources." + source_name, + sources_storage, + pipeline_file, + pipeline_file, + [".gitignore"], + SourceRequirements([]), + _get_docstring_for_module(sources_storage, source_name), + False, + ) + + +def get_verified_source_configuration( + sources_storage: FileStorage, source_name: str +) -> SourceConfiguration: if not sources_storage.has_folder(source_name): raise VerifiedSourceRepoError( f"Verified source {source_name} could not be found in the repository", source_name @@ -189,13 +269,6 @@ def get_verified_source_files( if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES) ] ) - # read the docs - init_py = os.path.join(source_name, utils.MODULE_INIT) - docstring: str = "" - if sources_storage.has_file(init_py): - docstring = get_module_docstring(sources_storage.load(init_py)) - if docstring: - docstring = docstring.splitlines()[0] # read requirements requirements_path = os.path.join(source_name, utils.REQUIREMENTS_TXT) if sources_storage.has_file(requirements_path): @@ -203,8 +276,16 @@ def get_verified_source_files( else: requirements = SourceRequirements([]) # find requirements - return VerifiedSourceFiles( - False, sources_storage, example_script, example_script, files, requirements, docstring + return SourceConfiguration( + "verified", + source_name, + sources_storage, + example_script, + example_script, + files, + requirements, + _get_docstring_for_module(sources_storage, source_name), + False, ) diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 5b9a4587c7..5d3ec689c4 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,7 +1,7 @@ import dataclasses from typing import Any, ClassVar, Dict, List, Optional, Union -from dlt.common.libs.sql_alchemy import URL, make_url +from dlt.common.libs.sql_alchemy_shims import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString from dlt.common.typing import TSecretValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec diff --git a/dlt/common/libs/sql_alchemy.py b/dlt/common/libs/sql_alchemy.py index 2f3b51ec0d..19ebbbc78a 100644 --- a/dlt/common/libs/sql_alchemy.py +++ b/dlt/common/libs/sql_alchemy.py @@ -1,446 +1,20 @@ -""" -Ports fragments of URL class from Sql Alchemy to use them when dependency is not available. -""" - -from typing import cast - +from dlt.common.exceptions import MissingDependencyException +from dlt import version try: - import sqlalchemy -except ImportError: - # port basic functionality without the whole Sql Alchemy - - import re - from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - NamedTuple, - Optional, - Sequence, - Tuple, - TypeVar, - Union, - overload, + from sqlalchemy import MetaData, Table, Column, create_engine + from sqlalchemy.engine import Engine, URL, make_url, Row + from sqlalchemy.sql import sqltypes, Select + from sqlalchemy.sql.sqltypes import TypeEngine + from sqlalchemy.exc import CompileError + import sqlalchemy as sa +except ModuleNotFoundError: + raise MissingDependencyException( + "dlt sql_database helpers ", + [f"{version.DLT_PKG_NAME}[sql_database]"], + "Install the sql_database helpers for loading from sql_database sources. Note that you may" + " need to install additional SQLAlchemy dialects for your source database.", ) - import collections.abc as collections_abc - from urllib.parse import ( - quote_plus, - parse_qsl, - quote, - unquote, - ) - - _KT = TypeVar("_KT", bound=Any) - _VT = TypeVar("_VT", bound=Any) - - class ImmutableDict(Dict[_KT, _VT]): - """Not a real immutable dict""" - - def __setitem__(self, __key: _KT, __value: _VT) -> None: - raise NotImplementedError("Cannot modify immutable dict") - - def __delitem__(self, _KT: Any) -> None: - raise NotImplementedError("Cannot modify immutable dict") - - def update(self, *arg: Any, **kw: Any) -> None: - raise NotImplementedError("Cannot modify immutable dict") - - EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() - - def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: - if value is None: - return default - if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): - return [value] - elif isinstance(value, list): - return value - else: - return list(value) - - class URL(NamedTuple): - """ - Represent the components of a URL used to connect to a database. - - Based on SqlAlchemy URL class with copyright as below: - - # engine/url.py - # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors - # - # This module is part of SQLAlchemy and is released under - # the MIT License: https://www.opensource.org/licenses/mit-license.php - """ - - drivername: str - """database backend and driver name, such as `postgresql+psycopg2`""" - username: Optional[str] - "username string" - password: Optional[str] - """password, which is normally a string but may also be any object that has a `__str__()` method.""" - host: Optional[str] - """hostname or IP number. May also be a data source name for some drivers.""" - port: Optional[int] - """integer port number""" - database: Optional[str] - """database name""" - query: ImmutableDict[str, Union[Tuple[str, ...], str]] - """an immutable mapping representing the query string. contains strings - for keys and either strings or tuples of strings for values""" - - @classmethod - def create( - cls, - drivername: str, - username: Optional[str] = None, - password: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - database: Optional[str] = None, - query: Mapping[str, Union[Sequence[str], str]] = None, - ) -> "URL": - """Create a new `URL` object.""" - return cls( - cls._assert_str(drivername, "drivername"), - cls._assert_none_str(username, "username"), - password, - cls._assert_none_str(host, "host"), - cls._assert_port(port), - cls._assert_none_str(database, "database"), - cls._str_dict(query or EMPTY_DICT), - ) - - @classmethod - def _assert_port(cls, port: Optional[int]) -> Optional[int]: - if port is None: - return None - try: - return int(port) - except TypeError: - raise TypeError("Port argument must be an integer or None") - - @classmethod - def _assert_str(cls, v: str, paramname: str) -> str: - if not isinstance(v, str): - raise TypeError("%s must be a string" % paramname) - return v - - @classmethod - def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: - if v is None: - return v - - return cls._assert_str(v, paramname) - - @classmethod - def _str_dict( - cls, - dict_: Optional[ - Union[ - Sequence[Tuple[str, Union[Sequence[str], str]]], - Mapping[str, Union[Sequence[str], str]], - ] - ], - ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: - if dict_ is None: - return EMPTY_DICT - - @overload - def _assert_value( - val: str, - ) -> str: ... - - @overload - def _assert_value( - val: Sequence[str], - ) -> Union[str, Tuple[str, ...]]: ... - - def _assert_value( - val: Union[str, Sequence[str]], - ) -> Union[str, Tuple[str, ...]]: - if isinstance(val, str): - return val - elif isinstance(val, collections_abc.Sequence): - return tuple(_assert_value(elem) for elem in val) - else: - raise TypeError( - "Query dictionary values must be strings or sequences of strings" - ) - - def _assert_str(v: str) -> str: - if not isinstance(v, str): - raise TypeError("Query dictionary keys must be strings") - return v - - dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] - if isinstance(dict_, collections_abc.Sequence): - dict_items = dict_ - else: - dict_items = dict_.items() - - return ImmutableDict( - { - _assert_str(key): _assert_value( - value, - ) - for key, value in dict_items - } - ) - - def set( # noqa - self, - drivername: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - database: Optional[str] = None, - query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, - ) -> "URL": - """return a new `URL` object with modifications.""" - - kw: Dict[str, Any] = {} - if drivername is not None: - kw["drivername"] = drivername - if username is not None: - kw["username"] = username - if password is not None: - kw["password"] = password - if host is not None: - kw["host"] = host - if port is not None: - kw["port"] = port - if database is not None: - kw["database"] = database - if query is not None: - kw["query"] = query - - return self._assert_replace(**kw) - - def _assert_replace(self, **kw: Any) -> "URL": - """argument checks before calling _replace()""" - - if "drivername" in kw: - self._assert_str(kw["drivername"], "drivername") - for name in "username", "host", "database": - if name in kw: - self._assert_none_str(kw[name], name) - if "port" in kw: - self._assert_port(kw["port"]) - if "query" in kw: - kw["query"] = self._str_dict(kw["query"]) - - return self._replace(**kw) - - def update_query_string(self, query_string: str, append: bool = False) -> "URL": - return self.update_query_pairs(parse_qsl(query_string), append=append) - - def update_query_pairs( - self, - key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], - append: bool = False, - ) -> "URL": - """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" - existing_query = self.query - new_keys: Dict[str, Union[str, List[str]]] = {} - - for key, value in key_value_pairs: - if key in new_keys: - new_keys[key] = to_list(new_keys[key]) - cast("List[str]", new_keys[key]).append(cast(str, value)) - else: - new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value - - new_query: Mapping[str, Union[str, Sequence[str]]] - if append: - new_query = {} - - for k in new_keys: - if k in existing_query: - new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) - else: - new_query[k] = new_keys[k] - - new_query.update( - {k: existing_query[k] for k in set(existing_query).difference(new_keys)} - ) - else: - new_query = ImmutableDict( - { - **self.query, - **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, - } - ) - return self.set(query=new_query) - - def update_query_dict( - self, - query_parameters: Mapping[str, Union[str, List[str]]], - append: bool = False, - ) -> "URL": - return self.update_query_pairs(query_parameters.items(), append=append) - - def render_as_string(self, hide_password: bool = True) -> str: - """Render this `URL` object as a string.""" - s = self.drivername + "://" - if self.username is not None: - s += quote(self.username, safe=" +") - if self.password is not None: - s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) - s += "@" - if self.host is not None: - if ":" in self.host: - s += f"[{self.host}]" - else: - s += self.host - if self.port is not None: - s += ":" + str(self.port) - if self.database is not None: - s += "/" + self.database - if self.query: - keys = to_list(self.query) - keys.sort() - s += "?" + "&".join( - f"{quote_plus(k)}={quote_plus(element)}" - for k in keys - for element in to_list(self.query[k]) - ) - return s - - def __repr__(self) -> str: - return self.render_as_string() - - def __copy__(self) -> "URL": - return self.__class__.create( - self.drivername, - self.username, - self.password, - self.host, - self.port, - self.database, - self.query.copy(), - ) - - def __deepcopy__(self, memo: Any) -> "URL": - return self.__copy__() - - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, URL) - and self.drivername == other.drivername - and self.username == other.username - and self.password == other.password - and self.host == other.host - and self.database == other.database - and self.query == other.query - and self.port == other.port - ) - - def __ne__(self, other: Any) -> bool: - return not self == other - - def get_backend_name(self) -> str: - """Return the backend name. - - This is the name that corresponds to the database backend in - use, and is the portion of the `drivername` - that is to the left of the plus sign. - - """ - if "+" not in self.drivername: - return self.drivername - else: - return self.drivername.split("+")[0] - - def get_driver_name(self) -> str: - """Return the backend name. - - This is the name that corresponds to the DBAPI driver in - use, and is the portion of the `drivername` - that is to the right of the plus sign. - """ - - if "+" not in self.drivername: - return self.drivername - else: - return self.drivername.split("+")[1] - - def make_url(name_or_url: Union[str, URL]) -> URL: - """Given a string, produce a new URL instance. - - The format of the URL generally follows `RFC-1738`, with some exceptions, including - that underscores, and not dashes or periods, are accepted within the - "scheme" portion. - - If a `URL` object is passed, it is returned as is.""" - - if isinstance(name_or_url, str): - return _parse_url(name_or_url) - elif not isinstance(name_or_url, URL): - raise ValueError(f"Expected string or URL object, got {name_or_url!r}") - else: - return name_or_url - - def _parse_url(name: str) -> URL: - pattern = re.compile( - r""" - (?P[\w\+]+):// - (?: - (?P[^:/]*) - (?::(?P[^@]*))? - @)? - (?: - (?: - \[(?P[^/\?]+)\] | - (?P[^/:\?]+) - )? - (?::(?P[^/\?]*))? - )? - (?:/(?P[^\?]*))? - (?:\?(?P.*))? - """, - re.X, - ) - - m = pattern.match(name) - if m is not None: - components = m.groupdict() - query: Optional[Dict[str, Union[str, List[str]]]] - if components["query"] is not None: - query = {} - - for key, value in parse_qsl(components["query"]): - if key in query: - query[key] = to_list(query[key]) - cast("List[str]", query[key]).append(value) - else: - query[key] = value - else: - query = None - - components["query"] = query - if components["username"] is not None: - components["username"] = unquote(components["username"]) - - if components["password"] is not None: - components["password"] = unquote(components["password"]) - - ipv4host = components.pop("ipv4host") - ipv6host = components.pop("ipv6host") - components["host"] = ipv4host or ipv6host - name = components.pop("name") - - if components["port"]: - components["port"] = int(components["port"]) - - return URL.create(name, **components) # type: ignore - - else: - raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) -else: - from sqlalchemy.engine import URL, make_url # type: ignore[assignment] +# TODO: maybe use sa.__version__? +IS_SQL_ALCHEMY_20 = hasattr(sa, "Double") diff --git a/dlt/common/libs/sql_alchemy_shims.py b/dlt/common/libs/sql_alchemy_shims.py new file mode 100644 index 0000000000..2f3b51ec0d --- /dev/null +++ b/dlt/common/libs/sql_alchemy_shims.py @@ -0,0 +1,446 @@ +""" +Ports fragments of URL class from Sql Alchemy to use them when dependency is not available. +""" + +from typing import cast + + +try: + import sqlalchemy +except ImportError: + # port basic functionality without the whole Sql Alchemy + + import re + from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, + ) + import collections.abc as collections_abc + from urllib.parse import ( + quote_plus, + parse_qsl, + quote, + unquote, + ) + + _KT = TypeVar("_KT", bound=Any) + _VT = TypeVar("_VT", bound=Any) + + class ImmutableDict(Dict[_KT, _VT]): + """Not a real immutable dict""" + + def __setitem__(self, __key: _KT, __value: _VT) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def __delitem__(self, _KT: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def update(self, *arg: Any, **kw: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() + + def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: + if value is None: + return default + if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): + return [value] + elif isinstance(value, list): + return value + else: + return list(value) + + class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + Based on SqlAlchemy URL class with copyright as below: + + # engine/url.py + # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors + # + # This module is part of SQLAlchemy and is released under + # the MIT License: https://www.opensource.org/licenses/mit-license.php + """ + + drivername: str + """database backend and driver name, such as `postgresql+psycopg2`""" + username: Optional[str] + "username string" + password: Optional[str] + """password, which is normally a string but may also be any object that has a `__str__()` method.""" + host: Optional[str] + """hostname or IP number. May also be a data source name for some drivers.""" + port: Optional[int] + """integer port number""" + database: Optional[str] + """database name""" + query: ImmutableDict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values""" + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = None, + ) -> "URL": + """Create a new `URL` object.""" + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query or EMPTY_DICT), + ) + + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: + if v is None: + return v + + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError( + "Query dictionary values must be strings or sequences of strings" + ) + + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v + + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ + else: + dict_items = dict_.items() + + return ImmutableDict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) + + def set( # noqa + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> "URL": + """return a new `URL` object with modifications.""" + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> "URL": + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string(self, query_string: str, append: bool = False) -> "URL": + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> "URL": + """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value + + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} + + for k in new_keys: + if k in existing_query: + new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) + else: + new_query[k] = new_keys[k] + + new_query.update( + {k: existing_query[k] for k in set(existing_query).difference(new_keys)} + ) + else: + new_query = ImmutableDict( + { + **self.query, + **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, + } + ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> "URL": + return self.update_query_pairs(query_parameters.items(), append=append) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this `URL` object as a string.""" + s = self.drivername + "://" + if self.username is not None: + s += quote(self.username, safe=" +") + if self.password is not None: + s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) + s += "@" + if self.host is not None: + if ":" in self.host: + s += f"[{self.host}]" + else: + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = to_list(self.query) + keys.sort() + s += "?" + "&".join( + f"{quote_plus(k)}={quote_plus(element)}" + for k in keys + for element in to_list(self.query[k]) + ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> "URL": + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + self.query.copy(), + ) + + def __deepcopy__(self, memo: Any) -> "URL": + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def get_backend_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the database backend in + use, and is the portion of the `drivername` + that is to the left of the plus sign. + + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the `drivername` + that is to the right of the plus sign. + """ + + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[1] + + def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. + + The format of the URL generally follows `RFC-1738`, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. + + If a `URL` object is passed, it is returned as is.""" + + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + elif not isinstance(name_or_url, URL): + raise ValueError(f"Expected string or URL object, got {name_or_url!r}") + else: + return name_or_url + + def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: + (?: + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) + )? + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None + + components["query"] = query + if components["username"] is not None: + components["username"] = unquote(components["username"]) + + if components["password"] is not None: + components["password"] = unquote(components["password"]) + + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") + + if components["port"]: + components["port"] = int(components["port"]) + + return URL.create(name, **components) # type: ignore + + else: + raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) + +else: + from sqlalchemy.engine import URL, make_url # type: ignore[assignment] diff --git a/dlt/common/typing.py b/dlt/common/typing.py index ee11a77965..8d18d84400 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -42,6 +42,8 @@ get_original_bases, ) +from typing_extensions import is_typeddict as _is_typeddict + try: from types import UnionType # type: ignore[attr-defined] except ImportError: @@ -293,7 +295,7 @@ def is_newtype_type(t: Type[Any]) -> bool: def is_typeddict(t: Type[Any]) -> bool: - if isinstance(t, _TypedDict): + if _is_typeddict(t): return True if inner_t := extract_type_if_modifier(t): return is_typeddict(inner_t) @@ -425,3 +427,23 @@ def decorator(func: Callable[..., TReturnVal]) -> Callable[TInputArgs, TReturnVa return func return decorator + + +def copy_sig_any( + wrapper: Callable[Concatenate[TDataItem, TInputArgs], Any], +) -> Callable[ + [Callable[..., TReturnVal]], Callable[Concatenate[TDataItem, TInputArgs], TReturnVal] +]: + """Copies docstring and signature from wrapper to func but keeps the func return value type + + It converts the type of first argument of the wrapper to Any which allows to type transformers in DltSources. + See filesystem source readers as example + """ + + def decorator( + func: Callable[..., TReturnVal] + ) -> Callable[Concatenate[Any, TInputArgs], TReturnVal]: + func.__doc__ = wrapper.__doc__ + return func + + return decorator diff --git a/dlt/destinations/impl/dremio/configuration.py b/dlt/destinations/impl/dremio/configuration.py index 9b1e52f292..d1893e76b7 100644 --- a/dlt/destinations/impl/dremio/configuration.py +++ b/dlt/destinations/impl/dremio/configuration.py @@ -4,7 +4,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -from dlt.common.libs.sql_alchemy import URL +from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index d5065f5bdd..5fa82f4977 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -1,4 +1,5 @@ import threading +import logging from typing import ClassVar, Dict, Optional from dlt.common.destination import DestinationCapabilitiesContext @@ -92,10 +93,9 @@ def to_db_datetime_type( precision = column.get("precision") if timezone and precision is not None: - raise TerminalValueError( + logging.warn( f"DuckDB does not support both timezone and precision for column '{column_name}' in" - f" table '{table_name}'. To resolve this issue, either set timezone to False or" - " None, or use the default precision." + f" table '{table_name}'. Will default to timezone." ) if timezone: diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 64d87065f3..5b08546f73 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,6 +1,6 @@ import dataclasses from typing import Final, ClassVar, Any, List, Dict -from dlt.common.libs.sql_alchemy import URL +from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 13bdc7f6b2..fab398fc21 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -2,7 +2,7 @@ from typing import Dict, Final, ClassVar, Any, List, Optional from dlt.common.data_writers.configuration import CsvFormatConfiguration -from dlt.common.libs.sql_alchemy import URL +from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 08fc132fc3..3fc479f237 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -4,7 +4,7 @@ from dlt import version from dlt.common.data_writers.configuration import CsvFormatConfiguration -from dlt.common.libs.sql_alchemy import URL +from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 67a6b3e83a..c828064288 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -1,4 +1,4 @@ -from copy import copy, deepcopy +from copy import deepcopy from typing import TypedDict, cast, Any, Optional, Dict from dlt.common import logger @@ -40,18 +40,21 @@ from dlt.extract.validation import create_item_validator -class TResourceHints(TypedDict, total=False): +class TResourceHintsBase(TypedDict, total=False): + write_disposition: Optional[TTableHintTemplate[TWriteDispositionConfig]] + parent: Optional[TTableHintTemplate[str]] + primary_key: Optional[TTableHintTemplate[TColumnNames]] + schema_contract: Optional[TTableHintTemplate[TSchemaContract]] + table_format: Optional[TTableHintTemplate[TTableFormat]] + merge_key: Optional[TTableHintTemplate[TColumnNames]] + + +class TResourceHints(TResourceHintsBase, total=False): name: TTableHintTemplate[str] # description: TTableHintTemplate[str] - write_disposition: TTableHintTemplate[TWriteDispositionConfig] # table_sealed: Optional[bool] - parent: TTableHintTemplate[str] columns: TTableHintTemplate[TTableSchemaColumns] - primary_key: TTableHintTemplate[TColumnNames] - merge_key: TTableHintTemplate[TColumnNames] incremental: Incremental[Any] - schema_contract: TTableHintTemplate[TSchemaContract] - table_format: TTableHintTemplate[TTableFormat] file_format: TTableHintTemplate[TFileFormat] validator: ValidateItem original_columns: TTableHintTemplate[TAnySchemaColumns] diff --git a/dlt/extract/incremental/typing.py b/dlt/extract/incremental/typing.py index a5e2612db4..6829e6b370 100644 --- a/dlt/extract/incremental/typing.py +++ b/dlt/extract/incremental/typing.py @@ -1,5 +1,10 @@ -from typing import TypedDict, Optional, Any, List, Literal, TypeVar, Callable, Sequence +from typing_extensions import TypedDict +from typing import Any, Callable, List, Literal, Optional, Sequence, TypeVar + +from dlt.common.schema.typing import TColumnNames +from dlt.common.typing import TSortOrder +from dlt.extract.items import TTableHintTemplate TCursorValue = TypeVar("TCursorValue", bound=Any) LastValueFunc = Callable[[Sequence[TCursorValue]], Any] @@ -10,3 +15,13 @@ class IncrementalColumnState(TypedDict): initial_value: Optional[Any] last_value: Optional[Any] unique_hashes: List[str] + + +class IncrementalArgs(TypedDict, total=False): + cursor_path: str + initial_value: Optional[str] + last_value_func: Optional[LastValueFunc[str]] + primary_key: Optional[TTableHintTemplate[TColumnNames]] + end_value: Optional[str] + row_order: Optional[TSortOrder] + allow_external_schedulers: Optional[bool] diff --git a/dlt/reflection/names.py b/dlt/reflection/names.py index dad7bdce92..4134e417ef 100644 --- a/dlt/reflection/names.py +++ b/dlt/reflection/names.py @@ -2,7 +2,7 @@ import dlt import dlt.destinations -from dlt import pipeline, attach, run, source, resource +from dlt import pipeline, attach, run, source, resource, transformer DLT = dlt.__name__ DESTINATIONS = dlt.destinations.__name__ @@ -11,12 +11,14 @@ RUN = run.__name__ SOURCE = source.__name__ RESOURCE = resource.__name__ +TRANSFORMER = transformer.__name__ -DETECTED_FUNCTIONS = [PIPELINE, SOURCE, RESOURCE, RUN] +DETECTED_FUNCTIONS = [PIPELINE, SOURCE, RESOURCE, RUN, TRANSFORMER] SIGNATURES = { PIPELINE: inspect.signature(pipeline), ATTACH: inspect.signature(attach), RUN: inspect.signature(run), SOURCE: inspect.signature(source), RESOURCE: inspect.signature(resource), + TRANSFORMER: inspect.signature(transformer), } diff --git a/dlt/reflection/script_visitor.py b/dlt/reflection/script_visitor.py index 52b19fe031..f4a5569ed0 100644 --- a/dlt/reflection/script_visitor.py +++ b/dlt/reflection/script_visitor.py @@ -80,6 +80,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: self.known_sources[str(node.name)] = node elif fn == n.RESOURCE: self.known_resources[str(node.name)] = node + elif fn == n.TRANSFORMER: + self.known_resources[str(node.name)] = node super().generic_visit(node) def visit_Call(self, node: ast.Call) -> Any: diff --git a/dlt/sources/.gitignore b/dlt/sources/.gitignore new file mode 100644 index 0000000000..3b28aa3f63 --- /dev/null +++ b/dlt/sources/.gitignore @@ -0,0 +1,10 @@ +# ignore secrets, virtual environments and typical python compilation artifacts +secrets.toml +# ignore basic python artifacts +.env +**/__pycache__/ +**/*.py[cod] +**/*$py.class +# ignore duckdb +*.duckdb +*.wal \ No newline at end of file diff --git a/dlt/sources/__init__.py b/dlt/sources/__init__.py index 465467db67..dcfc281160 100644 --- a/dlt/sources/__init__.py +++ b/dlt/sources/__init__.py @@ -3,7 +3,6 @@ from dlt.extract import DltSource, DltResource, Incremental as incremental from . import credentials from . import config -from . import filesystem __all__ = [ "DltSource", @@ -13,5 +12,4 @@ "incremental", "credentials", "config", - "filesystem", ] diff --git a/dlt/sources/filesystem.py b/dlt/sources/filesystem.py deleted file mode 100644 index 23fb6a9cf3..0000000000 --- a/dlt/sources/filesystem.py +++ /dev/null @@ -1,8 +0,0 @@ -from dlt.common.storages.fsspec_filesystem import ( - FileItem, - FileItemDict, - fsspec_filesystem, - glob_files, -) - -__all__ = ["FileItem", "FileItemDict", "fsspec_filesystem", "glob_files"] diff --git a/dlt/sources/filesystem/__init__.py b/dlt/sources/filesystem/__init__.py new file mode 100644 index 0000000000..80dabe7e66 --- /dev/null +++ b/dlt/sources/filesystem/__init__.py @@ -0,0 +1,102 @@ +"""Reads files in s3, gs or azure buckets using fsspec and provides convenience resources for chunked reading of various file formats""" +from typing import Iterator, List, Optional, Tuple, Union + +import dlt +from dlt.common.storages.fsspec_filesystem import ( + FileItem, + FileItemDict, + fsspec_filesystem, + glob_files, +) +from dlt.sources import DltResource +from dlt.sources.credentials import FileSystemCredentials + +from dlt.sources.filesystem.helpers import ( + AbstractFileSystem, + FilesystemConfigurationResource, +) +from dlt.sources.filesystem.readers import ( + ReadersSource, + _read_csv, + _read_csv_duckdb, + _read_jsonl, + _read_parquet, +) +from dlt.sources.filesystem.settings import DEFAULT_CHUNK_SIZE + + +@dlt.source(_impl_cls=ReadersSource, spec=FilesystemConfigurationResource) +def readers( + bucket_url: str = dlt.secrets.value, + credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value, + file_glob: Optional[str] = "*", +) -> Tuple[DltResource, ...]: + """This source provides a few resources that are chunked file readers. Readers can be further parametrized before use + read_csv(chunksize, **pandas_kwargs) + read_jsonl(chunksize) + read_parquet(chunksize) + + Args: + bucket_url (str): The url to the bucket. + credentials (FileSystemCredentials | AbstractFilesystem): The credentials to the filesystem of fsspec `AbstractFilesystem` instance. + file_glob (str, optional): The filter to apply to the files in glob format. by default lists all files in bucket_url non-recursively + + """ + return ( + filesystem(bucket_url, credentials, file_glob=file_glob) + | dlt.transformer(name="read_csv")(_read_csv), + filesystem(bucket_url, credentials, file_glob=file_glob) + | dlt.transformer(name="read_jsonl")(_read_jsonl), + filesystem(bucket_url, credentials, file_glob=file_glob) + | dlt.transformer(name="read_parquet")(_read_parquet), + filesystem(bucket_url, credentials, file_glob=file_glob) + | dlt.transformer(name="read_csv_duckdb")(_read_csv_duckdb), + ) + + +@dlt.resource(primary_key="file_url", spec=FilesystemConfigurationResource, standalone=True) +def filesystem( + bucket_url: str = dlt.secrets.value, + credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value, + file_glob: Optional[str] = "*", + files_per_page: int = DEFAULT_CHUNK_SIZE, + extract_content: bool = False, +) -> Iterator[List[FileItem]]: + """This resource lists files in `bucket_url` using `file_glob` pattern. The files are yielded as FileItem which also + provide methods to open and read file data. It should be combined with transformers that further process (ie. load files) + + Args: + bucket_url (str): The url to the bucket. + credentials (FileSystemCredentials | AbstractFilesystem): The credentials to the filesystem of fsspec `AbstractFilesystem` instance. + file_glob (str, optional): The filter to apply to the files in glob format. by default lists all files in bucket_url non-recursively + files_per_page (int, optional): The number of files to process at once, defaults to 100. + extract_content (bool, optional): If true, the content of the file will be extracted if + false it will return a fsspec file, defaults to False. + + Returns: + Iterator[List[FileItem]]: The list of files. + """ + if isinstance(credentials, AbstractFileSystem): + fs_client = credentials + else: + fs_client = fsspec_filesystem(bucket_url, credentials)[0] + + files_chunk: List[FileItem] = [] + for file_model in glob_files(fs_client, bucket_url, file_glob): + file_dict = FileItemDict(file_model, credentials) + if extract_content: + file_dict["file_content"] = file_dict.read_bytes() + files_chunk.append(file_dict) # type: ignore + + # wait for the chunk to be full + if len(files_chunk) >= files_per_page: + yield files_chunk + files_chunk = [] + if files_chunk: + yield files_chunk + + +read_csv = dlt.transformer(standalone=True)(_read_csv) +read_jsonl = dlt.transformer(standalone=True)(_read_jsonl) +read_parquet = dlt.transformer(standalone=True)(_read_parquet) +read_csv_duckdb = dlt.transformer(standalone=True)(_read_csv_duckdb) diff --git a/dlt/sources/filesystem/helpers.py b/dlt/sources/filesystem/helpers.py new file mode 100644 index 0000000000..ebfb491197 --- /dev/null +++ b/dlt/sources/filesystem/helpers.py @@ -0,0 +1,98 @@ +"""Helpers for the filesystem resource.""" +from typing import Any, Dict, Iterable, List, Optional, Type, Union +from fsspec import AbstractFileSystem + +import dlt +from dlt.common.configuration import resolve_type +from dlt.common.typing import TDataItem + +from dlt.sources import DltResource +from dlt.sources.filesystem import fsspec_filesystem +from dlt.sources.config import configspec, with_config +from dlt.sources.credentials import ( + CredentialsConfiguration, + FilesystemConfiguration, + FileSystemCredentials, +) + +from .settings import DEFAULT_CHUNK_SIZE + + +@configspec +class FilesystemConfigurationResource(FilesystemConfiguration): + credentials: Union[FileSystemCredentials, AbstractFileSystem] = None + file_glob: Optional[str] = "*" + files_per_page: int = DEFAULT_CHUNK_SIZE + extract_content: bool = False + + @resolve_type("credentials") + def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: + # use known credentials or empty credentials for unknown protocol + return Union[self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration], AbstractFileSystem] # type: ignore[return-value] + + +def fsspec_from_resource(filesystem_instance: DltResource) -> AbstractFileSystem: + """Extract authorized fsspec client from a filesystem resource""" + + @with_config( + spec=FilesystemConfiguration, + sections=("sources", filesystem_instance.section, filesystem_instance.name), + ) + def _get_fsspec( + bucket_url: str, credentials: Optional[FileSystemCredentials] + ) -> AbstractFileSystem: + return fsspec_filesystem(bucket_url, credentials)[0] + + return _get_fsspec( + filesystem_instance.explicit_args.get("bucket_url", dlt.config.value), + filesystem_instance.explicit_args.get("credentials", dlt.secrets.value), + ) + + +def add_columns(columns: List[str], rows: List[List[Any]]) -> List[Dict[str, Any]]: + """Adds column names to the given rows. + + Args: + columns (List[str]): The column names. + rows (List[List[Any]]): The rows. + + Returns: + List[Dict[str, Any]]: The rows with column names. + """ + result = [] + for row in rows: + result.append(dict(zip(columns, row))) + + return result + + +def fetch_arrow(file_data, chunk_size: int) -> Iterable[TDataItem]: # type: ignore + """Fetches data from the given CSV file. + + Args: + file_data (DuckDBPyRelation): The CSV file data. + chunk_size (int): The number of rows to read at once. + + Yields: + Iterable[TDataItem]: Data items, read from the given CSV file. + """ + batcher = file_data.fetch_arrow_reader(batch_size=chunk_size) + yield from batcher + + +def fetch_json(file_data, chunk_size: int) -> List[Dict[str, Any]]: # type: ignore + """Fetches data from the given CSV file. + + Args: + file_data (DuckDBPyRelation): The CSV file data. + chunk_size (int): The number of rows to read at once. + + Yields: + Iterable[TDataItem]: Data items, read from the given CSV file. + """ + while True: + batch = file_data.fetchmany(chunk_size) + if not batch: + break + + yield add_columns(file_data.columns, batch) diff --git a/dlt/sources/filesystem/readers.py b/dlt/sources/filesystem/readers.py new file mode 100644 index 0000000000..405948b515 --- /dev/null +++ b/dlt/sources/filesystem/readers.py @@ -0,0 +1,129 @@ +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional + +from dlt.common import json +from dlt.common.typing import copy_sig_any +from dlt.sources import TDataItems, DltResource, DltSource +from dlt.sources.filesystem import FileItemDict + +from .helpers import fetch_arrow, fetch_json + + +def _read_csv( + items: Iterator[FileItemDict], chunksize: int = 10000, **pandas_kwargs: Any +) -> Iterator[TDataItems]: + """Reads csv file with Pandas chunk by chunk. + + Args: + chunksize (int): Number of records to read in one chunk + **pandas_kwargs: Additional keyword arguments passed to Pandas.read_csv + Returns: + TDataItem: The file content + """ + import pandas as pd + + # apply defaults to pandas kwargs + kwargs = {**{"header": "infer", "chunksize": chunksize}, **pandas_kwargs} + + for file_obj in items: + # Here we use pandas chunksize to read the file in chunks and avoid loading the whole file + # in memory. + with file_obj.open() as file: + for df in pd.read_csv(file, **kwargs): + yield df.to_dict(orient="records") + + +def _read_jsonl(items: Iterator[FileItemDict], chunksize: int = 1000) -> Iterator[TDataItems]: + """Reads jsonl file content and extract the data. + + Args: + chunksize (int, optional): The number of JSON lines to load and yield at once, defaults to 1000 + + Returns: + TDataItem: The file content + """ + for file_obj in items: + with file_obj.open() as f: + lines_chunk = [] + for line in f: + lines_chunk.append(json.loadb(line)) + if len(lines_chunk) >= chunksize: + yield lines_chunk + lines_chunk = [] + if lines_chunk: + yield lines_chunk + + +def _read_parquet( + items: Iterator[FileItemDict], + chunksize: int = 10, +) -> Iterator[TDataItems]: + """Reads parquet file content and extract the data. + + Args: + chunksize (int, optional): The number of files to process at once, defaults to 10. + + Returns: + TDataItem: The file content + """ + from pyarrow import parquet as pq + + for file_obj in items: + with file_obj.open() as f: + parquet_file = pq.ParquetFile(f) + for rows in parquet_file.iter_batches(batch_size=chunksize): + yield rows.to_pylist() + + +def _read_csv_duckdb( + items: Iterator[FileItemDict], + chunk_size: Optional[int] = 5000, + use_pyarrow: bool = False, + **duckdb_kwargs: Any +) -> Iterator[TDataItems]: + """A resource to extract data from the given CSV files. + + Uses DuckDB engine to import and cast CSV data. + + Args: + items (Iterator[FileItemDict]): CSV files to read. + chunk_size (Optional[int]): + The number of rows to read at once. Defaults to 5000. + use_pyarrow (bool): + Whether to use `pyarrow` to read the data and designate + data schema. If set to False (by default), JSON is used. + duckdb_kwargs (Dict): + Additional keyword arguments to pass to the `read_csv()`. + + Returns: + Iterable[TDataItem]: Data items, read from the given CSV files. + """ + import duckdb + + helper = fetch_arrow if use_pyarrow else fetch_json + + for item in items: + with item.open() as f: + file_data = duckdb.from_csv_auto(f, **duckdb_kwargs) # type: ignore + + yield from helper(file_data, chunk_size) + + +if TYPE_CHECKING: + + class ReadersSource(DltSource): + """This is a typing stub that provides docstrings and signatures to the resources in `readers" source""" + + @copy_sig_any(_read_csv) + def read_csv(self) -> DltResource: ... + + @copy_sig_any(_read_jsonl) + def read_jsonl(self) -> DltResource: ... + + @copy_sig_any(_read_parquet) + def read_parquet(self) -> DltResource: ... + + @copy_sig_any(_read_csv_duckdb) + def read_csv_duckdb(self) -> DltResource: ... + +else: + ReadersSource = DltSource diff --git a/dlt/sources/filesystem/settings.py b/dlt/sources/filesystem/settings.py new file mode 100644 index 0000000000..33fcb55b5f --- /dev/null +++ b/dlt/sources/filesystem/settings.py @@ -0,0 +1 @@ +DEFAULT_CHUNK_SIZE = 100 diff --git a/dlt/sources/filesystem_pipeline.py b/dlt/sources/filesystem_pipeline.py new file mode 100644 index 0000000000..db570487ef --- /dev/null +++ b/dlt/sources/filesystem_pipeline.py @@ -0,0 +1,196 @@ +# flake8: noqa +import os +from typing import Iterator + +import dlt +from dlt.sources import TDataItems +from dlt.sources.filesystem import FileItemDict, filesystem, readers, read_csv + + +# where the test files are, those examples work with (url) +TESTS_BUCKET_URL = "samples" + + +def stream_and_merge_csv() -> None: + """Demonstrates how to scan folder with csv files, load them in chunk and merge on date column with the previous load""" + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem_csv", + destination="duckdb", + dataset_name="met_data", + ) + # met_data contains 3 columns, where "date" column contain a date on which we want to merge + # load all csvs in A801 + met_files = readers(bucket_url=TESTS_BUCKET_URL, file_glob="met_csv/A801/*.csv").read_csv() + # tell dlt to merge on date + met_files.apply_hints(write_disposition="merge", merge_key="date") + # NOTE: we load to met_csv table + load_info = pipeline.run(met_files.with_name("met_csv")) + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + # now let's simulate loading on next day. not only current data appears but also updated record for the previous day are present + # all the records for previous day will be replaced with new records + met_files = readers(bucket_url=TESTS_BUCKET_URL, file_glob="met_csv/A801/*.csv").read_csv() + met_files.apply_hints(write_disposition="merge", merge_key="date") + load_info = pipeline.run(met_files.with_name("met_csv")) + + # you can also do dlt pipeline standard_filesystem_csv show to confirm that all A801 were replaced with A803 records for overlapping day + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +def read_csv_with_duckdb() -> None: + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem", + destination="duckdb", + dataset_name="met_data_duckdb", + ) + + # load all the CSV data, excluding headers + met_files = readers( + bucket_url=TESTS_BUCKET_URL, file_glob="met_csv/A801/*.csv" + ).read_csv_duckdb(chunk_size=1000, header=True) + + load_info = pipeline.run(met_files) + + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +def read_csv_duckdb_compressed() -> None: + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem", + destination="duckdb", + dataset_name="taxi_data", + full_refresh=True, + ) + + met_files = readers( + bucket_url=TESTS_BUCKET_URL, + file_glob="gzip/*", + ).read_csv_duckdb() + + load_info = pipeline.run(met_files) + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +def read_parquet_and_jsonl_chunked() -> None: + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem", + destination="duckdb", + dataset_name="teams_data", + ) + # When using the readers resource, you can specify a filter to select only the files you + # want to load including a glob pattern. If you use a recursive glob pattern, the filenames + # will include the path to the file inside the bucket_url. + + # JSONL reading (in large chunks!) + jsonl_reader = readers(TESTS_BUCKET_URL, file_glob="**/*.jsonl").read_jsonl(chunksize=10000) + # PARQUET reading + parquet_reader = readers(TESTS_BUCKET_URL, file_glob="**/*.parquet").read_parquet() + # load both folders together to specified tables + load_info = pipeline.run( + [ + jsonl_reader.with_name("jsonl_team_data"), + parquet_reader.with_name("parquet_team_data"), + ] + ) + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +def read_custom_file_type_excel() -> None: + """Here we create an extract pipeline using filesystem resource and read_csv transformer""" + + # instantiate filesystem directly to get list of files (FileItems) and then use read_excel transformer to get + # content of excel via pandas + + @dlt.transformer(standalone=True) + def read_excel(items: Iterator[FileItemDict], sheet_name: str) -> Iterator[TDataItems]: + import pandas as pd + + for file_obj in items: + with file_obj.open() as file: + yield pd.read_excel(file, sheet_name).to_dict(orient="records") + + freshman_xls = filesystem( + bucket_url=TESTS_BUCKET_URL, file_glob="../custom/freshman_kgs.xlsx" + ) | read_excel("freshman_table") + + load_info = dlt.run( + freshman_xls.with_name("freshman"), + destination="duckdb", + dataset_name="freshman_data", + ) + print(load_info) + + +def copy_files_resource(local_folder: str) -> None: + """Demonstrates how to copy files locally by adding a step to filesystem resource and the to load the download listing to db""" + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem_copy", + destination="duckdb", + dataset_name="standard_filesystem_data", + ) + + # a step that copies files into test storage + def _copy(item: FileItemDict) -> FileItemDict: + # instantiate fsspec and copy file + dest_file = os.path.join(local_folder, item["relative_path"]) + # create dest folder + os.makedirs(os.path.dirname(dest_file), exist_ok=True) + # download file + item.fsspec.download(item["file_url"], dest_file) + # return file item unchanged + return item + + # use recursive glob pattern and add file copy step + downloader = filesystem(TESTS_BUCKET_URL, file_glob="**").add_map(_copy) + + # NOTE: you do not need to load any data to execute extract, below we obtain + # a list of files in a bucket and also copy them locally + # listing = list(downloader) + # print(listing) + + # download to table "listing" + # downloader = filesystem(TESTS_BUCKET_URL, file_glob="**").add_map(_copy) + load_info = pipeline.run(downloader.with_name("listing"), write_disposition="replace") + # pretty print the information on data that was loaded + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +def read_files_incrementally_mtime() -> None: + pipeline = dlt.pipeline( + pipeline_name="standard_filesystem_incremental", + destination="duckdb", + dataset_name="file_tracker", + ) + + # here we modify filesystem resource so it will track only new csv files + # such resource may be then combined with transformer doing further processing + new_files = filesystem(bucket_url=TESTS_BUCKET_URL, file_glob="csv/*") + # add incremental on modification time + new_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) + load_info = pipeline.run((new_files | read_csv()).with_name("csv_files")) + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + # load again - no new files! + new_files = filesystem(bucket_url=TESTS_BUCKET_URL, file_glob="csv/*") + # add incremental on modification time + new_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) + load_info = pipeline.run((new_files | read_csv()).with_name("csv_files")) + print(load_info) + print(pipeline.last_trace.last_normalize_info) + + +if __name__ == "__main__": + copy_files_resource("_storage") + stream_and_merge_csv() + read_parquet_and_jsonl_chunked() + read_custom_file_type_excel() + read_files_incrementally_mtime() + read_csv_with_duckdb() + read_csv_duckdb_compressed() diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index d2ca1c1ca6..31c52527da 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,6 +1,5 @@ import math import dataclasses -from abc import abstractmethod from base64 import b64encode from typing import ( TYPE_CHECKING, @@ -157,7 +156,7 @@ class OAuth2ClientCredentials(OAuth2AuthBase): def __init__( self, - access_token_url: TSecretStrValue, + access_token_url: str, client_id: TSecretStrValue, client_secret: TSecretStrValue, access_token_request_data: Dict[str, Any] = None, diff --git a/dlt/sources/pipeline_templates/.dlt/config.toml b/dlt/sources/pipeline_templates/.dlt/config.toml new file mode 100644 index 0000000000..634427baa6 --- /dev/null +++ b/dlt/sources/pipeline_templates/.dlt/config.toml @@ -0,0 +1,5 @@ +# put your configuration values here + +[runtime] +log_level="WARNING" # the system log level of dlt +# use the dlthub_telemetry setting to enable/disable anonymous usage data reporting, see https://dlthub.com/docs/telemetry diff --git a/dlt/sources/pipeline_templates/.gitignore b/dlt/sources/pipeline_templates/.gitignore new file mode 100644 index 0000000000..3b28aa3f63 --- /dev/null +++ b/dlt/sources/pipeline_templates/.gitignore @@ -0,0 +1,10 @@ +# ignore secrets, virtual environments and typical python compilation artifacts +secrets.toml +# ignore basic python artifacts +.env +**/__pycache__/ +**/*.py[cod] +**/*$py.class +# ignore duckdb +*.duckdb +*.wal \ No newline at end of file diff --git a/dlt/sources/pipeline_templates/__init__.py b/dlt/sources/pipeline_templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dlt/sources/pipeline_templates/arrow_pipeline.py b/dlt/sources/pipeline_templates/arrow_pipeline.py new file mode 100644 index 0000000000..92ed0664b9 --- /dev/null +++ b/dlt/sources/pipeline_templates/arrow_pipeline.py @@ -0,0 +1,60 @@ +"""The Arrow Pipeline Template will show how to load and transform arrow tables.""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import dlt +import time +import pyarrow as pa + + +def create_example_arrow_table() -> pa.Table: + return pa.Table.from_pylist([{"name": "tom", "age": 25}, {"name": "angela", "age": 23}]) + + +@dlt.resource(write_disposition="append", name="people") +def resource(): + """One resource function will materialize as a table in the destination, wie yield example data here""" + yield create_example_arrow_table() + + +def add_updated_at(item: pa.Table): + """Map function to add an updated at column to your incoming data.""" + column_count = len(item.columns) + # you will receive and return and arrow table + return item.set_column(column_count, "updated_at", [[time.time()] * item.num_rows]) + + +# apply tranformer to resource +resource.add_map(add_updated_at) + + +@dlt.source +def source(): + """A source function groups all resources into one schema.""" + # return resources + return resource() + + +def load_arrow_tables() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + pipeline = dlt.pipeline( + pipeline_name="arrow", + destination="duckdb", + dataset_name="arrow_data", + ) + + data = list(source().people) + + # print the data yielded from resource without loading it + print(data) # noqa: T201 + + # run the pipeline with your parameters + load_info = pipeline.run(source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_arrow_tables() diff --git a/dlt/sources/pipeline_templates/dataframe_pipeline.py b/dlt/sources/pipeline_templates/dataframe_pipeline.py new file mode 100644 index 0000000000..f9f7746098 --- /dev/null +++ b/dlt/sources/pipeline_templates/dataframe_pipeline.py @@ -0,0 +1,62 @@ +"""The DataFrame Pipeline Template will show how to load and transform pandas dataframes.""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import dlt +import time +import pandas as pd + + +def create_example_dataframe() -> pd.DataFrame: + return pd.DataFrame({"name": ["tom", "angela"], "age": [25, 23]}, columns=["name", "age"]) + + +@dlt.resource(write_disposition="append", name="people") +def resource(): + """One resource function will materialize as a table in the destination, wie yield example data here""" + yield create_example_dataframe() + + +def add_updated_at(item: pd.DataFrame): + """Map function to add an updated at column to your incoming data.""" + column_count = len(item.columns) + # you will receive and return and arrow table + item.insert(column_count, "updated_at", [time.time()] * 2, True) + return item + + +# apply tranformer to resource +resource.add_map(add_updated_at) + + +@dlt.source +def source(): + """A source function groups all resources into one schema.""" + + # return resources + return resource() + + +def load_dataframe() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + pipeline = dlt.pipeline( + pipeline_name="dataframe", + destination="duckdb", + dataset_name="dataframe_data", + ) + + data = list(source().people) + + # print the data yielded from resource without loading it + print(data) # noqa: T201 + + # run the pipeline with your parameters + load_info = pipeline.run(source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_dataframe() diff --git a/dlt/sources/pipeline_templates/debug_pipeline.py b/dlt/sources/pipeline_templates/debug_pipeline.py new file mode 100644 index 0000000000..3699198684 --- /dev/null +++ b/dlt/sources/pipeline_templates/debug_pipeline.py @@ -0,0 +1,64 @@ +"""The Debug Pipeline Template will load a column with each datatype to your destination.""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import dlt + +from dlt.common import Decimal + + +@dlt.resource(write_disposition="append", name="all_datatypes") +def resource(): + """this is the test data for loading validation, delete it once you yield actual data""" + yield [ + { + "col1": 989127831, + "col2": 898912.821982, + "col3": True, + "col4": "2022-05-23T13:26:45.176451+00:00", + "col5": "string data \n \r 🦆", + "col6": Decimal("2323.34"), + "col7": b"binary data \n \r ", + "col8": 2**56 + 92093890840, + "col9": { + "complex": [1, 2, 3, "a"], + "link": ( + "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6" + " \\vity%3A69'08444473\n\n551163392%2C6n \r 9085" + ), + }, + "col10": "2023-02-27", + "col11": "13:26:45.176451", + } + ] + + +@dlt.source +def source(): + """A source function groups all resources into one schema.""" + return resource() + + +def load_all_datatypes() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + pipeline = dlt.pipeline( + pipeline_name="debug", + destination="duckdb", + dataset_name="debug_data", + ) + + data = list(source().all_datatypes) + + # print the data yielded from resource without loading it + print(data) # noqa: T201 + + # run the pipeline with your parameters + load_info = pipeline.run(source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_all_datatypes() diff --git a/dlt/sources/pipeline_templates/default_pipeline.py b/dlt/sources/pipeline_templates/default_pipeline.py new file mode 100644 index 0000000000..9fa03f9ce5 --- /dev/null +++ b/dlt/sources/pipeline_templates/default_pipeline.py @@ -0,0 +1,51 @@ +"""The Default Pipeline Template provides a simple starting point for your dlt pipeline""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import dlt +from dlt.common import Decimal + + +@dlt.resource(name="customers", primary_key="id") +def customers(): + """Load customer data from a simple python list.""" + yield [ + {"id": 1, "name": "simon", "city": "berlin"}, + {"id": 2, "name": "violet", "city": "london"}, + {"id": 3, "name": "tammo", "city": "new york"}, + ] + + +@dlt.resource(name="inventory", primary_key="id") +def inventory(): + """Load inventory data from a simple python list.""" + yield [ + {"id": 1, "name": "apple", "price": Decimal("1.50")}, + {"id": 2, "name": "banana", "price": Decimal("1.70")}, + {"id": 3, "name": "pear", "price": Decimal("2.50")}, + ] + + +@dlt.source(name="my_fruitshop") +def source(): + """A source function groups all resources into one schema.""" + return customers(), inventory() + + +def load_stuff() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + p = dlt.pipeline( + pipeline_name="fruitshop", + destination="duckdb", + dataset_name="fruitshop_data", + ) + + load_info = p.run(source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_stuff() diff --git a/dlt/sources/pipeline_templates/intro_pipeline.py b/dlt/sources/pipeline_templates/intro_pipeline.py new file mode 100644 index 0000000000..a4de18daba --- /dev/null +++ b/dlt/sources/pipeline_templates/intro_pipeline.py @@ -0,0 +1,82 @@ +"""The Intro Pipeline Template contains the example from the docs intro page""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import pandas as pd +import sqlalchemy as sa + +import dlt +from dlt.sources.helpers import requests + + +def load_api_data() -> None: + """Load data from the chess api, for more complex examples use our rest_api source""" + + # Create a dlt pipeline that will load + # chess player data to the DuckDB destination + pipeline = dlt.pipeline( + pipeline_name="chess_pipeline", destination="duckdb", dataset_name="player_data" + ) + # Grab some player data from Chess.com API + data = [] + for player in ["magnuscarlsen", "rpragchess"]: + response = requests.get(f"https://api.chess.com/pub/player/{player}") + response.raise_for_status() + data.append(response.json()) + + # Extract, normalize, and load the data + load_info = pipeline.run(data, table_name="player") + print(load_info) # noqa: T201 + + +def load_pandas_data() -> None: + """Load data from a public csv via pandas""" + + owid_disasters_csv = ( + "https://raw.githubusercontent.com/owid/owid-datasets/master/datasets/" + "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020)/" + "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020).csv" + ) + df = pd.read_csv(owid_disasters_csv) + data = df.to_dict(orient="records") + + pipeline = dlt.pipeline( + pipeline_name="from_csv", + destination="duckdb", + dataset_name="mydata", + ) + load_info = pipeline.run(data, table_name="natural_disasters") + + print(load_info) # noqa: T201 + + +def load_sql_data() -> None: + """Load data from a sql database with sqlalchemy, for more complex examples use our sql_database source""" + + # Use any SQL database supported by SQLAlchemy, below we use a public + # MySQL instance to get data. + # NOTE: you'll need to install pymysql with `pip install pymysql` + # NOTE: loading data from public mysql instance may take several seconds + engine = sa.create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") + + with engine.connect() as conn: + # Select genome table, stream data in batches of 100 elements + query = "SELECT * FROM genome LIMIT 1000" + rows = conn.execution_options(yield_per=100).exec_driver_sql(query) + + pipeline = dlt.pipeline( + pipeline_name="from_database", + destination="duckdb", + dataset_name="genome_data", + ) + + # Convert the rows into dictionaries on the fly with a map function + load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") + + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_api_data() + load_pandas_data() + load_sql_data() diff --git a/dlt/sources/pipeline_templates/requests_pipeline.py b/dlt/sources/pipeline_templates/requests_pipeline.py new file mode 100644 index 0000000000..19acaa1fdb --- /dev/null +++ b/dlt/sources/pipeline_templates/requests_pipeline.py @@ -0,0 +1,61 @@ +"""The Requests Pipeline Template provides a simple starting point for a dlt pipeline with the requests library""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +from typing import Iterator, Any + +import dlt + +from dlt.sources.helpers import requests +from dlt.sources import TDataItems + + +YEAR = 2022 +MONTH = 10 +BASE_PATH = "https://api.chess.com/pub/player" + + +@dlt.resource(name="players", primary_key="player_id") +def players(): + """Load player profiles from the chess api.""" + for player_name in ["magnuscarlsen", "rpragchess"]: + path = f"{BASE_PATH}/{player_name}" + response = requests.get(path) + response.raise_for_status() + yield response.json() + + +# this resource takes data from players and returns games for the configured +@dlt.transformer(data_from=players, write_disposition="append") +def players_games(player: Any) -> Iterator[TDataItems]: + """Load all games for each player in october 2022""" + player_name = player["username"] + path = f"{BASE_PATH}/{player_name}/games/{YEAR:04d}/{MONTH:02d}" + response = requests.get(path) + response.raise_for_status() + yield response.json()["games"] + + +@dlt.source(name="chess") +def source(): + """A source function groups all resources into one schema.""" + return players(), players_games() + + +def load_chess_data() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + p = dlt.pipeline( + pipeline_name="chess", + destination="duckdb", + dataset_name="chess_data", + ) + + load_info = p.run(source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_chess_data() diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py new file mode 100644 index 0000000000..fa6b691933 --- /dev/null +++ b/dlt/sources/rest_api/__init__.py @@ -0,0 +1,465 @@ +"""Generic API Source""" +from copy import deepcopy +from typing import Type, Any, Dict, List, Optional, Generator, Callable, cast, Union +import graphlib # type: ignore[import,unused-ignore] +from requests.auth import AuthBase + +import dlt +from dlt.common.validation import validate_dict +from dlt.common import jsonpath +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TSchemaContract +from dlt.common.configuration.specs import BaseConfiguration + +from dlt.extract.incremental import Incremental +from dlt.extract.source import DltResource, DltSource + +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.paginators import BasePaginator +from dlt.sources.helpers.rest_client.auth import ( + HttpBasicAuth, + BearerTokenAuth, + APIKeyAuth, + AuthConfigBase, +) +from dlt.sources.helpers.rest_client.typing import HTTPMethodBasic +from .typing import ( + AuthConfig, + ClientConfig, + ResolvedParam, + ResolveParamConfig, + Endpoint, + EndpointResource, + IncrementalParamConfig, + RESTAPIConfig, + ParamBindType, + ProcessingSteps, +) +from .config_setup import ( + IncrementalParam, + create_auth, + create_paginator, + build_resource_dependency_graph, + process_parent_data_item, + setup_incremental_object, + create_response_hooks, +) +from .utils import check_connection, exclude_keys # noqa: F401 + +PARAM_TYPES: List[ParamBindType] = ["incremental", "resolve"] +MIN_SECRET_MASKING_LENGTH = 3 +SENSITIVE_KEYS: List[str] = [ + "token", + "api_key", + "username", + "password", +] + + +def rest_api_source( + config: RESTAPIConfig, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = False, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, +) -> DltSource: + """Creates and configures a REST API source for data extraction. + + Args: + config (RESTAPIConfig): Configuration for the REST API source. + name (str, optional): Name of the source. + section (str, optional): Section of the configuration file. + max_table_nesting (int, optional): Maximum depth of nested table above which + the remaining nodes are loaded as structs or JSON. + root_key (bool, optional): Enables merging on all resources by propagating + root foreign key to child tables. This option is most useful if you + plan to change write disposition of a resource to disable/enable merge. + Defaults to False. + schema (Schema, optional): An explicit `Schema` instance to be associated + with the source. If not present, `dlt` creates a new `Schema` object + with provided `name`. If such `Schema` already exists in the same + folder as the module containing the decorated function, such schema + will be loaded from file. + schema_contract (TSchemaContract, optional): Schema contract settings + that will be applied to this resource. + spec (Type[BaseConfiguration], optional): A specification of configuration + and secret values required by the source. + + Returns: + DltSource: A configured dlt source. + + Example: + pokemon_source = rest_api_source({ + "client": { + "base_url": "https://pokeapi.co/api/v2/", + "paginator": "json_link", + }, + "endpoints": { + "pokemon": { + "params": { + "limit": 100, # Default page size is 20 + }, + "resource": { + "primary_key": "id", + } + }, + }, + }) + """ + decorated = dlt.source( + rest_api_resources, + name, + section, + max_table_nesting, + root_key, + schema, + schema_contract, + spec, + ) + + return decorated(config) + + +def rest_api_resources(config: RESTAPIConfig) -> List[DltResource]: + """Creates a list of resources from a REST API configuration. + + Args: + config (RESTAPIConfig): Configuration for the REST API source. + + Returns: + List[DltResource]: List of dlt resources. + + Example: + github_source = rest_api_resources({ + "client": { + "base_url": "https://api.github.com/repos/dlt-hub/dlt/", + "auth": { + "token": dlt.secrets["token"], + }, + }, + "resource_defaults": { + "primary_key": "id", + "write_disposition": "merge", + "endpoint": { + "params": { + "per_page": 100, + }, + }, + }, + "resources": [ + { + "name": "issues", + "endpoint": { + "path": "issues", + "params": { + "sort": "updated", + "direction": "desc", + "state": "open", + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-25T11:21:28Z", + }, + }, + }, + }, + { + "name": "issue_comments", + "endpoint": { + "path": "issues/{issue_number}/comments", + "params": { + "issue_number": { + "type": "resolve", + "resource": "issues", + "field": "number", + } + }, + }, + }, + ], + }) + """ + + _validate_config(config) + + client_config = config["client"] + resource_defaults = config.get("resource_defaults", {}) + resource_list = config["resources"] + + ( + dependency_graph, + endpoint_resource_map, + resolved_param_map, + ) = build_resource_dependency_graph( + resource_defaults, + resource_list, + ) + + resources = create_resources( + client_config, + dependency_graph, + endpoint_resource_map, + resolved_param_map, + ) + + return list(resources.values()) + + +def create_resources( + client_config: ClientConfig, + dependency_graph: graphlib.TopologicalSorter, + endpoint_resource_map: Dict[str, EndpointResource], + resolved_param_map: Dict[str, Optional[ResolvedParam]], +) -> Dict[str, DltResource]: + resources = {} + + for resource_name in dependency_graph.static_order(): + resource_name = cast(str, resource_name) + endpoint_resource = endpoint_resource_map[resource_name] + endpoint_config = cast(Endpoint, endpoint_resource["endpoint"]) + request_params = endpoint_config.get("params", {}) + request_json = endpoint_config.get("json", None) + paginator = create_paginator(endpoint_config.get("paginator")) + processing_steps = endpoint_resource.pop("processing_steps", []) + + resolved_param: ResolvedParam = resolved_param_map[resource_name] + + include_from_parent: List[str] = endpoint_resource.get("include_from_parent", []) + if not resolved_param and include_from_parent: + raise ValueError( + f"Resource {resource_name} has include_from_parent but is not " + "dependent on another resource" + ) + _validate_param_type(request_params) + ( + incremental_object, + incremental_param, + incremental_cursor_transform, + ) = setup_incremental_object(request_params, endpoint_config.get("incremental")) + + client = RESTClient( + base_url=client_config["base_url"], + headers=client_config.get("headers"), + auth=create_auth(client_config.get("auth")), + paginator=create_paginator(client_config.get("paginator")), + ) + + hooks = create_response_hooks(endpoint_config.get("response_actions")) + + resource_kwargs = exclude_keys(endpoint_resource, {"endpoint", "include_from_parent"}) + + def process( + resource: DltResource, + processing_steps: List[ProcessingSteps], + ) -> Any: + for step in processing_steps: + if "filter" in step: + resource.add_filter(step["filter"]) + if "map" in step: + resource.add_map(step["map"]) + return resource + + if resolved_param is None: + + def paginate_resource( + method: HTTPMethodBasic, + path: str, + params: Dict[str, Any], + json: Optional[Dict[str, Any]], + paginator: Optional[BasePaginator], + data_selector: Optional[jsonpath.TJsonPath], + hooks: Optional[Dict[str, Any]], + client: RESTClient = client, + incremental_object: Optional[Incremental[Any]] = incremental_object, + incremental_param: Optional[IncrementalParam] = incremental_param, + incremental_cursor_transform: Optional[ + Callable[..., Any] + ] = incremental_cursor_transform, + ) -> Generator[Any, None, None]: + if incremental_object: + params = _set_incremental_params( + params, + incremental_object, + incremental_param, + incremental_cursor_transform, + ) + + yield from client.paginate( + method=method, + path=path, + params=params, + json=json, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ) + + resources[resource_name] = dlt.resource( + paginate_resource, + **resource_kwargs, # TODO: implement typing.Unpack + )( + method=endpoint_config.get("method", "get"), + path=endpoint_config.get("path"), + params=request_params, + json=request_json, + paginator=paginator, + data_selector=endpoint_config.get("data_selector"), + hooks=hooks, + ) + + resources[resource_name] = process(resources[resource_name], processing_steps) + + else: + predecessor = resources[resolved_param.resolve_config["resource"]] + + base_params = exclude_keys(request_params, {resolved_param.param_name}) + + def paginate_dependent_resource( + items: List[Dict[str, Any]], + method: HTTPMethodBasic, + path: str, + params: Dict[str, Any], + paginator: Optional[BasePaginator], + data_selector: Optional[jsonpath.TJsonPath], + hooks: Optional[Dict[str, Any]], + client: RESTClient = client, + resolved_param: ResolvedParam = resolved_param, + include_from_parent: List[str] = include_from_parent, + incremental_object: Optional[Incremental[Any]] = incremental_object, + incremental_param: Optional[IncrementalParam] = incremental_param, + incremental_cursor_transform: Optional[ + Callable[..., Any] + ] = incremental_cursor_transform, + ) -> Generator[Any, None, None]: + if incremental_object: + params = _set_incremental_params( + params, + incremental_object, + incremental_param, + incremental_cursor_transform, + ) + + for item in items: + formatted_path, parent_record = process_parent_data_item( + path, item, resolved_param, include_from_parent + ) + + for child_page in client.paginate( + method=method, + path=formatted_path, + params=params, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ): + if parent_record: + for child_record in child_page: + child_record.update(parent_record) + yield child_page + + resources[resource_name] = dlt.resource( # type: ignore[call-overload] + paginate_dependent_resource, + data_from=predecessor, + **resource_kwargs, # TODO: implement typing.Unpack + )( + method=endpoint_config.get("method", "get"), + path=endpoint_config.get("path"), + params=base_params, + paginator=paginator, + data_selector=endpoint_config.get("data_selector"), + hooks=hooks, + ) + + resources[resource_name] = process(resources[resource_name], processing_steps) + + return resources + + +def _validate_config(config: RESTAPIConfig) -> None: + c = deepcopy(config) + client_config = c.get("client") + if client_config: + auth = client_config.get("auth") + if auth: + auth = _mask_secrets(auth) + + validate_dict(RESTAPIConfig, c, path=".") + + +def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: + if isinstance(auth_config, AuthBase) and not isinstance(auth_config, AuthConfigBase): + return auth_config + + has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS) + if isinstance(auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth)) or has_sensitive_key: + return _mask_secrets_dict(auth_config) + # Here, we assume that OAuth2 and other custom classes that don't implement __get__() + # also don't print secrets in __str__() + # TODO: call auth_config.mask_secrets() when that is implemented in dlt-core + return auth_config + + +def _mask_secrets_dict(auth_config: AuthConfig) -> AuthConfig: + for sensitive_key in SENSITIVE_KEYS: + try: + auth_config[sensitive_key] = _mask_secret(auth_config[sensitive_key]) # type: ignore[literal-required, index] + except KeyError: + continue + return auth_config + + +def _mask_secret(secret: Optional[str]) -> str: + if secret is None: + return "None" + if len(secret) < MIN_SECRET_MASKING_LENGTH: + return "*****" + return f"{secret[0]}*****{secret[-1]}" + + +def _set_incremental_params( + params: Dict[str, Any], + incremental_object: Incremental[Any], + incremental_param: IncrementalParam, + transform: Optional[Callable[..., Any]], +) -> Dict[str, Any]: + def identity_func(x: Any) -> Any: + return x + + if transform is None: + transform = identity_func + params[incremental_param.start] = transform(incremental_object.last_value) + if incremental_param.end: + params[incremental_param.end] = transform(incremental_object.end_value) + return params + + +def _validate_param_type( + request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]] +) -> None: + for _, value in request_params.items(): + if isinstance(value, dict) and value.get("type") not in PARAM_TYPES: + raise ValueError( + f"Invalid param type: {value.get('type')}. Available options: {PARAM_TYPES}" + ) + + +# XXX: This is a workaround pass test_dlt_init.py +# since the source uses dlt.source as a function +def _register_source(source_func: Callable[..., DltSource]) -> None: + import inspect + from dlt.common.configuration import get_fun_spec + from dlt.common.source import _SOURCES, SourceInfo + + spec = get_fun_spec(source_func) + func_module = inspect.getmodule(source_func) + _SOURCES[source_func.__name__] = SourceInfo( + SPEC=spec, + f=source_func, + module=func_module, + ) + + +_register_source(rest_api_source) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py new file mode 100644 index 0000000000..7bf6c81634 --- /dev/null +++ b/dlt/sources/rest_api/config_setup.py @@ -0,0 +1,634 @@ +import warnings +from copy import copy +from typing import ( + Type, + Any, + Dict, + Tuple, + List, + Optional, + Union, + Callable, + cast, + NamedTuple, +) +import graphlib # type: ignore[import,unused-ignore] +import string + +import dlt +from dlt.common import logger +from dlt.common.configuration import resolve_configuration +from dlt.common.schema.utils import merge_columns +from dlt.common.utils import update_dict_nested +from dlt.common import jsonpath + +from dlt.extract.incremental import Incremental +from dlt.extract.utils import ensure_table_schema_columns + +from dlt.sources.helpers.requests import Response +from dlt.sources.helpers.rest_client.paginators import ( + BasePaginator, + SinglePagePaginator, + HeaderLinkPaginator, + JSONResponseCursorPaginator, + OffsetPaginator, + PageNumberPaginator, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + from dlt.sources.helpers.rest_client.paginators import ( + JSONResponsePaginator as JSONLinkPaginator, + ) + +from dlt.sources.helpers.rest_client.detector import single_entity_path +from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException +from dlt.sources.helpers.rest_client.auth import ( + AuthConfigBase, + HttpBasicAuth, + BearerTokenAuth, + APIKeyAuth, + OAuth2ClientCredentials, +) + +from .typing import ( + EndpointResourceBase, + AuthConfig, + IncrementalConfig, + PaginatorConfig, + ResolvedParam, + ResponseAction, + ResponseActionDict, + Endpoint, + EndpointResource, +) +from .utils import exclude_keys + + +PAGINATOR_MAP: Dict[str, Type[BasePaginator]] = { + "json_link": JSONLinkPaginator, + "json_response": ( + JSONLinkPaginator + ), # deprecated. Use json_link instead. Will be removed in upcoming release + "header_link": HeaderLinkPaginator, + "auto": None, + "single_page": SinglePagePaginator, + "cursor": JSONResponseCursorPaginator, + "offset": OffsetPaginator, + "page_number": PageNumberPaginator, +} + +AUTH_MAP: Dict[str, Type[AuthConfigBase]] = { + "bearer": BearerTokenAuth, + "api_key": APIKeyAuth, + "http_basic": HttpBasicAuth, + "oauth2_client_credentials": OAuth2ClientCredentials, +} + + +class IncrementalParam(NamedTuple): + start: str + end: Optional[str] + + +def register_paginator( + paginator_name: str, + paginator_class: Type[BasePaginator], +) -> None: + if not issubclass(paginator_class, BasePaginator): + raise ValueError( + f"Invalid paginator: {paginator_class.__name__}. " + "Your custom paginator has to be a subclass of BasePaginator" + ) + PAGINATOR_MAP[paginator_name] = paginator_class + + +def get_paginator_class(paginator_name: str) -> Type[BasePaginator]: + try: + return PAGINATOR_MAP[paginator_name] + except KeyError: + available_options = ", ".join(PAGINATOR_MAP.keys()) + raise ValueError( + f"Invalid paginator: {paginator_name}. Available options: {available_options}." + ) + + +def create_paginator( + paginator_config: Optional[PaginatorConfig], +) -> Optional[BasePaginator]: + if isinstance(paginator_config, BasePaginator): + return paginator_config + + if isinstance(paginator_config, str): + paginator_class = get_paginator_class(paginator_config) + try: + # `auto` has no associated class in `PAGINATOR_MAP` + return paginator_class() if paginator_class else None + except TypeError: + raise ValueError( + f"Paginator {paginator_config} requires arguments to create an instance. Use" + f" {paginator_class} instance instead." + ) + + if isinstance(paginator_config, dict): + paginator_type = paginator_config.get("type", "auto") + paginator_class = get_paginator_class(paginator_type) + return ( + paginator_class(**exclude_keys(paginator_config, {"type"})) if paginator_class else None + ) + + return None + + +def register_auth( + auth_name: str, + auth_class: Type[AuthConfigBase], +) -> None: + if not issubclass(auth_class, AuthConfigBase): + raise ValueError( + f"Invalid auth: {auth_class.__name__}. " + "Your custom auth has to be a subclass of AuthConfigBase" + ) + AUTH_MAP[auth_name] = auth_class + + +def get_auth_class(auth_type: str) -> Type[AuthConfigBase]: + try: + return AUTH_MAP[auth_type] + except KeyError: + available_options = ", ".join(AUTH_MAP.keys()) + raise ValueError( + f"Invalid authentication: {auth_type}. Available options: {available_options}." + ) + + +def create_auth(auth_config: Optional[AuthConfig]) -> Optional[AuthConfigBase]: + auth: AuthConfigBase = None + if isinstance(auth_config, AuthConfigBase): + auth = auth_config + + if isinstance(auth_config, str): + auth_class = get_auth_class(auth_config) + auth = auth_class() + + if isinstance(auth_config, dict): + auth_type = auth_config.get("type", "bearer") + auth_class = get_auth_class(auth_type) + auth = auth_class(**exclude_keys(auth_config, {"type"})) + + if auth: + # TODO: provide explicitly (non-default) values as explicit explicit_value=dict(auth) + # this will resolve auth which is a configuration using current section context + return resolve_configuration(auth, accept_partial=True) + + return None + + +def setup_incremental_object( + request_params: Dict[str, Any], + incremental_config: Optional[IncrementalConfig] = None, +) -> Tuple[Optional[Incremental[Any]], Optional[IncrementalParam], Optional[Callable[..., Any]]]: + incremental_params: List[str] = [] + for param_name, param_config in request_params.items(): + if ( + isinstance(param_config, dict) + and param_config.get("type") == "incremental" + or isinstance(param_config, dlt.sources.incremental) + ): + incremental_params.append(param_name) + if len(incremental_params) > 1: + raise ValueError( + "Only a single incremental parameter is allower per endpoint. Found:" + f" {incremental_params}" + ) + convert: Optional[Callable[..., Any]] + for param_name, param_config in request_params.items(): + if isinstance(param_config, dlt.sources.incremental): + if param_config.end_value is not None: + raise ValueError( + f"Only initial_value is allowed in the configuration of param: {param_name}. To" + " set end_value too use the incremental configuration at the resource level." + " See" + " https://dlthub.com/docs/dlt-ecosystem/verified-sources/rest_api#incremental-loading/" + ) + return param_config, IncrementalParam(start=param_name, end=None), None + if isinstance(param_config, dict) and param_config.get("type") == "incremental": + if param_config.get("end_value") or param_config.get("end_param"): + raise ValueError( + "Only start_param and initial_value are allowed in the configuration of param:" + f" {param_name}. To set end_value too use the incremental configuration at the" + " resource level. See" + " https://dlthub.com/docs/dlt-ecosystem/verified-sources/rest_api#incremental-loading" + ) + convert = parse_convert_or_deprecated_transform(param_config) + + config = exclude_keys(param_config, {"type", "convert", "transform"}) + # TODO: implement param type to bind incremental to + return ( + dlt.sources.incremental(**config), + IncrementalParam(start=param_name, end=None), + convert, + ) + if incremental_config: + convert = parse_convert_or_deprecated_transform(incremental_config) + config = exclude_keys( + incremental_config, {"start_param", "end_param", "convert", "transform"} + ) + return ( + dlt.sources.incremental(**config), + IncrementalParam( + start=incremental_config["start_param"], + end=incremental_config.get("end_param"), + ), + convert, + ) + + return None, None, None + + +def parse_convert_or_deprecated_transform( + config: Union[IncrementalConfig, Dict[str, Any]] +) -> Optional[Callable[..., Any]]: + convert = config.get("convert", None) + deprecated_transform = config.get("transform", None) + if deprecated_transform: + warnings.warn( + "The key `transform` is deprecated in the incremental configuration and it will be" + " removed. Use `convert` instead", + DeprecationWarning, + stacklevel=2, + ) + convert = deprecated_transform + return convert + + +def make_parent_key_name(resource_name: str, field_name: str) -> str: + return f"_{resource_name}_{field_name}" + + +def build_resource_dependency_graph( + resource_defaults: EndpointResourceBase, + resource_list: List[Union[str, EndpointResource]], +) -> Tuple[Any, Dict[str, EndpointResource], Dict[str, Optional[ResolvedParam]]]: + dependency_graph = graphlib.TopologicalSorter() + endpoint_resource_map: Dict[str, EndpointResource] = {} + resolved_param_map: Dict[str, ResolvedParam] = {} + + # expand all resources and index them + for resource_kwargs in resource_list: + if isinstance(resource_kwargs, dict): + # clone resource here, otherwise it needs to be cloned in several other places + # note that this clones only dict structure, keeping all instances without deepcopy + resource_kwargs = update_dict_nested({}, resource_kwargs) # type: ignore + + endpoint_resource = _make_endpoint_resource(resource_kwargs, resource_defaults) + assert isinstance(endpoint_resource["endpoint"], dict) + _setup_single_entity_endpoint(endpoint_resource["endpoint"]) + _bind_path_params(endpoint_resource) + + resource_name = endpoint_resource["name"] + assert isinstance( + resource_name, str + ), f"Resource name must be a string, got {type(resource_name)}" + + if resource_name in endpoint_resource_map: + raise ValueError(f"Resource {resource_name} has already been defined") + endpoint_resource_map[resource_name] = endpoint_resource + + # create dependency graph + for resource_name, endpoint_resource in endpoint_resource_map.items(): + assert isinstance(endpoint_resource["endpoint"], dict) + # connect transformers to resources via resolved params + resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) + if len(resolved_params) > 1: + raise ValueError( + f"Multiple resolved params for resource {resource_name}: {resolved_params}" + ) + elif len(resolved_params) == 1: + resolved_param = resolved_params[0] + predecessor = resolved_param.resolve_config["resource"] + if predecessor not in endpoint_resource_map: + raise ValueError( + f"A transformer resource {resource_name} refers to non existing parent resource" + f" {predecessor} on {resolved_param}" + ) + dependency_graph.add(resource_name, predecessor) + resolved_param_map[resource_name] = resolved_param + else: + dependency_graph.add(resource_name) + resolved_param_map[resource_name] = None + + return dependency_graph, endpoint_resource_map, resolved_param_map + + +def _make_endpoint_resource( + resource: Union[str, EndpointResource], default_config: EndpointResourceBase +) -> EndpointResource: + """ + Creates an EndpointResource object based on the provided resource + definition and merges it with the default configuration. + + This function supports defining a resource in multiple formats: + - As a string: The string is interpreted as both the resource name + and its endpoint path. + - As a dictionary: The dictionary must include `name` and `endpoint` + keys. The `endpoint` can be a string representing the path, + or a dictionary for more complex configurations. If the `endpoint` + is missing the `path` key, the resource name is used as the `path`. + """ + if isinstance(resource, str): + resource = {"name": resource, "endpoint": {"path": resource}} + return _merge_resource_endpoints(default_config, resource) + + if "endpoint" in resource: + if isinstance(resource["endpoint"], str): + resource["endpoint"] = {"path": resource["endpoint"]} + else: + # endpoint is optional + resource["endpoint"] = {} + + if "path" not in resource["endpoint"]: + resource["endpoint"]["path"] = resource["name"] # type: ignore + + return _merge_resource_endpoints(default_config, resource) + + +def _bind_path_params(resource: EndpointResource) -> None: + """Binds params declared in path to params available in `params`. Pops the + bound params but. Params of type `resolve` and `incremental` are skipped + and bound later. + """ + path_params: Dict[str, Any] = {} + assert isinstance(resource["endpoint"], dict) # type guard + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] + path = resource["endpoint"]["path"] + for format_ in string.Formatter().parse(path): + name = format_[1] + if name: + params = resource["endpoint"].get("params", {}) + if name not in params and name not in path_params: + raise ValueError( + f"The path {path} defined in resource {resource['name']} requires param with" + f" name {name} but it is not found in {params}" + ) + if name in resolve_params: + resolve_params.remove(name) + if name in params: + if not isinstance(params[name], dict): + # bind resolved param and pop it from endpoint + path_params[name] = params.pop(name) + else: + param_type = params[name].get("type") + if param_type != "resolve": + raise ValueError( + f"The path {path} defined in resource {resource['name']} tries to bind" + f" param {name} with type {param_type}. Paths can only bind 'resource'" + " type params." + ) + # resolved params are bound later + path_params[name] = "{" + name + "}" + + if len(resolve_params) > 0: + raise NotImplementedError( + f"Resource {resource['name']} defines resolve params {resolve_params} that are not" + f" bound in path {path}. Resolve query params not supported yet." + ) + + resource["endpoint"]["path"] = path.format(**path_params) + + +def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: + """Tries to guess if the endpoint refers to a single entity and when detected: + * if `data_selector` was not specified (or is None), "$" is selected + * if `paginator` was not specified (or is None), SinglePagePaginator is selected + + Endpoint is modified in place and returned + """ + # try to guess if list of entities or just single entity is returned + if single_entity_path(endpoint["path"]): + if endpoint.get("data_selector") is None: + endpoint["data_selector"] = "$" + if endpoint.get("paginator") is None: + endpoint["paginator"] = SinglePagePaginator() + return endpoint + + +def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]: + """ + Find all resolved params in the endpoint configuration and return + a list of ResolvedParam objects. + + Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".) + """ + return [ + ResolvedParam(key, value) # type: ignore[arg-type] + for key, value in endpoint_config.get("params", {}).items() + if (isinstance(value, dict) and value.get("type") == "resolve") + ] + + +def _action_type_unless_custom_hook( + action_type: Optional[str], custom_hook: Optional[List[Callable[..., Any]]] +) -> Union[Tuple[str, Optional[List[Callable[..., Any]]]], Tuple[None, List[Callable[..., Any]]],]: + if custom_hook: + return (None, custom_hook) + return (action_type, None) + + +def _handle_response_action( + response: Response, + action: ResponseAction, +) -> Union[ + Tuple[str, Optional[List[Callable[..., Any]]]], + Tuple[None, List[Callable[..., Any]]], + Tuple[None, None], +]: + """ + Checks, based on the response, if the provided action applies. + """ + content: str = response.text + status_code = None + content_substr = None + action_type = None + custom_hooks = None + response_action = None + if callable(action): + custom_hooks = [action] + else: + action = cast(ResponseActionDict, action) + status_code = action.get("status_code") + content_substr = action.get("content") + response_action = action.get("action") + if isinstance(response_action, str): + action_type = response_action + elif callable(response_action): + custom_hooks = [response_action] + elif isinstance(response_action, list) and all( + callable(action) for action in response_action + ): + custom_hooks = response_action + else: + raise ValueError( + f"Action {response_action} does not conform to expected type. Expected: str or" + f" Callable or List[Callable]. Found: {type(response_action)}" + ) + + if status_code is not None and content_substr is not None: + if response.status_code == status_code and content_substr in content: + return _action_type_unless_custom_hook(action_type, custom_hooks) + + elif status_code is not None: + if response.status_code == status_code: + return _action_type_unless_custom_hook(action_type, custom_hooks) + + elif content_substr is not None: + if content_substr in content: + return _action_type_unless_custom_hook(action_type, custom_hooks) + + elif status_code is None and content_substr is None and custom_hooks is not None: + return (None, custom_hooks) + + return (None, None) + + +def _create_response_action_hook( + response_action: ResponseAction, +) -> Callable[[Response, Any, Any], None]: + def response_action_hook(response: Response, *args: Any, **kwargs: Any) -> None: + """ + This is the hook executed by the requests library + """ + (action_type, custom_hooks) = _handle_response_action(response, response_action) + if custom_hooks: + for hook in custom_hooks: + hook(response) + elif action_type == "ignore": + logger.info( + f"Ignoring response with code {response.status_code} " + f"and content '{response.json()}'." + ) + raise IgnoreResponseException + + # If there are hooks, then the REST client does not raise for status + # If no action has been taken and the status code indicates an error, + # raise an HTTP error based on the response status + elif not action_type: + response.raise_for_status() + + return response_action_hook + + +def create_response_hooks( + response_actions: Optional[List[ResponseAction]], +) -> Optional[Dict[str, Any]]: + """Create response hooks based on the provided response actions. Note + that if the error status code is not handled by the response actions, + the default behavior is to raise an HTTP error. + + Example: + def set_encoding(response, *args, **kwargs): + response.encoding = 'windows-1252' + return response + + def remove_field(response: Response, *args, **kwargs) -> Response: + payload = response.json() + for record in payload: + record.pop("email", None) + modified_content: bytes = json.dumps(payload).encode("utf-8") + response._content = modified_content + return response + + response_actions = [ + set_encoding, + {"status_code": 404, "action": "ignore"}, + {"content": "Not found", "action": "ignore"}, + {"status_code": 200, "content": "some text", "action": "ignore"}, + {"status_code": 200, "action": remove_field}, + ] + hooks = create_response_hooks(response_actions) + """ + if response_actions: + hooks = [_create_response_action_hook(action) for action in response_actions] + return {"response": hooks} + return None + + +def process_parent_data_item( + path: str, + item: Dict[str, Any], + resolved_param: ResolvedParam, + include_from_parent: List[str], +) -> Tuple[str, Dict[str, Any]]: + parent_resource_name = resolved_param.resolve_config["resource"] + + field_values = jsonpath.find_values(resolved_param.field_path, item) + + if not field_values: + field_path = resolved_param.resolve_config["field"] + raise ValueError( + f"Transformer expects a field '{field_path}' to be present in the incoming data from" + f" resource {parent_resource_name} in order to bind it to path param" + f" {resolved_param.param_name}. Available parent fields are {', '.join(item.keys())}" + ) + bound_path = path.format(**{resolved_param.param_name: field_values[0]}) + + parent_record: Dict[str, Any] = {} + if include_from_parent: + for parent_key in include_from_parent: + child_key = make_parent_key_name(parent_resource_name, parent_key) + if parent_key not in item: + raise ValueError( + f"Transformer expects a field '{parent_key}' to be present in the incoming data" + f" from resource {parent_resource_name} in order to include it in child records" + f" under {child_key}. Available parent fields are {', '.join(item.keys())}" + ) + parent_record[child_key] = item[parent_key] + + return bound_path, parent_record + + +def _merge_resource_endpoints( + default_config: EndpointResourceBase, config: EndpointResource +) -> EndpointResource: + """Merges `default_config` and `config`, returns new instance of EndpointResource""" + # NOTE: config is normalized and always has "endpoint" field which is a dict + # TODO: could deep merge paginators and auths of the same type + + default_endpoint = default_config.get("endpoint", Endpoint()) + assert isinstance(default_endpoint, dict) + config_endpoint = config["endpoint"] + assert isinstance(config_endpoint, dict) + + merged_endpoint: Endpoint = { + **default_endpoint, + **{k: v for k, v in config_endpoint.items() if k not in ("json", "params")}, # type: ignore[typeddict-item] + } + # merge endpoint, only params and json are allowed to deep merge + if "json" in config_endpoint: + merged_endpoint["json"] = { + **(merged_endpoint.get("json", {})), + **config_endpoint["json"], + } + if "params" in config_endpoint: + merged_endpoint["params"] = { + **(merged_endpoint.get("params", {})), + **config_endpoint["params"], + } + # merge columns + if (default_columns := default_config.get("columns")) and (columns := config.get("columns")): + # merge only native dlt formats, skip pydantic and others + if isinstance(columns, (list, dict)) and isinstance(default_columns, (list, dict)): + # normalize columns + columns = ensure_table_schema_columns(columns) + default_columns = ensure_table_schema_columns(default_columns) + # merge columns with deep merging hints + config["columns"] = merge_columns(copy(default_columns), columns, merge_columns=True) + + # no need to deep merge resources + merged_resource: EndpointResource = { + **default_config, + **config, + "endpoint": merged_endpoint, + } + return merged_resource diff --git a/dlt/sources/rest_api/exceptions.py b/dlt/sources/rest_api/exceptions.py new file mode 100644 index 0000000000..24fd5b31b0 --- /dev/null +++ b/dlt/sources/rest_api/exceptions.py @@ -0,0 +1,8 @@ +from dlt.common.exceptions import DltException + + +class RestApiException(DltException): + pass + + +# class Paginator diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py new file mode 100644 index 0000000000..22a9560433 --- /dev/null +++ b/dlt/sources/rest_api/typing.py @@ -0,0 +1,280 @@ +from dataclasses import dataclass, field +from typing_extensions import TypedDict + +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Union, +) + +from dlt.common import jsonpath +from dlt.common.schema.typing import ( + TAnySchemaColumns, +) +from dlt.extract.incremental.typing import IncrementalArgs +from dlt.extract.items import TTableHintTemplate +from dlt.extract.hints import TResourceHintsBase +from dlt.sources.helpers.rest_client.auth import AuthConfigBase, TApiKeyLocation + +from dataclasses import dataclass, field + +from dlt.common import jsonpath +from dlt.common.typing import TSortOrder +from dlt.common.schema.typing import ( + TColumnNames, + TTableFormat, + TAnySchemaColumns, + TWriteDispositionConfig, + TSchemaContract, +) + +from dlt.extract.items import TTableHintTemplate +from dlt.extract.incremental.typing import LastValueFunc + +from dlt.sources.helpers.rest_client.typing import HTTPMethodBasic + +from dlt.sources.helpers.rest_client.paginators import ( + BasePaginator, + HeaderLinkPaginator, + JSONResponseCursorPaginator, + OffsetPaginator, + PageNumberPaginator, + SinglePagePaginator, +) +from dlt.sources.helpers.rest_client.typing import HTTPMethodBasic + + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + from dlt.sources.helpers.rest_client.paginators import ( + JSONResponsePaginator as JSONLinkPaginator, + ) + +from dlt.sources.helpers.rest_client.auth import ( + HttpBasicAuth, + BearerTokenAuth, + APIKeyAuth, +) + +PaginatorType = Literal[ + "json_link", + "json_response", # deprecated. Use json_link instead. Will be removed in upcoming release + "header_link", + "auto", + "single_page", + "cursor", + "offset", + "page_number", +] + + +class PaginatorTypeConfig(TypedDict, total=True): + type: PaginatorType # noqa + + +class PageNumberPaginatorConfig(PaginatorTypeConfig, total=False): + """A paginator that uses page number-based pagination strategy.""" + + base_page: Optional[int] + page_param: Optional[str] + total_path: Optional[jsonpath.TJsonPath] + maximum_page: Optional[int] + + +class OffsetPaginatorConfig(PaginatorTypeConfig, total=False): + """A paginator that uses offset-based pagination strategy.""" + + limit: int + offset: Optional[int] + offset_param: Optional[str] + limit_param: Optional[str] + total_path: Optional[jsonpath.TJsonPath] + maximum_offset: Optional[int] + + +class HeaderLinkPaginatorConfig(PaginatorTypeConfig, total=False): + """A paginator that uses the 'Link' header in HTTP responses + for pagination.""" + + links_next_key: Optional[str] + + +class JSONLinkPaginatorConfig(PaginatorTypeConfig, total=False): + """Locates the next page URL within the JSON response body. The key + containing the URL can be specified using a JSON path.""" + + next_url_path: Optional[jsonpath.TJsonPath] + + +class JSONResponseCursorPaginatorConfig(PaginatorTypeConfig, total=False): + """Uses a cursor parameter for pagination, with the cursor value found in + the JSON response body.""" + + cursor_path: Optional[jsonpath.TJsonPath] + cursor_param: Optional[str] + + +PaginatorConfig = Union[ + PaginatorType, + PageNumberPaginatorConfig, + OffsetPaginatorConfig, + HeaderLinkPaginatorConfig, + JSONLinkPaginatorConfig, + JSONResponseCursorPaginatorConfig, + BasePaginator, + SinglePagePaginator, + HeaderLinkPaginator, + JSONLinkPaginator, + JSONResponseCursorPaginator, + OffsetPaginator, + PageNumberPaginator, +] + + +AuthType = Literal["bearer", "api_key", "http_basic"] + + +class AuthTypeConfig(TypedDict, total=True): + type: AuthType # noqa + + +class BearerTokenAuthConfig(TypedDict, total=False): + """Uses `token` for Bearer authentication in "Authorization" header.""" + + # we allow for a shorthand form of bearer auth, without a type + type: Optional[AuthType] # noqa + token: str + + +class ApiKeyAuthConfig(AuthTypeConfig, total=False): + """Uses provided `api_key` to create authorization data in the specified `location` (query, param, header, cookie) under specified `name`""" + + name: Optional[str] + api_key: str + location: Optional[TApiKeyLocation] + + +class HttpBasicAuthConfig(AuthTypeConfig, total=True): + """Uses HTTP basic authentication""" + + username: str + password: str + + +# TODO: add later +# class OAuthJWTAuthConfig(AuthTypeConfig, total=True): + + +AuthConfig = Union[ + AuthConfigBase, + AuthType, + BearerTokenAuthConfig, + ApiKeyAuthConfig, + HttpBasicAuthConfig, + BearerTokenAuth, + APIKeyAuth, + HttpBasicAuth, +] + + +class ClientConfig(TypedDict, total=False): + base_url: str + headers: Optional[Dict[str, str]] + auth: Optional[AuthConfig] + paginator: Optional[PaginatorConfig] + + +class IncrementalRESTArgs(IncrementalArgs, total=False): + convert: Optional[Callable[..., Any]] + + +class IncrementalConfig(IncrementalRESTArgs, total=False): + start_param: str + end_param: Optional[str] + + +ParamBindType = Literal["resolve", "incremental"] + + +class ParamBindConfig(TypedDict): + type: ParamBindType # noqa + + +class ResolveParamConfig(ParamBindConfig): + resource: str + field: str + + +class IncrementalParamConfig(ParamBindConfig, IncrementalRESTArgs): + pass + # TODO: implement param type to bind incremental to + # param_type: Optional[Literal["start_param", "end_param"]] + + +@dataclass +class ResolvedParam: + param_name: str + resolve_config: ResolveParamConfig + field_path: jsonpath.TJsonPath = field(init=False) + + def __post_init__(self) -> None: + self.field_path = jsonpath.compile_path(self.resolve_config["field"]) + + +class ResponseActionDict(TypedDict, total=False): + status_code: Optional[Union[int, str]] + content: Optional[str] + action: Optional[Union[str, Union[Callable[..., Any], List[Callable[..., Any]]]]] + + +ResponseAction = Union[ResponseActionDict, Callable[..., Any]] + + +class Endpoint(TypedDict, total=False): + path: Optional[str] + method: Optional[HTTPMethodBasic] + params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]] + json: Optional[Dict[str, Any]] + paginator: Optional[PaginatorConfig] + data_selector: Optional[jsonpath.TJsonPath] + response_actions: Optional[List[ResponseAction]] + incremental: Optional[IncrementalConfig] + + +class ProcessingSteps(TypedDict): + filter: Optional[Callable[[Any], bool]] # noqa: A003 + map: Optional[Callable[[Any], Any]] # noqa: A003 + + +class ResourceBase(TResourceHintsBase, total=False): + """Defines hints that may be passed to `dlt.resource` decorator""" + + table_name: Optional[TTableHintTemplate[str]] + max_table_nesting: Optional[int] + columns: Optional[TTableHintTemplate[TAnySchemaColumns]] + selected: Optional[bool] + parallelized: Optional[bool] + processing_steps: Optional[List[ProcessingSteps]] + + +class EndpointResourceBase(ResourceBase, total=False): + endpoint: Optional[Union[str, Endpoint]] + include_from_parent: Optional[List[str]] + + +class EndpointResource(EndpointResourceBase, total=False): + name: TTableHintTemplate[str] + + +class RESTAPIConfigBase(TypedDict): + client: ClientConfig + resources: List[Union[str, EndpointResource]] + + +class RESTAPIConfig(RESTAPIConfigBase, total=False): + resource_defaults: Optional[EndpointResourceBase] diff --git a/dlt/sources/rest_api/utils.py b/dlt/sources/rest_api/utils.py new file mode 100644 index 0000000000..c1ef181cca --- /dev/null +++ b/dlt/sources/rest_api/utils.py @@ -0,0 +1,35 @@ +from typing import Tuple, Dict, Any, Mapping, Iterable + +from dlt.common import logger +from dlt.extract.source import DltSource + + +def join_url(base_url: str, path: str) -> str: + if not base_url.endswith("/"): + base_url += "/" + return base_url + path.lstrip("/") + + +def exclude_keys(d: Mapping[str, Any], keys: Iterable[str]) -> Dict[str, Any]: + """Removes specified keys from a dictionary and returns a new dictionary. + + Args: + d (Mapping[str, Any]): The dictionary to remove keys from. + keys (Iterable[str]): The keys to remove. + + Returns: + Dict[str, Any]: A new dictionary with the specified keys removed. + """ + return {k: v for k, v in d.items() if k not in keys} + + +def check_connection( + source: DltSource, + *resource_names: str, +) -> Tuple[bool, str]: + try: + list(source.with_resources(*resource_names).add_limit(1)) + return (True, "") + except Exception as e: + logger.error(f"Error checking connection: {e}") + return (False, str(e)) diff --git a/dlt/sources/rest_api_pipeline.py b/dlt/sources/rest_api_pipeline.py new file mode 100644 index 0000000000..01a8828fcd --- /dev/null +++ b/dlt/sources/rest_api_pipeline.py @@ -0,0 +1,158 @@ +from typing import Any, Optional + +import dlt +from dlt.common.pendulum import pendulum +from dlt.sources.rest_api import ( + RESTAPIConfig, + check_connection, + rest_api_resources, + rest_api_source, +) + + +@dlt.source(name="github") +def github_source(access_token: Optional[str] = dlt.secrets.value) -> Any: + # Create a REST API configuration for the GitHub API + # Use RESTAPIConfig to get autocompletion and type checking + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.github.com/repos/dlt-hub/dlt/", + # we add an auth config if the auth token is present + "auth": ( + { + "type": "bearer", + "token": access_token, + } + if access_token + else None + ), + }, + # The default configuration for all resources and their endpoints + "resource_defaults": { + "primary_key": "id", + "write_disposition": "merge", + "endpoint": { + "params": { + "per_page": 100, + }, + }, + }, + "resources": [ + # This is a simple resource definition, + # that uses the endpoint path as a resource name: + # "pulls", + # Alternatively, you can define the endpoint as a dictionary + # { + # "name": "pulls", # <- Name of the resource + # "endpoint": "pulls", # <- This is the endpoint path + # } + # Or use a more detailed configuration: + { + "name": "issues", + "endpoint": { + "path": "issues", + # Query parameters for the endpoint + "params": { + "sort": "updated", + "direction": "desc", + "state": "open", + # Define `since` as a special parameter + # to incrementally load data from the API. + # This works by getting the updated_at value + # from the previous response data and using this value + # for the `since` query parameter in the next request. + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": pendulum.today().subtract(days=30).to_iso8601_string(), + }, + }, + }, + }, + # The following is an example of a resource that uses + # a parent resource (`issues`) to get the `issue_number` + # and include it in the endpoint path: + { + "name": "issue_comments", + "endpoint": { + # The placeholder {issue_number} will be resolved + # from the parent resource + "path": "issues/{issue_number}/comments", + "params": { + # The value of `issue_number` will be taken + # from the `number` field in the `issues` resource + "issue_number": { + "type": "resolve", + "resource": "issues", + "field": "number", + } + }, + }, + # Include data from `id` field of the parent resource + # in the child data. The field name in the child data + # will be called `_issues_id` (_{resource_name}_{field_name}) + "include_from_parent": ["id"], + }, + ], + } + + yield from rest_api_resources(config) + + +def load_github() -> None: + pipeline = dlt.pipeline( + pipeline_name="rest_api_github", + destination="duckdb", + dataset_name="rest_api_data", + ) + + load_info = pipeline.run(github_source()) + print(load_info) # noqa: T201 + + +def load_pokemon() -> None: + pipeline = dlt.pipeline( + pipeline_name="rest_api_pokemon", + destination="duckdb", + dataset_name="rest_api_data", + ) + + pokemon_source = rest_api_source( + { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + # If you leave out the paginator, it will be inferred from the API: + # "paginator": "json_link", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + }, + }, + "resources": [ + "pokemon", + "berry", + "location", + ], + } + ) + + def check_network_and_authentication() -> None: + (can_connect, error_msg) = check_connection( + pokemon_source, + "not_existing_endpoint", + ) + if not can_connect: + pass # do something with the error message + + check_network_and_authentication() + + load_info = pipeline.run(pokemon_source) + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_github() + load_pokemon() diff --git a/dlt/sources/sql_database/__init__.py b/dlt/sources/sql_database/__init__.py new file mode 100644 index 0000000000..d102fc9a46 --- /dev/null +++ b/dlt/sources/sql_database/__init__.py @@ -0,0 +1,216 @@ +"""Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" + +from typing import Callable, Dict, List, Optional, Union, Iterable, Any + +from dlt.common.libs.sql_alchemy import MetaData, Table, Engine + +import dlt +from dlt.sources import DltResource + + +from dlt.sources.credentials import ConnectionStringCredentials +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext + +from .helpers import ( + table_rows, + engine_from_credentials, + TableBackend, + SqlDatabaseTableConfiguration, + SqlTableResourceConfiguration, + _detect_precision_hints_deprecated, + TQueryAdapter, +) +from .schema_types import ( + default_table_adapter, + table_to_columns, + get_primary_key, + ReflectionLevel, + TTypeAdapter, +) + + +@dlt.source +def sql_database( + credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, + schema: Optional[str] = dlt.config.value, + metadata: Optional[MetaData] = None, + table_names: Optional[List[str]] = dlt.config.value, + chunk_size: int = 50000, + backend: TableBackend = "sqlalchemy", + detect_precision_hints: Optional[bool] = False, + reflection_level: Optional[ReflectionLevel] = "full", + defer_table_reflect: Optional[bool] = None, + table_adapter_callback: Callable[[Table], None] = None, + backend_kwargs: Dict[str, Any] = None, + include_views: bool = False, + type_adapter_callback: Optional[TTypeAdapter] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> Iterable[DltResource]: + """ + A dlt source which loads data from an SQL database using SQLAlchemy. + Resources are automatically created for each table in the schema or from the given list of tables. + + Args: + credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `sqlalchemy.Engine` instance. + schema (Optional[str]): Name of the database schema to load (if different from default). + metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. `schema` argument is ignored when this is used. + table_names (Optional[List[str]]): A list of table names to load. By default, all tables in the schema are loaded. + chunk_size (int): Number of rows yielded in one batch. SQL Alchemy will create additional internal rows buffer twice the chunk size. + backend (TableBackend): Type of backend to generate table data. One of: "sqlalchemy", "pyarrow", "pandas" and "connectorx". + "sqlalchemy" yields batches as lists of Python dictionaries, "pyarrow" and "connectorx" yield batches as arrow tables, "pandas" yields panda frames. + "sqlalchemy" is the default and does not require additional dependencies, "pyarrow" creates stable destination schemas with correct data types, + "connectorx" is typically the fastest but ignores the "chunk_size" so you must deal with large tables yourself. + detect_precision_hints (bool): Deprecated. Use `reflection_level`. Set column precision and scale hints for supported data types in the target schema based on the columns in the source tables. + This is disabled by default. + reflection_level: (ReflectionLevel): Specifies how much information should be reflected from the source database schema. + "minimal": Only table names, nullability and primary keys are reflected. Data types are inferred from the data. + "full": Data types will be reflected on top of "minimal". `dlt` will coerce the data into reflected types if necessary. This is the default option. + "full_with_precision": Sets precision and scale on supported data types (ie. decimal, text, binary). Creates big and regular integer types. + defer_table_reflect (bool): Will connect and reflect table schema only when yielding data. Requires table_names to be explicitly passed. + Enable this option when running on Airflow. Available on dlt 0.4.4 and later. + table_adapter_callback: (Callable): Receives each reflected table. May be used to modify the list of columns that will be selected. + backend_kwargs (**kwargs): kwargs passed to table backend ie. "conn" is used to pass specialized connection string to connectorx. + include_views (bool): Reflect views as well as tables. Note view names included in `table_names` are always included regardless of this setting. + type_adapter_callback(Optional[Callable]): Callable to override type inference when reflecting columns. + Argument is a single sqlalchemy data type (`TypeEngine` instance) and it should return another sqlalchemy data type, or `None` (type will be inferred from data) + query_adapter_callback(Optional[Callable[Select, Table], Select]): Callable to override the SELECT query used to fetch data from the table. + The callback receives the sqlalchemy `Select` and corresponding `Table` objects and should return the modified `Select`. + + Returns: + Iterable[DltResource]: A list of DLT resources for each table to be loaded. + """ + # detect precision hints is deprecated + _detect_precision_hints_deprecated(detect_precision_hints) + + if detect_precision_hints: + reflection_level = "full_with_precision" + else: + reflection_level = reflection_level or "minimal" + + # set up alchemy engine + engine = engine_from_credentials(credentials) + engine.execution_options(stream_results=True, max_row_buffer=2 * chunk_size) + metadata = metadata or MetaData(schema=schema) + + # use provided tables or all tables + if table_names: + tables = [ + Table(name, metadata, autoload_with=None if defer_table_reflect else engine) + for name in table_names + ] + else: + if defer_table_reflect: + raise ValueError("You must pass table names to defer table reflection") + metadata.reflect(bind=engine, views=include_views) + tables = list(metadata.tables.values()) + + for table in tables: + yield sql_table( + credentials=credentials, + table=table.name, + schema=table.schema, + metadata=metadata, + chunk_size=chunk_size, + backend=backend, + reflection_level=reflection_level, + defer_table_reflect=defer_table_reflect, + table_adapter_callback=table_adapter_callback, + backend_kwargs=backend_kwargs, + type_adapter_callback=type_adapter_callback, + query_adapter_callback=query_adapter_callback, + ) + + +@dlt.resource(name=lambda args: args["table"], standalone=True, spec=SqlTableResourceConfiguration) +def sql_table( + credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, + table: str = dlt.config.value, + schema: Optional[str] = dlt.config.value, + metadata: Optional[MetaData] = None, + incremental: Optional[dlt.sources.incremental[Any]] = None, + chunk_size: int = 50000, + backend: TableBackend = "sqlalchemy", + detect_precision_hints: Optional[bool] = None, + reflection_level: Optional[ReflectionLevel] = "full", + defer_table_reflect: Optional[bool] = None, + table_adapter_callback: Callable[[Table], None] = None, + backend_kwargs: Dict[str, Any] = None, + type_adapter_callback: Optional[TTypeAdapter] = None, + included_columns: Optional[List[str]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> DltResource: + """ + A dlt resource which loads data from an SQL database table using SQLAlchemy. + + Args: + credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `Engine` instance representing the database connection. + table (str): Name of the table or view to load. + schema (Optional[str]): Optional name of the schema the table belongs to. + metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. If provided, the `schema` argument is ignored. + incremental (Optional[dlt.sources.incremental[Any]]): Option to enable incremental loading for the table. + E.g., `incremental=dlt.sources.incremental('updated_at', pendulum.parse('2022-01-01T00:00:00Z'))` + chunk_size (int): Number of rows yielded in one batch. SQL Alchemy will create additional internal rows buffer twice the chunk size. + backend (TableBackend): Type of backend to generate table data. One of: "sqlalchemy", "pyarrow", "pandas" and "connectorx". + "sqlalchemy" yields batches as lists of Python dictionaries, "pyarrow" and "connectorx" yield batches as arrow tables, "pandas" yields panda frames. + "sqlalchemy" is the default and does not require additional dependencies, "pyarrow" creates stable destination schemas with correct data types, + "connectorx" is typically the fastest but ignores the "chunk_size" so you must deal with large tables yourself. + reflection_level: (ReflectionLevel): Specifies how much information should be reflected from the source database schema. + "minimal": Only table names, nullability and primary keys are reflected. Data types are inferred from the data. + "full": Data types will be reflected on top of "minimal". `dlt` will coerce the data into reflected types if necessary. This is the default option. + "full_with_precision": Sets precision and scale on supported data types (ie. decimal, text, binary). Creates big and regular integer types. + detect_precision_hints (bool): Deprecated. Use `reflection_level`. Set column precision and scale hints for supported data types in the target schema based on the columns in the source tables. + This is disabled by default. + defer_table_reflect (bool): Will connect and reflect table schema only when yielding data. Enable this option when running on Airflow. Available + on dlt 0.4.4 and later + table_adapter_callback: (Callable): Receives each reflected table. May be used to modify the list of columns that will be selected. + backend_kwargs (**kwargs): kwargs passed to table backend ie. "conn" is used to pass specialized connection string to connectorx. + type_adapter_callback(Optional[Callable]): Callable to override type inference when reflecting columns. + Argument is a single sqlalchemy data type (`TypeEngine` instance) and it should return another sqlalchemy data type, or `None` (type will be inferred from data) + included_columns (Optional[List[str]): List of column names to select from the table. If not provided, all columns are loaded. + query_adapter_callback(Optional[Callable[Select, Table], Select]): Callable to override the SELECT query used to fetch data from the table. + The callback receives the sqlalchemy `Select` and corresponding `Table` objects and should return the modified `Select`. + + Returns: + DltResource: The dlt resource for loading data from the SQL database table. + """ + _detect_precision_hints_deprecated(detect_precision_hints) + + if detect_precision_hints: + reflection_level = "full_with_precision" + else: + reflection_level = reflection_level or "minimal" + + engine = engine_from_credentials(credentials, may_dispose_after_use=True) + engine.execution_options(stream_results=True, max_row_buffer=2 * chunk_size) + metadata = metadata or MetaData(schema=schema) + + table_obj = metadata.tables.get("table") or Table( + table, metadata, autoload_with=None if defer_table_reflect else engine + ) + if not defer_table_reflect: + default_table_adapter(table_obj, included_columns) + if table_adapter_callback: + table_adapter_callback(table_obj) + + skip_complex_on_minimal = backend == "sqlalchemy" + return dlt.resource( + table_rows, + name=table_obj.name, + primary_key=get_primary_key(table_obj), + columns=table_to_columns( + table_obj, reflection_level, type_adapter_callback, skip_complex_on_minimal + ), + )( + engine, + table_obj, + chunk_size, + backend, + incremental=incremental, + reflection_level=reflection_level, + defer_table_reflect=defer_table_reflect, + table_adapter_callback=table_adapter_callback, + backend_kwargs=backend_kwargs, + type_adapter_callback=type_adapter_callback, + included_columns=included_columns, + query_adapter_callback=query_adapter_callback, + ) diff --git a/dlt/sources/sql_database/arrow_helpers.py b/dlt/sources/sql_database/arrow_helpers.py new file mode 100644 index 0000000000..898d8c3280 --- /dev/null +++ b/dlt/sources/sql_database/arrow_helpers.py @@ -0,0 +1,150 @@ +from typing import Any, Sequence, Optional + +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common import logger, json +from dlt.common.configuration import with_config +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.json import custom_encode, map_nested_in_place + +from .schema_types import RowAny + + +@with_config +def columns_to_arrow( + columns_schema: TTableSchemaColumns, + caps: DestinationCapabilitiesContext = None, + tz: str = "UTC", +) -> Any: + """Converts `column_schema` to arrow schema using `caps` and `tz`. `caps` are injected from the container - which + is always the case if run within the pipeline. This will generate arrow schema compatible with the destination. + Otherwise generic capabilities are used + """ + from dlt.common.libs.pyarrow import pyarrow as pa, get_py_arrow_datatype + from dlt.common.destination.capabilities import DestinationCapabilitiesContext + + return pa.schema( + [ + pa.field( + name, + get_py_arrow_datatype( + schema_item, + caps or DestinationCapabilitiesContext.generic_capabilities(), + tz, + ), + nullable=schema_item.get("nullable", True), + ) + for name, schema_item in columns_schema.items() + if schema_item.get("data_type") is not None + ] + ) + + +def row_tuples_to_arrow(rows: Sequence[RowAny], columns: TTableSchemaColumns, tz: str) -> Any: + """Converts the rows to an arrow table using the columns schema. + Columns missing `data_type` will be inferred from the row data. + Columns with object types not supported by arrow are excluded from the resulting table. + """ + from dlt.common.libs.pyarrow import pyarrow as pa + import numpy as np + + try: + from pandas._libs import lib + + pivoted_rows = lib.to_object_array_tuples(rows).T + except ImportError: + logger.info( + "Pandas not installed, reverting to numpy.asarray to create a table which is slower" + ) + pivoted_rows = np.asarray(rows, dtype="object", order="k").T # type: ignore[call-overload] + + columnar = { + col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(columns))) + } + columnar_known_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is not None + } + columnar_unknown_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is None + } + + arrow_schema = columns_to_arrow(columns, tz=tz) + + for idx in range(0, len(arrow_schema.names)): + field = arrow_schema.field(idx) + py_type = type(rows[0][idx]) + # cast double / float ndarrays to decimals if type mismatch, looks like decimals and floats are often mixed up in dialects + if pa.types.is_decimal(field.type) and issubclass(py_type, (str, float)): + logger.warning( + f"Field {field.name} was reflected as decimal type, but rows contains" + f" {py_type.__name__}. Additional cast is required which may slow down arrow table" + " generation." + ) + float_array = pa.array(columnar_known_types[field.name], type=pa.float64()) + columnar_known_types[field.name] = float_array.cast(field.type, safe=False) + if issubclass(py_type, (dict, list)): + logger.warning( + f"Field {field.name} was reflected as JSON type and needs to be serialized back to" + " string to be placed in arrow table. This will slow data extraction down. You" + " should cast JSON field to STRING in your database system ie. by creating and" + " extracting an SQL VIEW that selects with cast." + ) + json_str_array = pa.array( + [None if s is None else json.dumps(s) for s in columnar_known_types[field.name]] + ) + columnar_known_types[field.name] = json_str_array + + # If there are unknown type columns, first create a table to infer their types + if columnar_unknown_types: + new_schema_fields = [] + for key in list(columnar_unknown_types): + arrow_col: Optional[pa.Array] = None + try: + arrow_col = pa.array(columnar_unknown_types[key]) + if pa.types.is_null(arrow_col.type): + logger.warning( + f"Column {key} contains only NULL values and data type could not be" + " inferred. This column is removed from a arrow table" + ) + continue + + except pa.ArrowInvalid as e: + # Try coercing types not supported by arrow to a json friendly format + # E.g. dataclasses -> dict, UUID -> str + try: + arrow_col = pa.array( + map_nested_in_place(custom_encode, list(columnar_unknown_types[key])) + ) + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow and" + f" got converted into {arrow_col.type}. This slows down arrow table" + " generation." + ) + except (pa.ArrowInvalid, TypeError): + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow. This" + f" column will be ignored. Error: {e}" + ) + if arrow_col is not None: + columnar_known_types[key] = arrow_col + new_schema_fields.append( + pa.field( + key, + arrow_col.type, + nullable=columns[key]["nullable"], + ) + ) + + # New schema + column_order = {name: idx for idx, name in enumerate(columns)} + arrow_schema = pa.schema( + sorted( + list(arrow_schema) + new_schema_fields, + key=lambda x: column_order[x.name], + ) + ) + + return pa.Table.from_pydict(columnar_known_types, schema=arrow_schema) diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py new file mode 100644 index 0000000000..1d758fe882 --- /dev/null +++ b/dlt/sources/sql_database/helpers.py @@ -0,0 +1,311 @@ +"""SQL database source helpers""" + +import warnings +from typing import ( + Callable, + Any, + Dict, + List, + Literal, + Optional, + Iterator, + Union, +) +import operator + +import dlt +from dlt.common.configuration.specs import BaseConfiguration, configspec +from dlt.common.exceptions import MissingDependencyException +from dlt.common.schema import TTableSchemaColumns +from dlt.common.typing import TDataItem, TSortOrder + +from dlt.sources.credentials import ConnectionStringCredentials + +from .arrow_helpers import row_tuples_to_arrow +from .schema_types import ( + default_table_adapter, + table_to_columns, + get_primary_key, + Table, + SelectAny, + ReflectionLevel, + TTypeAdapter, +) + +from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine + + +TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"] +TQueryAdapter = Callable[[SelectAny, Table], SelectAny] + + +class TableLoader: + def __init__( + self, + engine: Engine, + backend: TableBackend, + table: Table, + columns: TTableSchemaColumns, + chunk_size: int = 1000, + incremental: Optional[dlt.sources.incremental[Any]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, + ) -> None: + self.engine = engine + self.backend = backend + self.table = table + self.columns = columns + self.chunk_size = chunk_size + self.query_adapter_callback = query_adapter_callback + self.incremental = incremental + if incremental: + try: + self.cursor_column = table.c[incremental.cursor_path] + except KeyError as e: + raise KeyError( + f"Cursor column '{incremental.cursor_path}' does not exist in table" + f" '{table.name}'" + ) from e + self.last_value = incremental.last_value + self.end_value = incremental.end_value + self.row_order: TSortOrder = self.incremental.row_order + else: + self.cursor_column = None + self.last_value = None + self.end_value = None + self.row_order = None + + def _make_query(self) -> SelectAny: + table = self.table + query = table.select() + if not self.incremental: + return query # type: ignore[no-any-return] + last_value_func = self.incremental.last_value_func + + # generate where + if last_value_func is max: # Query ordered and filtered according to last_value function + filter_op = operator.ge + filter_op_end = operator.lt + elif last_value_func is min: + filter_op = operator.le + filter_op_end = operator.gt + else: # Custom last_value, load everything and let incremental handle filtering + return query # type: ignore[no-any-return] + + if self.last_value is not None: + query = query.where(filter_op(self.cursor_column, self.last_value)) + if self.end_value is not None: + query = query.where(filter_op_end(self.cursor_column, self.end_value)) + + # generate order by from declared row order + order_by = None + if (self.row_order == "asc" and last_value_func is max) or ( + self.row_order == "desc" and last_value_func is min + ): + order_by = self.cursor_column.asc() + elif (self.row_order == "asc" and last_value_func is min) or ( + self.row_order == "desc" and last_value_func is max + ): + order_by = self.cursor_column.desc() + if order_by is not None: + query = query.order_by(order_by) + + return query # type: ignore[no-any-return] + + def make_query(self) -> SelectAny: + if self.query_adapter_callback: + return self.query_adapter_callback(self._make_query(), self.table) + return self._make_query() + + def load_rows(self, backend_kwargs: Dict[str, Any] = None) -> Iterator[TDataItem]: + # make copy of kwargs + backend_kwargs = dict(backend_kwargs or {}) + query = self.make_query() + if self.backend == "connectorx": + yield from self._load_rows_connectorx(query, backend_kwargs) + else: + yield from self._load_rows(query, backend_kwargs) + + def _load_rows(self, query: SelectAny, backend_kwargs: Dict[str, Any]) -> TDataItem: + with self.engine.connect() as conn: + result = conn.execution_options(yield_per=self.chunk_size).execute(query) + # NOTE: cursor returns not normalized column names! may be quite useful in case of Oracle dialect + # that normalizes columns + # columns = [c[0] for c in result.cursor.description] + columns = list(result.keys()) + for partition in result.partitions(size=self.chunk_size): + if self.backend == "sqlalchemy": + yield [dict(row._mapping) for row in partition] + elif self.backend == "pandas": + from dlt.common.libs.pandas_sql import _wrap_result + + df = _wrap_result( + partition, + columns, + **{"dtype_backend": "pyarrow", **backend_kwargs}, + ) + yield df + elif self.backend == "pyarrow": + yield row_tuples_to_arrow( + partition, self.columns, tz=backend_kwargs.get("tz", "UTC") + ) + + def _load_rows_connectorx( + self, query: SelectAny, backend_kwargs: Dict[str, Any] + ) -> Iterator[TDataItem]: + try: + import connectorx as cx + except ImportError: + raise MissingDependencyException("Connector X table backend", ["connectorx"]) + + # default settings + backend_kwargs = { + "return_type": "arrow2", + "protocol": "binary", + **backend_kwargs, + } + conn = backend_kwargs.pop( + "conn", + self.engine.url._replace( + drivername=self.engine.url.get_backend_name() + ).render_as_string(hide_password=False), + ) + try: + query_str = str(query.compile(self.engine, compile_kwargs={"literal_binds": True})) + except CompileError as ex: + raise NotImplementedError( + f"Query for table {self.table.name} could not be compiled to string to execute it" + " on ConnectorX. If you are on SQLAlchemy 1.4.x the causing exception is due to" + f" literals that cannot be rendered, upgrade to 2.x: {str(ex)}" + ) from ex + df = cx.read_sql(conn, query_str, **backend_kwargs) + yield df + + +def table_rows( + engine: Engine, + table: Table, + chunk_size: int, + backend: TableBackend, + incremental: Optional[dlt.sources.incremental[Any]] = None, + defer_table_reflect: bool = False, + table_adapter_callback: Callable[[Table], None] = None, + reflection_level: ReflectionLevel = "minimal", + backend_kwargs: Dict[str, Any] = None, + type_adapter_callback: Optional[TTypeAdapter] = None, + included_columns: Optional[List[str]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> Iterator[TDataItem]: + columns: TTableSchemaColumns = None + if defer_table_reflect: + table = Table(table.name, table.metadata, autoload_with=engine, extend_existing=True) # type: ignore[attr-defined] + default_table_adapter(table, included_columns) + if table_adapter_callback: + table_adapter_callback(table) + columns = table_to_columns(table, reflection_level, type_adapter_callback) + + # set the primary_key in the incremental + if incremental and incremental.primary_key is None: + primary_key = get_primary_key(table) + if primary_key is not None: + incremental.primary_key = primary_key + + # yield empty record to set hints + yield dlt.mark.with_hints( + [], + dlt.mark.make_hints( + primary_key=get_primary_key(table), + columns=columns, + ), + ) + else: + # table was already reflected + columns = table_to_columns(table, reflection_level, type_adapter_callback) + + loader = TableLoader( + engine, + backend, + table, + columns, + incremental=incremental, + chunk_size=chunk_size, + query_adapter_callback=query_adapter_callback, + ) + try: + yield from loader.load_rows(backend_kwargs) + finally: + # dispose the engine if created for this particular table + # NOTE: database wide engines are not disposed, not externally provided + if getattr(engine, "may_dispose_after_use", False): + engine.dispose() + + +def engine_from_credentials( + credentials: Union[ConnectionStringCredentials, Engine, str], + may_dispose_after_use: bool = False, + **backend_kwargs: Any, +) -> Engine: + if isinstance(credentials, Engine): + return credentials + if isinstance(credentials, ConnectionStringCredentials): + credentials = credentials.to_native_representation() + engine = create_engine(credentials, **backend_kwargs) + setattr(engine, "may_dispose_after_use", may_dispose_after_use) # noqa + return engine # type: ignore[no-any-return] + + +def unwrap_json_connector_x(field: str) -> TDataItem: + """Creates a transform function to be added with `add_map` that will unwrap JSON columns + ingested via connectorx. Such columns are additionally quoted and translate SQL NULL to json "null" + """ + import pyarrow.compute as pc + import pyarrow as pa + + def _unwrap(table: TDataItem) -> TDataItem: + col_index = table.column_names.index(field) + # remove quotes + column = table[field] # pc.replace_substring_regex(table[field], '"(.*)"', "\\1") + # convert json null to null + column = pc.replace_with_mask( + column, + pc.equal(column, "null").combine_chunks(), + pa.scalar(None, pa.large_string()), + ) + return table.set_column(col_index, table.schema.field(col_index), column) + + return _unwrap + + +def _detect_precision_hints_deprecated(value: Optional[bool]) -> None: + if value is None: + return + + msg = ( + "`detect_precision_hints` argument is deprecated and will be removed in a future release. " + ) + if value: + msg += "Use `reflection_level='full_with_precision'` which has the same effect instead." + + warnings.warn( + msg, + DeprecationWarning, + ) + + +@configspec +class SqlDatabaseTableConfiguration(BaseConfiguration): + incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] + included_columns: Optional[List[str]] = None + + +@configspec +class SqlTableResourceConfiguration(BaseConfiguration): + credentials: Union[ConnectionStringCredentials, Engine, str] = None + table: str = None + schema: Optional[str] = None + incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] + chunk_size: int = 50000 + backend: TableBackend = "sqlalchemy" + detect_precision_hints: Optional[bool] = None + defer_table_reflect: Optional[bool] = False + reflection_level: Optional[ReflectionLevel] = "full" + included_columns: Optional[List[str]] = None diff --git a/dlt/sources/sql_database/schema_types.py b/dlt/sources/sql_database/schema_types.py new file mode 100644 index 0000000000..6ea2b9d54b --- /dev/null +++ b/dlt/sources/sql_database/schema_types.py @@ -0,0 +1,164 @@ +from typing import ( + Optional, + Any, + Type, + TYPE_CHECKING, + Literal, + List, + Callable, + Union, +) +from typing_extensions import TypeAlias +from dlt.common.libs.sql_alchemy import Table, Column, Row, sqltypes, Select, TypeEngine + + +from dlt.common import logger +from dlt.common.schema.typing import TColumnSchema, TTableSchemaColumns + +ReflectionLevel = Literal["minimal", "full", "full_with_precision"] + + +# optionally create generics with any so they can be imported by dlt importer +if TYPE_CHECKING: + SelectAny: TypeAlias = Select[Any] # type: ignore[type-arg] + ColumnAny: TypeAlias = Column[Any] # type: ignore[type-arg] + RowAny: TypeAlias = Row[Any] # type: ignore[type-arg] + TypeEngineAny = TypeEngine[Any] # type: ignore[type-arg] +else: + SelectAny: TypeAlias = Type[Any] + ColumnAny: TypeAlias = Type[Any] + RowAny: TypeAlias = Type[Any] + TypeEngineAny = Type[Any] + + +TTypeAdapter = Callable[[TypeEngineAny], Optional[Union[TypeEngineAny, Type[TypeEngineAny]]]] + + +def default_table_adapter(table: Table, included_columns: Optional[List[str]]) -> None: + """Default table adapter being always called before custom one""" + if included_columns is not None: + # Delete columns not included in the load + for col in list(table._columns): # type: ignore[attr-defined] + if col.name not in included_columns: + table._columns.remove(col) # type: ignore[attr-defined] + for col in table._columns: # type: ignore[attr-defined] + sql_t = col.type + if hasattr(sqltypes, "Uuid") and isinstance(sql_t, sqltypes.Uuid): + # emit uuids as string by default + sql_t.as_uuid = False + + +def sqla_col_to_column_schema( + sql_col: ColumnAny, + reflection_level: ReflectionLevel, + type_adapter_callback: Optional[TTypeAdapter] = None, + skip_complex_columns_on_minimal: bool = False, +) -> Optional[TColumnSchema]: + """Infer dlt schema column type from an sqlalchemy type. + + If `add_precision` is set, precision and scale is inferred from that types that support it, + such as numeric, varchar, int, bigint. Numeric (decimal) types have always precision added. + """ + col: TColumnSchema = { + "name": sql_col.name, + "nullable": sql_col.nullable, + } + if reflection_level == "minimal": + # TODO: when we have a complex column, it should not be added to the schema as it will be + # normalized into subtables + if isinstance(sql_col.type, sqltypes.JSON) and skip_complex_columns_on_minimal: + return None + return col + + sql_t = sql_col.type + + if type_adapter_callback: + sql_t = type_adapter_callback(sql_t) + # Check if sqla type class rather than instance is returned + if sql_t is not None and isinstance(sql_t, type): + sql_t = sql_t() + + if sql_t is None: + # Column ignored by callback + return col + + add_precision = reflection_level == "full_with_precision" + + if hasattr(sqltypes, "Uuid") and isinstance(sql_t, sqltypes.Uuid): + # we represent UUID as text by default, see default_table_adapter + col["data_type"] = "text" + if isinstance(sql_t, sqltypes.Numeric): + # check for Numeric type first and integer later, some numeric types (ie. Oracle) + # derive from both + # all Numeric types that are returned as floats will assume "double" type + # and returned as decimals will assume "decimal" type + if sql_t.asdecimal is False: + col["data_type"] = "double" + else: + col["data_type"] = "decimal" + if sql_t.precision is not None: + col["precision"] = sql_t.precision + # must have a precision for any meaningful scale + if sql_t.scale is not None: + col["scale"] = sql_t.scale + elif sql_t.decimal_return_scale is not None: + col["scale"] = sql_t.decimal_return_scale + elif isinstance(sql_t, sqltypes.SmallInteger): + col["data_type"] = "bigint" + if add_precision: + col["precision"] = 32 + elif isinstance(sql_t, sqltypes.Integer): + col["data_type"] = "bigint" + elif isinstance(sql_t, sqltypes.String): + col["data_type"] = "text" + if add_precision and sql_t.length: + col["precision"] = sql_t.length + elif isinstance(sql_t, sqltypes._Binary): + col["data_type"] = "binary" + if add_precision and sql_t.length: + col["precision"] = sql_t.length + elif isinstance(sql_t, sqltypes.DateTime): + col["data_type"] = "timestamp" + if add_precision: + col["timezone"] = sql_t.timezone + elif isinstance(sql_t, sqltypes.Date): + col["data_type"] = "date" + elif isinstance(sql_t, sqltypes.Time): + col["data_type"] = "time" + elif isinstance(sql_t, sqltypes.JSON): + col["data_type"] = "complex" + elif isinstance(sql_t, sqltypes.Boolean): + col["data_type"] = "bool" + else: + logger.warning( + f"A column with name {sql_col.name} contains unknown data type {sql_t} which cannot be" + " mapped to `dlt` data type. When using sqlalchemy backend such data will be passed to" + " the normalizer. In case of `pyarrow` and `pandas` backend, data types are detected" + " from numpy ndarrays. In case of other backends, the behavior is backend-specific." + ) + return {key: value for key, value in col.items() if value is not None} # type: ignore[return-value] + + +def get_primary_key(table: Table) -> Optional[List[str]]: + """Create primary key or return None if no key defined""" + primary_key = [c.name for c in table.primary_key] + return primary_key if len(primary_key) > 0 else None + + +def table_to_columns( + table: Table, + reflection_level: ReflectionLevel = "full", + type_conversion_fallback: Optional[TTypeAdapter] = None, + skip_complex_columns_on_minimal: bool = False, +) -> TTableSchemaColumns: + """Convert an sqlalchemy table to a dlt table schema.""" + return { + col["name"]: col + for col in ( + sqla_col_to_column_schema( + c, reflection_level, type_conversion_fallback, skip_complex_columns_on_minimal + ) + for c in table.columns + ) + if col is not None + } diff --git a/dlt/sources/sql_database_pipeline.py b/dlt/sources/sql_database_pipeline.py new file mode 100644 index 0000000000..4b82997fd7 --- /dev/null +++ b/dlt/sources/sql_database_pipeline.py @@ -0,0 +1,361 @@ +# flake8: noqa +import humanize +from typing import Any +import os + +import dlt +from dlt.common import pendulum +from dlt.sources.credentials import ConnectionStringCredentials + +from dlt.sources.sql_database import sql_database, sql_table, Table + +from sqlalchemy.sql.sqltypes import TypeEngine +import sqlalchemy as sa + + +def load_select_tables_from_database() -> None: + """Use the sql_database source to reflect an entire database schema and load select tables from it. + + This example sources data from the public Rfam MySQL database. + """ + # Create a pipeline + pipeline = dlt.pipeline(pipeline_name="rfam", destination="duckdb", dataset_name="rfam_data") + + # Credentials for the sample database. + # Note: It is recommended to configure credentials in `.dlt/secrets.toml` under `sources.sql_database.credentials` + credentials = ConnectionStringCredentials( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ) + # To pass the credentials from `secrets.toml`, comment out the above credentials. + # And the credentials will be automatically read from `secrets.toml`. + + # Configure the source to load a few select tables incrementally + source_1 = sql_database(credentials).with_resources("family", "clan") + + # Add incremental config to the resources. "updated" is a timestamp column in these tables that gets used as a cursor + source_1.family.apply_hints(incremental=dlt.sources.incremental("updated")) + source_1.clan.apply_hints(incremental=dlt.sources.incremental("updated")) + + # Run the pipeline. The merge write disposition merges existing rows in the destination by primary key + info = pipeline.run(source_1, write_disposition="merge") + print(info) + + # Load some other tables with replace write disposition. This overwrites the existing tables in destination + source_2 = sql_database(credentials).with_resources("features", "author") + info = pipeline.run(source_2, write_disposition="replace") + print(info) + + # Load a table incrementally with append write disposition + # this is good when a table only has new rows inserted, but not updated + source_3 = sql_database(credentials).with_resources("genome") + source_3.genome.apply_hints(incremental=dlt.sources.incremental("created")) + + info = pipeline.run(source_3, write_disposition="append") + print(info) + + +def load_entire_database() -> None: + """Use the sql_database source to completely load all tables in a database""" + pipeline = dlt.pipeline(pipeline_name="rfam", destination="duckdb", dataset_name="rfam_data") + + # By default the sql_database source reflects all tables in the schema + # The database credentials are sourced from the `.dlt/secrets.toml` configuration + source = sql_database() + + # Run the pipeline. For a large db this may take a while + info = pipeline.run(source, write_disposition="replace") + print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at)) + print(info) + + +def load_standalone_table_resource() -> None: + """Load a few known tables with the standalone sql_table resource, request full schema and deferred + table reflection""" + pipeline = dlt.pipeline( + pipeline_name="rfam_database", + destination="duckdb", + dataset_name="rfam_data", + full_refresh=True, + ) + + # Load a table incrementally starting at a given date + # Adding incremental via argument like this makes extraction more efficient + # as only rows newer than the start date are fetched from the table + # we also use `detect_precision_hints` to get detailed column schema + # and defer_table_reflect to reflect schema only during execution + family = sql_table( + credentials=ConnectionStringCredentials( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ), + table="family", + incremental=dlt.sources.incremental( + "updated", + ), + reflection_level="full_with_precision", + defer_table_reflect=True, + ) + # columns will be empty here due to defer_table_reflect set to True + print(family.compute_table_schema()) + + # Load all data from another table + genome = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="genome", + reflection_level="full_with_precision", + defer_table_reflect=True, + ) + + # Run the resources together + info = pipeline.extract([family, genome], write_disposition="merge") + print(info) + # Show inferred columns + print(pipeline.default_schema.to_pretty_yaml()) + + +def select_columns() -> None: + """Uses table adapter callback to modify list of columns to be selected""" + pipeline = dlt.pipeline( + pipeline_name="rfam_database", + destination="duckdb", + dataset_name="rfam_data_cols", + full_refresh=True, + ) + + def table_adapter(table: Table) -> None: + print(table.name) + if table.name == "family": + # this is SqlAlchemy table. _columns are writable + # let's drop updated column + table._columns.remove(table.columns["updated"]) # type: ignore + + family = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + chunk_size=10, + reflection_level="full_with_precision", + table_adapter_callback=table_adapter, + ) + + # also we do not want the whole table, so we add limit to get just one chunk (10 records) + pipeline.run(family.add_limit(1)) + # only 10 rows + print(pipeline.last_trace.last_normalize_info) + # no "updated" column in "family" table + print(pipeline.default_schema.to_pretty_yaml()) + + +def select_with_end_value_and_row_order() -> None: + """Gets data from a table withing a specified range and sorts rows descending""" + pipeline = dlt.pipeline( + pipeline_name="rfam_database", + destination="duckdb", + dataset_name="rfam_data", + full_refresh=True, + ) + + # gets data from this range + start_date = pendulum.now().subtract(years=1) + end_date = pendulum.now() + + family = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + incremental=dlt.sources.incremental( # declares desc row order + "updated", initial_value=start_date, end_value=end_date, row_order="desc" + ), + chunk_size=10, + ) + # also we do not want the whole table, so we add limit to get just one chunk (10 records) + pipeline.run(family.add_limit(1)) + # only 10 rows + print(pipeline.last_trace.last_normalize_info) + + +def my_sql_via_pyarrow() -> None: + """Uses pyarrow backend to load tables from mysql""" + + # uncomment line below to get load_id into your data (slows pyarrow loading down) + # dlt.config["normalize.parquet_normalizer.add_dlt_load_id"] = True + + # Create a pipeline + pipeline = dlt.pipeline( + pipeline_name="rfam_cx", + destination="duckdb", + dataset_name="rfam_data_arrow_4", + ) + + def _double_as_decimal_adapter(table: sa.Table) -> None: + """Return double as double, not decimals, only works if you are using sqlalchemy 2.0""" + for column in table.columns.values(): + if hasattr(sa, "Double") and isinstance(column.type, sa.Double): + column.type.asdecimal = False + + sql_alchemy_source = sql_database( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", + backend="pyarrow", + table_adapter_callback=_double_as_decimal_adapter, + ).with_resources("family", "genome") + + info = pipeline.run(sql_alchemy_source) + print(info) + + +def create_unsw_flow() -> None: + """Uploads UNSW_Flow dataset to postgres via csv stream skipping dlt normalizer. + You need to download the dataset from https://github.com/rdpahalavan/nids-datasets + """ + from pyarrow.parquet import ParquetFile + + # from dlt.destinations import postgres + + # use those config to get 3x speedup on parallelism + # [sources.data_writer] + # file_max_bytes=3000000 + # buffer_max_items=200000 + + # [normalize] + # workers=3 + + data_iter = ParquetFile("UNSW-NB15/Network-Flows/UNSW_Flow.parquet").iter_batches( + batch_size=128 * 1024 + ) + + pipeline = dlt.pipeline( + pipeline_name="unsw_upload", + # destination=postgres("postgres://loader:loader@localhost:5432/dlt_data"), + destination="postgres", + progress="log", + ) + pipeline.run( + data_iter, + dataset_name="speed_test", + table_name="unsw_flow_7", + loader_file_format="csv", + ) + + +def test_connectorx_speed() -> None: + """Uses unsw_flow dataset (~2mln rows, 25+ columns) to test connectorx speed""" + import os + + # from dlt.destinations import filesystem + + unsw_table = sql_table( + "postgresql://loader:loader@localhost:5432/dlt_data", + "unsw_flow_7", + "speed_test", + # this is ignored by connectorx + chunk_size=100000, + backend="connectorx", + # keep source data types + reflection_level="full_with_precision", + # just to demonstrate how to setup a separate connection string for connectorx + backend_kwargs={"conn": "postgresql://loader:loader@localhost:5432/dlt_data"}, + ) + + pipeline = dlt.pipeline( + pipeline_name="unsw_download", + destination="filesystem", + # destination=filesystem(os.path.abspath("../_storage/unsw")), + progress="log", + full_refresh=True, + ) + + info = pipeline.run( + unsw_table, + dataset_name="speed_test", + table_name="unsw_flow", + loader_file_format="parquet", + ) + print(info) + + +def test_pandas_backend_verbatim_decimals() -> None: + pipeline = dlt.pipeline( + pipeline_name="rfam_cx", + destination="duckdb", + dataset_name="rfam_data_pandas_2", + ) + + def _double_as_decimal_adapter(table: sa.Table) -> None: + """Emits decimals instead of floats.""" + for column in table.columns.values(): + if isinstance(column.type, sa.Float): + column.type.asdecimal = True + + sql_alchemy_source = sql_database( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", + backend="pandas", + table_adapter_callback=_double_as_decimal_adapter, + chunk_size=100000, + # set coerce_float to False to represent them as string + backend_kwargs={"coerce_float": False, "dtype_backend": "numpy_nullable"}, + # preserve full typing info. this will parse + reflection_level="full_with_precision", + ).with_resources("family", "genome") + + info = pipeline.run(sql_alchemy_source) + print(info) + + +def use_type_adapter() -> None: + """Example use of type adapter to coerce unknown data types""" + pipeline = dlt.pipeline( + pipeline_name="dummy", + destination="postgres", + dataset_name="dummy", + ) + + def type_adapter(sql_type: Any) -> Any: + if isinstance(sql_type, sa.ARRAY): + return sa.JSON() # Load arrays as JSON + return sql_type + + sql_alchemy_source = sql_database( + "postgresql://loader:loader@localhost:5432/dlt_data", + backend="pyarrow", + type_adapter_callback=type_adapter, + reflection_level="full_with_precision", + ).with_resources("table_with_array_column") + + info = pipeline.run(sql_alchemy_source) + print(info) + + +def specify_columns_to_load() -> None: + """Run the SQL database source with a subset of table columns loaded""" + pipeline = dlt.pipeline( + pipeline_name="dummy", + destination="duckdb", + dataset_name="dummy", + ) + + # Columns can be specified per table in env var (json array) or in `.dlt/config.toml` + os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = '["rfam_acc", "description"]' + + sql_alchemy_source = sql_database( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", + backend="pyarrow", + reflection_level="full_with_precision", + ).with_resources("family", "genome") + + info = pipeline.run(sql_alchemy_source) + print(info) + + +if __name__ == "__main__": + # Load selected tables with different settings + # load_select_tables_from_database() + + # load a table and select columns + # select_columns() + + # load_entire_database() + # select_with_end_value_and_row_order() + + # Load tables with the standalone table resource + load_standalone_table_resource() + + # Load all tables from the database. + # Warning: The sample database is very large + # load_entire_database() diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index e1cd9ce88e..7eea6d9aff 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -102,7 +102,7 @@ The GitHub API [requires an access token](https://docs.github.com/en/rest/authen After you get the token, add it to the `secrets.toml` file: ```toml -[sources.rest_api.github] +[sources.rest_api_pipeline.github_source] github_token = "your_github_token" ``` diff --git a/docs/website/docs/reference/command-line-interface.md b/docs/website/docs/reference/command-line-interface.md index 8e816fb622..693c068a4f 100644 --- a/docs/website/docs/reference/command-line-interface.md +++ b/docs/website/docs/reference/command-line-interface.md @@ -23,9 +23,9 @@ version if run again with existing `source` name. You are warned if files will b ### Specify your own "verified sources" repository. You can use `--location ` option to specify your own repository with sources. Typically you would [fork ours](https://github.com/dlt-hub/verified-sources) and start customizing and adding sources ie. to use them for your team or organization. You can also specify a branch with `--branch ` ie. to test a version being developed. -### List all verified sources +### List all sources ```sh -dlt init --list-verified-sources +dlt init --list-sources ``` Shows all available verified sources and their short descriptions. For each source, checks if your local `dlt` version requires update and prints the relevant warning. diff --git a/docs/website/docs/walkthroughs/add-a-verified-source.md b/docs/website/docs/walkthroughs/add-a-verified-source.md index d7cd24b544..144b805974 100644 --- a/docs/website/docs/walkthroughs/add-a-verified-source.md +++ b/docs/website/docs/walkthroughs/add-a-verified-source.md @@ -21,10 +21,10 @@ mkdir various_pipelines cd various_pipelines ``` -List available verified sources to see their names and descriptions: +List available sources to see their names and descriptions: ```sh -dlt init --list-verified-sources +dlt init --list-sources ``` Now pick one of the source names, for example `pipedrive` and a destination i.e. `bigquery`: diff --git a/poetry.lock b/poetry.lock index 1bfdb776a2..0bb8ec1fb3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "about-time" @@ -1941,27 +1941,27 @@ testing = ["flake8", "pytest", "pytest-cov", "pytest-virtualenv", "pytest-xdist" [[package]] name = "connectorx" -version = "0.3.2" +version = "0.3.3" description = "" optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "connectorx-0.3.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:98274242c64a2831a8b1c86e0fa2c46a557dd8cbcf00c3adcf5a602455fb02d7"}, - {file = "connectorx-0.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e2b11ba49efd330a7348bef3ce09c98218eea21d92a12dd75cd8f0ade5c99ffc"}, - {file = "connectorx-0.3.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:3f6431a30304271f9137bd7854d2850231041f95164c6b749d9ede4c0d92d10c"}, - {file = "connectorx-0.3.2-cp310-none-win_amd64.whl", hash = "sha256:b370ebe8f44d2049254dd506f17c62322cc2db1b782a57f22cce01ddcdcc8fed"}, - {file = "connectorx-0.3.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d5277fc936a80da3d1dcf889020e45da3493179070d9be8a47500c7001fab967"}, - {file = "connectorx-0.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cc6c963237c3d3b02f7dcd47e1be9fc6e8b93ef0aeed8694f65c62b3c4688a1"}, - {file = "connectorx-0.3.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:9403902685b3423cba786db01a36f36efef90ae3d429e45b74dadb4ae9e328dc"}, - {file = "connectorx-0.3.2-cp311-none-win_amd64.whl", hash = "sha256:6b5f518194a2cf12d5ad031d488ded4e4678eff3b63551856f2a6f1a83197bb8"}, - {file = "connectorx-0.3.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:a5602ae0531e55c58af8cfca92b8e9454fc1ccd82c801cff8ee0f17c728b4988"}, - {file = "connectorx-0.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c5959bfb4a049bb8ce1f590b5824cd1105460b6552ffec336c4bd740eebd5bd"}, - {file = "connectorx-0.3.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c4387bb27ba3acde0ab6921fdafa3811e09fce0db3d1f1ede8547d9de3aab685"}, - {file = "connectorx-0.3.2-cp38-none-win_amd64.whl", hash = "sha256:4b1920c191be9a372629c31c92d5f71fc63f49f283e5adfc4111169de40427d9"}, - {file = "connectorx-0.3.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4473fc06ac3618c673cea63a7050e721fe536782d5c1b6e433589c37a63de704"}, - {file = "connectorx-0.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4009b16399457340326137a223921a24e3e166b45db4dbf3ef637b9981914dc2"}, - {file = "connectorx-0.3.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:74f5b93535663cf47f9fc3d7964f93e652c07003fa71c38d7a68f42167f54bba"}, - {file = "connectorx-0.3.2-cp39-none-win_amd64.whl", hash = "sha256:0b80acca13326856c14ee726b47699011ab1baa10897180240c8783423ca5e8c"}, + {file = "connectorx-0.3.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:4c0e61e44a62eaee2ffe89bf938c7431b8f3d2d3ecdf09e8abb2d159f09138f0"}, + {file = "connectorx-0.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da1970ec09ad7a65e25936a6d613f15ad2ce916f97f17c64180415dc58493881"}, + {file = "connectorx-0.3.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b43b0abcfb954c497981bcf8f2b5339dcf7986399a401b9470f0bf8055a58562"}, + {file = "connectorx-0.3.3-cp310-none-win_amd64.whl", hash = "sha256:dff9e04396a76d3f2ca9ab1abed0df52497f19666b222c512d7b10f1699636c8"}, + {file = "connectorx-0.3.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d1d0cbb1b97643337fb7f3e30fa2b44f63d8629eadff55afffcdf10b2afeaf9c"}, + {file = "connectorx-0.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4010b466cafd728ec80adf387e53cc10668e2bc1a8c52c42a0604bea5149c412"}, + {file = "connectorx-0.3.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f430c359e7977818f90ac8cce3bb7ba340469dcabee13e4ac7926f80e34e8c4d"}, + {file = "connectorx-0.3.3-cp311-none-win_amd64.whl", hash = "sha256:6e6495cab5f23e638456622a880c774c4bcfc17ee9ed7009d4217756a7e9e2c8"}, + {file = "connectorx-0.3.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:dfefa3c55601b1a229dd27359a61c18977921455eae0c5068ec15d79900a096c"}, + {file = "connectorx-0.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b62f6cac84a7c41c4f61746262da059dd8af06d10de64ebde2d59c73e28c22b"}, + {file = "connectorx-0.3.3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2eaca398a5dae6da595c8c521d2a27050100a94e4d5778776b914b919e54ab1e"}, + {file = "connectorx-0.3.3-cp312-none-win_amd64.whl", hash = "sha256:a37762f26ced286e9c06528f0179877148ea83f24263ac53b906c33c430af323"}, + {file = "connectorx-0.3.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:9267431fa88b00c60c6113d9deabe86a2ad739c8be56ee4b57164d3ed983b5dc"}, + {file = "connectorx-0.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:823170c06b61c7744fc668e6525b26a11ca462c1c809354aa2d482bd5a92bb0e"}, + {file = "connectorx-0.3.3-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:9b001b78406dd7a1b8b7d61330bbcb73ea68f478589fc439fbda001ed875e8ea"}, + {file = "connectorx-0.3.3-cp39-none-win_amd64.whl", hash = "sha256:e1e16404e353f348120d393586c58cad8a4ebf81e07f3f1dff580b551dbc863d"}, ] [[package]] @@ -3724,106 +3724,6 @@ files = [ {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:222fc2ee0e40522de0b21ad3bc90ab8983be3bf3cec3d349c80d76c8bb1a4beb"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d4763b0b9195b72132a4e7de8e5a9bf1f05542f442a9115aa27cfc2a8004f581"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:209649da10c9d4a93d8a4d100ecbf9cc3b0252169426bec3e8b4ad7e57d600cf"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:68813aa333c1604a2df4a495b2a6ed065d7c8aebf26cc7e7abb5a6835d08353c"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:370a23ec775ad14e9d1e71474d56f381224dcf3e72b15d8ca7b4ad7dd9cd5853"}, - {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:14664a66a3ddf6bc9e56f401bf029db2d169982c53eff3f5876399104df0e9a6"}, - {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea3722cc4932cbcebd553b69dce1b4a73572823cff4e6a244f1c855da21d511"}, - {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e14bb264c40fd7c627ef5678e295370cd6ba95ca71d835798b6e37502fc4c690"}, - {file = "google_re2-1.1-5-cp310-cp310-win32.whl", hash = "sha256:39512cd0151ea4b3969c992579c79b423018b464624ae955be685fc07d94556c"}, - {file = "google_re2-1.1-5-cp310-cp310-win_amd64.whl", hash = "sha256:ac66537aa3bc5504320d922b73156909e3c2b6da19739c866502f7827b3f9fdf"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b5ea68d54890c9edb1b930dcb2658819354e5d3f2201f811798bbc0a142c2b4"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:33443511b6b83c35242370908efe2e8e1e7cae749c766b2b247bf30e8616066c"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:413d77bdd5ba0bfcada428b4c146e87707452ec50a4091ec8e8ba1413d7e0619"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:5171686e43304996a34baa2abcee6f28b169806d0e583c16d55e5656b092a414"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b284db130283771558e31a02d8eb8fb756156ab98ce80035ae2e9e3a5f307c4"}, - {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:296e6aed0b169648dc4b870ff47bd34c702a32600adb9926154569ef51033f47"}, - {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:38d50e68ead374160b1e656bbb5d101f0b95fb4cc57f4a5c12100155001480c5"}, - {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a0416a35921e5041758948bcb882456916f22845f66a93bc25070ef7262b72a"}, - {file = "google_re2-1.1-5-cp311-cp311-win32.whl", hash = "sha256:a1d59568bbb5de5dd56dd6cdc79907db26cce63eb4429260300c65f43469e3e7"}, - {file = "google_re2-1.1-5-cp311-cp311-win_amd64.whl", hash = "sha256:72f5a2f179648b8358737b2b493549370debd7d389884a54d331619b285514e3"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cbc72c45937b1dc5acac3560eb1720007dccca7c9879138ff874c7f6baf96005"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5fadd1417fbef7235fa9453dba4eb102e6e7d94b1e4c99d5fa3dd4e288d0d2ae"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:040f85c63cc02696485b59b187a5ef044abe2f99b92b4fb399de40b7d2904ccc"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:64e3b975ee6d9bbb2420494e41f929c1a0de4bcc16d86619ab7a87f6ea80d6bd"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ee370413e00f4d828eaed0e83b8af84d7a72e8ee4f4bd5d3078bc741dfc430a"}, - {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:5b89383001079323f693ba592d7aad789d7a02e75adb5d3368d92b300f5963fd"}, - {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63cb4fdfbbda16ae31b41a6388ea621510db82feb8217a74bf36552ecfcd50ad"}, - {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ebedd84ae8be10b7a71a16162376fd67a2386fe6361ef88c622dcf7fd679daf"}, - {file = "google_re2-1.1-5-cp312-cp312-win32.whl", hash = "sha256:c8e22d1692bc2c81173330c721aff53e47ffd3c4403ff0cd9d91adfd255dd150"}, - {file = "google_re2-1.1-5-cp312-cp312-win_amd64.whl", hash = "sha256:5197a6af438bb8c4abda0bbe9c4fbd6c27c159855b211098b29d51b73e4cbcf6"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b6727e0b98417e114b92688ad2aa256102ece51f29b743db3d831df53faf1ce3"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:711e2b6417eb579c61a4951029d844f6b95b9b373b213232efd413659889a363"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:71ae8b3df22c5c154c8af0f0e99d234a450ef1644393bc2d7f53fc8c0a1e111c"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:94a04e214bc521a3807c217d50cf099bbdd0c0a80d2d996c0741dbb995b5f49f"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:a770f75358508a9110c81a1257721f70c15d9bb592a2fb5c25ecbd13566e52a5"}, - {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:07c9133357f7e0b17c6694d5dcb82e0371f695d7c25faef2ff8117ef375343ff"}, - {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:204ca6b1cf2021548f4a9c29ac015e0a4ab0a7b6582bf2183d838132b60c8fda"}, - {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b95857c2c654f419ca684ec38c9c3325c24e6ba7d11910a5110775a557bb18"}, - {file = "google_re2-1.1-5-cp38-cp38-win32.whl", hash = "sha256:347ac770e091a0364e822220f8d26ab53e6fdcdeaec635052000845c5a3fb869"}, - {file = "google_re2-1.1-5-cp38-cp38-win_amd64.whl", hash = "sha256:ec32bb6de7ffb112a07d210cf9f797b7600645c2d5910703fa07f456dd2150e0"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb5adf89060f81c5ff26c28e261e6b4997530a923a6093c9726b8dec02a9a326"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a22630c9dd9ceb41ca4316bccba2643a8b1d5c198f21c00ed5b50a94313aaf10"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:544dc17fcc2d43ec05f317366375796351dec44058e1164e03c3f7d050284d58"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:19710af5ea88751c7768575b23765ce0dfef7324d2539de576f75cdc319d6654"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:f82995a205e08ad896f4bd5ce4847c834fab877e1772a44e5f262a647d8a1dec"}, - {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:63533c4d58da9dc4bc040250f1f52b089911699f0368e0e6e15f996387a984ed"}, - {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79e00fcf0cb04ea35a22b9014712d448725ce4ddc9f08cc818322566176ca4b0"}, - {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bc41afcefee2da6c4ed883a93d7f527c4b960cd1d26bbb0020a7b8c2d341a60a"}, - {file = "google_re2-1.1-5-cp39-cp39-win32.whl", hash = "sha256:486730b5e1f1c31b0abc6d80abe174ce4f1188fe17d1b50698f2bf79dc6e44be"}, - {file = "google_re2-1.1-5-cp39-cp39-win_amd64.whl", hash = "sha256:4de637ca328f1d23209e80967d1b987d6b352cd01b3a52a84b4d742c69c3da6c"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:621e9c199d1ff0fdb2a068ad450111a84b3bf14f96dfe5a8a7a0deae5f3f4cce"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:220acd31e7dde95373f97c3d1f3b3bd2532b38936af28b1917ee265d25bebbf4"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:db34e1098d164f76251a6ece30e8f0ddfd65bb658619f48613ce71acb3f9cbdb"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:5152bac41d8073977582f06257219541d0fc46ad99b0bbf30e8f60198a43b08c"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6191294799e373ee1735af91f55abd23b786bdfd270768a690d9d55af9ea1b0d"}, - {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:070cbafbb4fecbb02e98feb28a1eb292fb880f434d531f38cc33ee314b521f1f"}, - {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8437d078b405a59a576cbed544490fe041140f64411f2d91012e8ec05ab8bf86"}, - {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f00f9a9af8896040e37896d9b9fc409ad4979f1ddd85bb188694a7d95ddd1164"}, - {file = "google_re2-1.1-6-cp310-cp310-win32.whl", hash = "sha256:df26345f229a898b4fd3cafd5f82259869388cee6268fc35af16a8e2293dd4e5"}, - {file = "google_re2-1.1-6-cp310-cp310-win_amd64.whl", hash = "sha256:3665d08262c57c9b28a5bdeb88632ad792c4e5f417e5645901695ab2624f5059"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b26b869d8aa1d8fe67c42836bf3416bb72f444528ee2431cfb59c0d3e02c6ce3"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:41fd4486c57dea4f222a6bb7f1ff79accf76676a73bdb8da0fcbd5ba73f8da71"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:0ee378e2e74e25960070c338c28192377c4dd41e7f4608f2688064bd2badc41e"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:a00cdbf662693367b36d075b29feb649fd7ee1b617cf84f85f2deebeda25fc64"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c09455014217a41499432b8c8f792f25f3df0ea2982203c3a8c8ca0e7895e69"}, - {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6501717909185327935c7945e23bb5aa8fc7b6f237b45fe3647fa36148662158"}, - {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3510b04790355f199e7861c29234081900e1e1cbf2d1484da48aa0ba6d7356ab"}, - {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8c0e64c187ca406764f9e9ad6e750d62e69ed8f75bf2e865d0bfbc03b642361c"}, - {file = "google_re2-1.1-6-cp311-cp311-win32.whl", hash = "sha256:2a199132350542b0de0f31acbb3ca87c3a90895d1d6e5235f7792bb0af02e523"}, - {file = "google_re2-1.1-6-cp311-cp311-win_amd64.whl", hash = "sha256:83bdac8ceaece8a6db082ea3a8ba6a99a2a1ee7e9f01a9d6d50f79c6f251a01d"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:81985ff894cd45ab5a73025922ac28c0707759db8171dd2f2cc7a0e856b6b5ad"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5635af26065e6b45456ccbea08674ae2ab62494008d9202df628df3b267bc095"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:813b6f04de79f4a8fdfe05e2cb33e0ccb40fe75d30ba441d519168f9d958bd54"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:5ec2f5332ad4fd232c3f2d6748c2c7845ccb66156a87df73abcc07f895d62ead"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5a687b3b32a6cbb731647393b7c4e3fde244aa557f647df124ff83fb9b93e170"}, - {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:39a62f9b3db5d3021a09a47f5b91708b64a0580193e5352751eb0c689e4ad3d7"}, - {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca0f0b45d4a1709cbf5d21f355e5809ac238f1ee594625a1e5ffa9ff7a09eb2b"}, - {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64b3796a7a616c7861247bd061c9a836b5caf0d5963e5ea8022125601cf7b09"}, - {file = "google_re2-1.1-6-cp312-cp312-win32.whl", hash = "sha256:32783b9cb88469ba4cd9472d459fe4865280a6b1acdad4480a7b5081144c4eb7"}, - {file = "google_re2-1.1-6-cp312-cp312-win_amd64.whl", hash = "sha256:259ff3fd2d39035b9cbcbf375995f83fa5d9e6a0c5b94406ff1cc168ed41d6c6"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e4711bcffe190acd29104d8ecfea0c0e42b754837de3fb8aad96e6cc3c613cdc"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:4d081cce43f39c2e813fe5990e1e378cbdb579d3f66ded5bade96130269ffd75"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:4f123b54d48450d2d6b14d8fad38e930fb65b5b84f1b022c10f2913bd956f5b5"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:e1928b304a2b591a28eb3175f9db7f17c40c12cf2d4ec2a85fdf1cc9c073ff91"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:3a69f76146166aec1173003c1f547931bdf288c6b135fda0020468492ac4149f"}, - {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:fc08c388f4ebbbca345e84a0c56362180d33d11cbe9ccfae663e4db88e13751e"}, - {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b057adf38ce4e616486922f2f47fc7d19c827ba0a7f69d540a3664eba2269325"}, - {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4138c0b933ab099e96f5d8defce4486f7dfd480ecaf7f221f2409f28022ccbc5"}, - {file = "google_re2-1.1-6-cp38-cp38-win32.whl", hash = "sha256:9693e45b37b504634b1abbf1ee979471ac6a70a0035954592af616306ab05dd6"}, - {file = "google_re2-1.1-6-cp38-cp38-win_amd64.whl", hash = "sha256:5674d437baba0ea287a5a7f8f81f24265d6ae8f8c09384e2ef7b6f84b40a7826"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7783137cb2e04f458a530c6d0ee9ef114815c1d48b9102f023998c371a3b060e"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a49b7153935e7a303675f4deb5f5d02ab1305adefc436071348706d147c889e0"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a96a8bb309182090704593c60bdb369a2756b38fe358bbf0d40ddeb99c71769f"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:dff3d4be9f27ef8ec3705eed54f19ef4ab096f5876c15fe011628c69ba3b561c"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:40f818b0b39e26811fa677978112a8108269977fdab2ba0453ac4363c35d9e66"}, - {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:8a7e53538cdb40ef4296017acfbb05cab0c19998be7552db1cfb85ba40b171b9"}, - {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ee18e7569fb714e5bb8c42809bf8160738637a5e71ed5a4797757a1fb4dc4de"}, - {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cda4f6d1a7d5b43ea92bc395f23853fba0caf8b1e1efa6e8c48685f912fcb89"}, - {file = "google_re2-1.1-6-cp39-cp39-win32.whl", hash = "sha256:6a9cdbdc36a2bf24f897be6a6c85125876dc26fea9eb4247234aec0decbdccfd"}, - {file = "google_re2-1.1-6-cp39-cp39-win_amd64.whl", hash = "sha256:73f646cecfad7cc5b4330b4192c25f2e29730a3b8408e089ffd2078094208196"}, ] [[package]] @@ -3875,6 +3775,17 @@ files = [ [package.extras] test = ["pytest", "sphinx", "sphinx-autobuild", "twine", "wheel"] +[[package]] +name = "graphlib-backport" +version = "1.1.0" +description = "Backport of the Python 3.9 graphlib module for Python 3.6+" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "graphlib_backport-1.1.0-py3-none-any.whl", hash = "sha256:eccacf9f2126cdf89ce32a6018c88e1ecd3e4898a07568add6e1907a439055ba"}, + {file = "graphlib_backport-1.1.0.tar.gz", hash = "sha256:00a7888b21e5393064a133209cb5d3b3ef0a2096cf023914c9d778dff5644125"}, +] + [[package]] name = "greenlet" version = "3.0.3" @@ -5186,6 +5097,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mimesis" +version = "7.1.0" +description = "Mimesis: Fake Data Generator." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "mimesis-7.1.0-py3-none-any.whl", hash = "sha256:da65bea6d6d5d5d87d5c008e6b23ef5f96a49cce436d9f8708dabb5152da0290"}, + {file = "mimesis-7.1.0.tar.gz", hash = "sha256:c83b55d35536d7e9b9700a596b7ccfb639a740e3e1fb5e08062e8ab2a67dcb37"}, +] + [[package]] name = "minimal-snowplow-tracker" version = "0.0.2" @@ -9716,10 +9638,11 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] +sql-database = ["sqlalchemy"] synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "888e1760984e867fde690a1cca90330e255d69a8775c81020d003650def7ab4c" +content-hash = "ae02db22861b419596adea95c7ddff27317ae91579c6e9138f777489fe20c05a" diff --git a/pyproject.toml b/pyproject.toml index 1bdaf77b86..28d6056f60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.5.4" +version = "0.9.9a0" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -18,7 +18,7 @@ classifiers = [ "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows",] keywords = [ "etl" ] -include = [ "LICENSE.txt", "README.md"] +include = [ "LICENSE.txt", "README.md", "dlt/sources/pipeline_templates/.gitignore", "dlt/sources/pipeline_templates/.dlt/config.toml" ] packages = [ { include = "dlt" }, ] @@ -51,6 +51,7 @@ jsonpath-ng = ">=1.5.3" fsspec = ">=2022.4.0" packaging = ">=21.1" win-precise-time = {version = ">=1.4.2", markers="os_name == 'nt'"} +graphlib-backport = {version = "*", python = "<3.9"} psycopg2-binary = {version = ">=2.9.1", optional = true} # use this dependency as the current version of psycopg2cffi does not have sql module @@ -82,6 +83,7 @@ clickhouse-connect = { version = ">=0.7.7", optional = true } lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true } tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } +sqlalchemy = { version = ">=1.4", optional = true } [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -108,6 +110,7 @@ clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] +sql_database = ["sqlalchemy"] [tool.poetry.scripts] @@ -156,6 +159,7 @@ pyjwt = "^2.8.0" pytest-mock = "^3.14.0" types-regex = "^2024.5.15.20240519" flake8-print = "^5.0.0" +mimesis = "^7.0.0" [tool.poetry.group.pipeline] optional = true @@ -213,7 +217,6 @@ SQLAlchemy = ">=1.4.0" pymysql = "^1.1.0" pypdf2 = "^3.0.1" pydoc-markdown = "^4.8.2" -connectorx = "0.3.2" dbt-core = ">=1.2.0" dbt-duckdb = ">=1.2.0" pymongo = ">=4.3.3" @@ -223,6 +226,7 @@ pyarrow = ">=14.0.0" psycopg2-binary = ">=2.9" lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.35" +connectorx = { version = ">=0.3.2" } [tool.black] # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file line-length = 100 @@ -237,4 +241,4 @@ multi_line_output = 3 [build-system] requires = ["poetry-core>=1.0.8"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +build-backend = "poetry.core.masonry.api" diff --git a/tests/.example.env b/tests/.example.env index 50eee33bd5..175544218c 100644 --- a/tests/.example.env +++ b/tests/.example.env @@ -19,6 +19,6 @@ DESTINATION__REDSHIFT__CREDENTIALS__USERNAME=loader DESTINATION__REDSHIFT__CREDENTIALS__HOST=3.73.90.3 DESTINATION__REDSHIFT__CREDENTIALS__PASSWORD=set-me-up -DESTINATION__POSTGRES__CREDENTIALS=postgres://loader:loader@localhost:5432/dlt_data +DESTINATION__POSTGRES__CREDENTIALS=postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__DUCKDB__CREDENTIALS=duckdb:///_storage/test_quack.duckdb -RUNTIME__SENTRY_DSN=https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 \ No newline at end of file +RUNTIME__SENTRY_DSN=https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index f856162479..0c6be1ea24 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -106,9 +106,9 @@ def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: assert result.returncode == 0 -def test_invoke_list_verified_sources(script_runner: ScriptRunner) -> None: +def test_invoke_list_sources(script_runner: ScriptRunner) -> None: known_sources = ["chess", "sql_database", "google_sheets", "pipedrive"] - result = script_runner.run(["dlt", "init", "--list-verified-sources"]) + result = script_runner.run(["dlt", "init", "--list-sources"]) assert result.returncode == 0 for known_source in known_sources: assert known_source in result.stdout diff --git a/tests/cli/common/test_telemetry_command.py b/tests/cli/common/test_telemetry_command.py index d2ccc81ebe..21f44b3e88 100644 --- a/tests/cli/common/test_telemetry_command.py +++ b/tests/cli/common/test_telemetry_command.py @@ -132,7 +132,7 @@ def instrument_raises_2(in_raises_2: bool) -> int: def test_instrumentation_wrappers() -> None: from dlt.cli._dlt import ( init_command_wrapper, - list_verified_sources_command_wrapper, + list_sources_command_wrapper, DEFAULT_VERIFIED_SOURCES_REPO, pipeline_command_wrapper, deploy_command_wrapper, @@ -145,7 +145,7 @@ def test_instrumentation_wrappers() -> None: SENT_ITEMS.clear() with io.StringIO() as buf, contextlib.redirect_stderr(buf): - init_command_wrapper("instrumented_source", "", False, None, None) + init_command_wrapper("instrumented_source", "", None, None) output = buf.getvalue() assert "is not one of the standard dlt destinations" in output msg = SENT_ITEMS[0] @@ -155,7 +155,7 @@ def test_instrumentation_wrappers() -> None: assert msg["properties"]["success"] is False SENT_ITEMS.clear() - list_verified_sources_command_wrapper(DEFAULT_VERIFIED_SOURCES_REPO, None) + list_sources_command_wrapper(DEFAULT_VERIFIED_SOURCES_REPO, None) msg = SENT_ITEMS[0] assert msg["event"] == "command_list_sources" diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index 03eded9da0..e85c4593f6 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -10,6 +10,7 @@ from unittest import mock import re from packaging.requirements import Requirement +from typing import Dict # import that because O3 modules cannot be unloaded import cryptography.hazmat.bindings._rust @@ -29,9 +30,14 @@ from dlt.cli import init_command, echo from dlt.cli.init_command import ( SOURCES_MODULE_NAME, + DEFAULT_VERIFIED_SOURCES_REPO, + SourceConfiguration, utils as cli_utils, files_ops, _select_source_files, + _list_core_sources, + _list_template_sources, + _list_verified_sources, ) from dlt.cli.exceptions import CliCommandException from dlt.cli.requirements import SourceRequirements @@ -49,37 +55,61 @@ from tests.common.utils import modify_and_commit_file from tests.utils import IMPLEMENTED_DESTINATIONS, clean_test_storage +# we hardcode the core sources here so we can check that the init script picks +# up the right source +CORE_SOURCES = ["filesystem", "rest_api", "sql_database"] + +# we also hardcode all the templates here for testing +TEMPLATES = ["debug", "default", "arrow", "requests", "dataframe", "intro"] + +# a few verified sources we know to exist +SOME_KNOWN_VERIFIED_SOURCES = ["chess", "sql_database", "google_sheets", "pipedrive"] + def get_verified_source_candidates(repo_dir: str) -> List[str]: sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME)) # enumerate all candidate verified sources - return files_ops.get_verified_source_names(sources_storage) + return files_ops.get_sources_names(sources_storage, source_type="verified") def test_init_command_pipeline_template(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) + init_command.init_command("debug", "bigquery", repo_dir) visitor = assert_init_files(project_files, "debug_pipeline", "bigquery") # single resource assert len(visitor.known_resource_calls) == 1 -def test_init_command_pipeline_generic(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("generic_pipeline", "redshift", True, repo_dir) - visitor = assert_init_files(project_files, "generic_pipeline", "redshift") +def test_init_command_pipeline_default_template(repo_dir: str, project_files: FileStorage) -> None: + init_command.init_command("some_random_name", "redshift", repo_dir) + visitor = assert_init_files(project_files, "some_random_name_pipeline", "redshift") # multiple resources assert len(visitor.known_resource_calls) > 1 +def test_default_source_file_selection() -> None: + templates_storage = init_command._get_templates_storage() + + # try a known source, it will take the known pipeline script + tconf = files_ops.get_template_configuration(templates_storage, "debug") + assert tconf.dest_pipeline_script == "debug_pipeline.py" + assert tconf.src_pipeline_script == "debug_pipeline.py" + + # random name will select the default script + tconf = files_ops.get_template_configuration(templates_storage, "very_nice_name") + assert tconf.dest_pipeline_script == "very_nice_name_pipeline.py" + assert tconf.src_pipeline_script == "default_pipeline.py" + + def test_init_command_new_pipeline_same_name(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) + init_command.init_command("debug_pipeline", "bigquery", repo_dir) with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) + init_command.init_command("debug_pipeline", "bigquery", repo_dir) _out = buf.getvalue() - assert "already exist, exiting" in _out + assert "already exists, exiting" in _out def test_init_command_chess_verified_source(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("chess", "duckdb", False, repo_dir) + init_command.init_command("chess", "duckdb", repo_dir) assert_source_files(project_files, "chess", "duckdb", has_source_section=True) assert_requirements_txt(project_files, "duckdb") # check files hashes @@ -110,25 +140,40 @@ def test_init_command_chess_verified_source(repo_dir: str, project_files: FileSt raise -def test_init_list_verified_pipelines(repo_dir: str, project_files: FileStorage) -> None: - sources = init_command._list_verified_sources(repo_dir) - # a few known sources must be there - known_sources = ["chess", "sql_database", "google_sheets", "pipedrive"] - assert set(known_sources).issubset(set(sources.keys())) - # check docstrings - for k_p in known_sources: - assert sources[k_p].doc - # run the command - init_command.list_verified_sources_command(repo_dir) +def test_list_sources(repo_dir: str) -> None: + def check_results(items: Dict[str, SourceConfiguration]) -> None: + for name, source in items.items(): + assert source.doc, f"{name} missing docstring" + core_sources = _list_core_sources() + assert set(core_sources) == set(CORE_SOURCES) + check_results(core_sources) -def test_init_list_verified_pipelines_update_warning( - repo_dir: str, project_files: FileStorage -) -> None: + verified_sources = _list_verified_sources(DEFAULT_VERIFIED_SOURCES_REPO) + assert set(SOME_KNOWN_VERIFIED_SOURCES).issubset(verified_sources) + check_results(verified_sources) + assert len(verified_sources.keys()) > 10 + + templates = _list_template_sources() + assert set(templates) == set(TEMPLATES) + check_results(templates) + + +def test_init_list_sources(repo_dir: str) -> None: + # run the command and check all the sources are there + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + init_command.list_sources_command(repo_dir) + _out = buf.getvalue() + + for source in SOME_KNOWN_VERIFIED_SOURCES + TEMPLATES + CORE_SOURCES: + assert source in _out + + +def test_init_list_sources_update_warning(repo_dir: str, project_files: FileStorage) -> None: """Sources listed include a warning if a different dlt version is required""" with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.0.1"): with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.list_verified_sources_command(repo_dir) + init_command.list_sources_command(repo_dir) _out = buf.getvalue() # Check one listed source @@ -143,17 +188,18 @@ def test_init_list_verified_pipelines_update_warning( assert "0.0.1" not in parsed_requirement.specifier -def test_init_all_verified_sources_together(repo_dir: str, project_files: FileStorage) -> None: - source_candidates = get_verified_source_candidates(repo_dir) +def test_init_all_sources_together(repo_dir: str, project_files: FileStorage) -> None: + source_candidates = [*get_verified_source_candidates(repo_dir), *CORE_SOURCES, *TEMPLATES] + # source_candidates = [source_name for source_name in source_candidates if source_name == "salesforce"] for source_name in source_candidates: # all must install correctly - init_command.init_command(source_name, "bigquery", False, repo_dir) + init_command.init_command(source_name, "bigquery", repo_dir) # verify files _, secrets = assert_source_files(project_files, source_name, "bigquery") # requirements.txt is created from the first source and not overwritten afterwards - assert_index_version_constraint(project_files, source_candidates[0]) + assert_index_version_constraint(project_files, list(source_candidates)[0]) # secrets should contain sections for all sources for source_name in source_candidates: assert secrets.get_value(source_name, type, None, "sources") is not None @@ -163,44 +209,66 @@ def test_init_all_verified_sources_together(repo_dir: str, project_files: FileSt for destination_name in ["bigquery", "postgres", "redshift"]: assert secrets.get_value(destination_name, type, None, "destination") is not None - # create pipeline template on top - init_command.init_command("debug_pipeline", "postgres", False, repo_dir) - assert_init_files(project_files, "debug_pipeline", "postgres", "bigquery") - # clear the resources otherwise sources not belonging to generic_pipeline will be found - _SOURCES.clear() - init_command.init_command("generic_pipeline", "redshift", True, repo_dir) - assert_init_files(project_files, "generic_pipeline", "redshift", "bigquery") - -def test_init_all_verified_sources_isolated(cloned_init_repo: FileStorage) -> None: +def test_init_all_sources_isolated(cloned_init_repo: FileStorage) -> None: repo_dir = get_repo_dir(cloned_init_repo) - for candidate in get_verified_source_candidates(repo_dir): + # ensure we test both sources form verified sources and core sources + source_candidates = ( + set(get_verified_source_candidates(repo_dir)).union(set(CORE_SOURCES)).union(set(TEMPLATES)) + ) + for candidate in source_candidates: clean_test_storage() repo_dir = get_repo_dir(cloned_init_repo) - files = get_project_files() + files = get_project_files(clear_all_sources=False) with set_working_dir(files.storage_path): - init_command.init_command(candidate, "bigquery", False, repo_dir) + init_command.init_command(candidate, "bigquery", repo_dir) assert_source_files(files, candidate, "bigquery") assert_requirements_txt(files, "bigquery") - assert_index_version_constraint(files, candidate) + if candidate not in CORE_SOURCES + TEMPLATES: + assert_index_version_constraint(files, candidate) @pytest.mark.parametrize("destination_name", IMPLEMENTED_DESTINATIONS) def test_init_all_destinations( destination_name: str, project_files: FileStorage, repo_dir: str ) -> None: - if destination_name == "destination": - pytest.skip("Init for generic destination not implemented yet") - pipeline_name = f"generic_{destination_name}_pipeline" - init_command.init_command(pipeline_name, destination_name, True, repo_dir) - assert_init_files(project_files, pipeline_name, destination_name) + source_name = "generic" + init_command.init_command(source_name, destination_name, repo_dir) + assert_init_files(project_files, source_name + "_pipeline", destination_name) + + +def test_custom_destination_note(repo_dir: str, project_files: FileStorage): + source_name = "generic" + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + init_command.init_command(source_name, "destination", repo_dir) + _out = buf.getvalue() + assert "to add a destination function that will consume your data" in _out + + +@pytest.mark.parametrize("omit", [True, False]) +# this will break if we have new core sources that are not in verified sources anymore +@pytest.mark.parametrize("source", CORE_SOURCES) +def test_omit_core_sources( + source: str, omit: bool, project_files: FileStorage, repo_dir: str +) -> None: + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + init_command.init_command(source, "destination", repo_dir, omit_core_sources=omit) + _out = buf.getvalue() + + # check messaging + assert ("Omitting dlt core sources" in _out) == omit + assert ("will no longer be copied from the" in _out) == (not omit) + + # if we omit core sources, there will be a folder with the name of the source from the verified sources repo + assert project_files.has_folder(source) == omit + assert (f"dlt.sources.{source}" in project_files.load(f"{source}_pipeline.py")) == (not omit) def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) -> None: sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME)) new_content = '"""New docstrings"""' new_content_hash = hashlib.sha3_256(bytes(new_content, encoding="ascii")).hexdigest() - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) # modify existing file, no commit mod_file_path = os.path.join("pipedrive", "__init__.py") @@ -211,7 +279,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) sources_storage.save(new_file_path, new_content) sources_storage.delete(del_file_path) - source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") + source_files = files_ops.get_verified_source_configuration(sources_storage, "pipedrive") remote_index = files_ops.get_remote_source_index( sources_storage.storage_path, source_files.files, ">=0.3.5" ) @@ -257,7 +325,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) mod_file_path_2 = os.path.join("pipedrive", "new_munger_X.py") sources_storage.save(mod_file_path_2, local_content) local_index = files_ops.load_verified_sources_local_index("pipedrive") - source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") + source_files = files_ops.get_verified_source_configuration(sources_storage, "pipedrive") remote_index = files_ops.get_remote_source_index( sources_storage.storage_path, source_files.files, ">=0.3.5" ) @@ -300,7 +368,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) sources_storage.save(new_file_path, local_content) sources_storage.save(mod_file_path, local_content) project_files.delete(del_file_path) - source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") + source_files = files_ops.get_verified_source_configuration(sources_storage, "pipedrive") remote_index = files_ops.get_remote_source_index( sources_storage.storage_path, source_files.files, ">=0.3.5" ) @@ -313,7 +381,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # generate a conflict by deleting file locally that is modified on remote project_files.delete(mod_file_path) - source_files = files_ops.get_verified_source_files(sources_storage, "pipedrive") + source_files = files_ops.get_verified_source_configuration(sources_storage, "pipedrive") remote_index = files_ops.get_remote_source_index( sources_storage.storage_path, source_files.files, ">=0.3.5" ) @@ -325,7 +393,7 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) with git.get_repo(repo_dir) as repo: assert git.is_clean_and_synced(repo) is True @@ -341,7 +409,7 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) assert project_files.has_file(mod_local_path) _, commit = modify_and_commit_file(repo_dir, mod_remote_path, content=new_content) # update without conflict - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) # was file copied assert project_files.load(mod_local_path) == new_content with git.get_repo(repo_dir) as repo: @@ -368,14 +436,14 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) # repeat the same: no files to update with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) _out = buf.getvalue() assert "No files to update, exiting" in _out # delete file repo_storage = FileStorage(repo_dir) repo_storage.delete(mod_remote_path) - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) # file should be deleted assert not project_files.has_file(mod_local_path) @@ -383,14 +451,14 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) new_local_path = os.path.join("pipedrive", "__init__X.py") new_remote_path = os.path.join(SOURCES_MODULE_NAME, new_local_path) repo_storage.save(new_remote_path, new_content) - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) # was file copied assert project_files.load(new_local_path) == new_content # deleting the source folder will fully reload project_files.delete_folder("pipedrive", recursively=True) with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) _out = buf.getvalue() # source was added anew assert "was added to your project!" in _out @@ -403,7 +471,7 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) def test_init_code_update_conflict( repo_dir: str, project_files: FileStorage, resolution: str ) -> None: - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) repo_storage = FileStorage(repo_dir) mod_local_path = os.path.join("pipedrive", "__init__.py") mod_remote_path = os.path.join(SOURCES_MODULE_NAME, mod_local_path) @@ -417,7 +485,7 @@ def test_init_code_update_conflict( with echo.always_choose(False, resolution): with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("pipedrive", "duckdb", False, repo_dir) + init_command.init_command("pipedrive", "duckdb", repo_dir) _out = buf.getvalue() if resolution == "s": @@ -441,7 +509,7 @@ def test_init_pyproject_toml(repo_dir: str, project_files: FileStorage) -> None: # add pyproject.toml to trigger dependency system project_files.save(cli_utils.PYPROJECT_TOML, "# toml") with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("google_sheets", "bigquery", False, repo_dir) + init_command.init_command("google_sheets", "bigquery", repo_dir) _out = buf.getvalue() assert "pyproject.toml" in _out assert "google-api-python-client" in _out @@ -452,20 +520,21 @@ def test_init_requirements_text(repo_dir: str, project_files: FileStorage) -> No # add pyproject.toml to trigger dependency system project_files.save(cli_utils.REQUIREMENTS_TXT, "# requirements") with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("google_sheets", "bigquery", False, repo_dir) + init_command.init_command("google_sheets", "bigquery", repo_dir) _out = buf.getvalue() assert "requirements.txt" in _out assert "google-api-python-client" in _out assert "pip3 install" in _out +@pytest.mark.skip("Why is this not working??") def test_pipeline_template_sources_in_single_file( repo_dir: str, project_files: FileStorage ) -> None: - init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) + init_command.init_command("debug", "bigquery", repo_dir) # _SOURCES now contains the sources from pipeline.py which simulates loading from two places with pytest.raises(CliCommandException) as cli_ex: - init_command.init_command("generic_pipeline", "redshift", True, repo_dir) + init_command.init_command("arrow", "redshift", repo_dir) assert "In init scripts you must declare all sources and resources in single file." in str( cli_ex.value ) @@ -474,7 +543,7 @@ def test_pipeline_template_sources_in_single_file( def test_incompatible_dlt_version_warning(repo_dir: str, project_files: FileStorage) -> None: with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.1.1"): with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command("facebook_ads", "bigquery", False, repo_dir) + init_command.init_command("facebook_ads", "bigquery", repo_dir) _out = buf.getvalue() assert ( @@ -530,7 +599,7 @@ def assert_source_files( visitor, secrets = assert_common_files( project_files, source_name + "_pipeline.py", destination_name ) - assert project_files.has_folder(source_name) + assert project_files.has_folder(source_name) == (source_name not in [*CORE_SOURCES, *TEMPLATES]) source_secrets = secrets.get_value(source_name, type, None, source_name) if has_source_section: assert source_secrets is not None diff --git a/tests/cli/test_pipeline_command.py b/tests/cli/test_pipeline_command.py index 82d74299f8..664646e2e5 100644 --- a/tests/cli/test_pipeline_command.py +++ b/tests/cli/test_pipeline_command.py @@ -22,7 +22,7 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("chess", "duckdb", False, repo_dir) + init_command.init_command("chess", "duckdb", repo_dir) try: pipeline = dlt.attach(pipeline_name="chess_pipeline") @@ -160,7 +160,7 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("chess", "dummy", False, repo_dir) + init_command.init_command("chess", "dummy", repo_dir) try: pipeline = dlt.attach(pipeline_name="chess_pipeline") @@ -195,7 +195,7 @@ def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileStorage) -> None: - init_command.init_command("chess", "dummy", False, repo_dir) + init_command.init_command("chess", "dummy", repo_dir) os.environ["EXCEPTION_PROB"] = "1.0" try: diff --git a/tests/cli/utils.py b/tests/cli/utils.py index 56c614e3ae..998885375f 100644 --- a/tests/cli/utils.py +++ b/tests/cli/utils.py @@ -56,7 +56,16 @@ def get_repo_dir(cloned_init_repo: FileStorage) -> str: return repo_dir -def get_project_files() -> FileStorage: - _SOURCES.clear() +def get_project_files(clear_all_sources: bool = True) -> FileStorage: + # we only remove sources registered outside of dlt core + for name, source in _SOURCES.copy().items(): + if not source.module.__name__.startswith( + "dlt.sources" + ) and not source.module.__name__.startswith("default_pipeline"): + _SOURCES.pop(name) + + if clear_all_sources: + _SOURCES.clear() + # project dir return FileStorage(PROJECT_DIR, makedirs=True) diff --git a/tests/common/storages/custom/freshman_kgs.xlsx b/tests/common/storages/custom/freshman_kgs.xlsx new file mode 100644 index 0000000000..2c3d0fbf9a Binary files /dev/null and b/tests/common/storages/custom/freshman_kgs.xlsx differ diff --git a/tests/load/clickhouse/test_clickhouse_configuration.py b/tests/load/clickhouse/test_clickhouse_configuration.py index a4e8abc8dd..2b74922c34 100644 --- a/tests/load/clickhouse/test_clickhouse_configuration.py +++ b/tests/load/clickhouse/test_clickhouse_configuration.py @@ -3,7 +3,7 @@ import pytest from dlt.common.configuration.resolve import resolve_configuration -from dlt.common.libs.sql_alchemy import make_url +from dlt.common.libs.sql_alchemy_shims import make_url from dlt.common.utils import digest128 from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient from dlt.destinations.impl.clickhouse.configuration import ( diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index 10d93d104c..f692b7ae92 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -8,7 +8,7 @@ pytest.importorskip("snowflake") -from dlt.common.libs.sql_alchemy import make_url +from dlt.common.libs.sql_alchemy_shims import make_url from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import digest128 diff --git a/tests/load/sources/__init__.py b/tests/load/sources/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sources/filesystem/__init__.py b/tests/load/sources/filesystem/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sources/filesystem/cases.py b/tests/load/sources/filesystem/cases.py new file mode 100644 index 0000000000..52f49686a9 --- /dev/null +++ b/tests/load/sources/filesystem/cases.py @@ -0,0 +1,69 @@ +import os + +from tests.load.utils import WITH_GDRIVE_BUCKETS + +TESTS_BUCKET_URLS = [ + os.path.join(bucket_url, "standard_source/samples") + for bucket_url in WITH_GDRIVE_BUCKETS + if not bucket_url.startswith("memory") +] + +GLOB_RESULTS = [ + { + "glob": None, + "relative_paths": ["sample.txt"], + }, + { + "glob": "*/*", + "relative_paths": [ + "csv/freshman_kgs.csv", + "csv/freshman_lbs.csv", + "csv/mlb_players.csv", + "csv/mlb_teams_2012.csv", + "gzip/taxi.csv.gz", + "jsonl/mlb_players.jsonl", + "parquet/mlb_players.parquet", + ], + }, + { + "glob": "**/*.csv", + "relative_paths": [ + "csv/freshman_kgs.csv", + "csv/freshman_lbs.csv", + "csv/mlb_players.csv", + "csv/mlb_teams_2012.csv", + "met_csv/A801/A881_20230920.csv", + "met_csv/A803/A803_20230919.csv", + "met_csv/A803/A803_20230920.csv", + ], + }, + { + "glob": "*/*.csv", + "relative_paths": [ + "csv/freshman_kgs.csv", + "csv/freshman_lbs.csv", + "csv/mlb_players.csv", + "csv/mlb_teams_2012.csv", + ], + }, + { + "glob": "csv/*", + "relative_paths": [ + "csv/freshman_kgs.csv", + "csv/freshman_lbs.csv", + "csv/mlb_players.csv", + "csv/mlb_teams_2012.csv", + ], + }, + { + "glob": "csv/mlb*", + "relative_paths": [ + "csv/mlb_players.csv", + "csv/mlb_teams_2012.csv", + ], + }, + { + "glob": "*", + "relative_paths": ["sample.txt"], + }, +] diff --git a/tests/load/sources/filesystem/test_filesystem_source.py b/tests/load/sources/filesystem/test_filesystem_source.py new file mode 100644 index 0000000000..947e7e9e1c --- /dev/null +++ b/tests/load/sources/filesystem/test_filesystem_source.py @@ -0,0 +1,266 @@ +import os +from typing import Any, Dict, List + +import dlt +import pytest +from dlt.common import pendulum + +from dlt.common.storages import fsspec_filesystem +from dlt.sources.filesystem import filesystem, readers, FileItem, FileItemDict, read_csv +from dlt.sources.filesystem.helpers import fsspec_from_resource + +from tests.common.storages.utils import TEST_SAMPLE_FILES +from tests.load.utils import DestinationTestConfiguration, destinations_configs +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, + assert_query_data, +) +from tests.utils import TEST_STORAGE_ROOT + +from tests.load.sources.filesystem.cases import GLOB_RESULTS, TESTS_BUCKET_URLS + + +@pytest.fixture(autouse=True) +def glob_test_setup() -> None: + file_fs, _ = fsspec_filesystem("file") + file_path = os.path.join(TEST_STORAGE_ROOT, "standard_source") + if not file_fs.isdir(file_path): + file_fs.mkdirs(file_path) + file_fs.upload(TEST_SAMPLE_FILES, file_path, recursive=True) + + +@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +@pytest.mark.parametrize("glob_params", GLOB_RESULTS) +def test_file_list(bucket_url: str, glob_params: Dict[str, Any]) -> None: + @dlt.transformer + def bypass(items) -> str: + return items + + # we just pass the glob parameter to the resource if it is not None + if file_glob := glob_params["glob"]: + filesystem_res = filesystem(bucket_url=bucket_url, file_glob=file_glob) | bypass + else: + filesystem_res = filesystem(bucket_url=bucket_url) | bypass + + all_files = list(filesystem_res) + file_count = len(all_files) + relative_paths = [item["relative_path"] for item in all_files] + assert file_count == len(glob_params["relative_paths"]) + assert set(relative_paths) == set(glob_params["relative_paths"]) + + +@pytest.mark.parametrize("extract_content", [True, False]) +@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +def test_load_content_resources(bucket_url: str, extract_content: bool) -> None: + @dlt.transformer + def assert_sample_content(items: List[FileItemDict]): + # expect just one file + for item in items: + assert item["file_name"] == "sample.txt" + content = item.read_bytes() + assert content == b"dlthub content" + assert item["size_in_bytes"] == 14 + assert item["file_url"].endswith("/samples/sample.txt") + assert item["mime_type"] == "text/plain" + assert isinstance(item["modification_date"], pendulum.DateTime) + + yield items + + # use transformer to test files + sample_file = ( + filesystem( + bucket_url=bucket_url, + file_glob="sample.txt", + extract_content=extract_content, + ) + | assert_sample_content + ) + # just execute iterator + files = list(sample_file) + assert len(files) == 1 + + # take file from nested dir + # use map function to assert + def assert_csv_file(item: FileItem): + # on windows when checking out, git will convert lf into cr+lf so we have more bytes (+ number of lines: 25) + assert item["size_in_bytes"] in (742, 767) + assert item["relative_path"] == "met_csv/A801/A881_20230920.csv" + assert item["file_url"].endswith("/samples/met_csv/A801/A881_20230920.csv") + assert item["mime_type"] == "text/csv" + # print(item) + return item + + nested_file = filesystem(bucket_url, file_glob="met_csv/A801/A881_20230920.csv") + + assert len(list(nested_file | assert_csv_file)) == 1 + + +@pytest.mark.skip("Needs secrets toml to work..") +def test_fsspec_as_credentials(): + # get gs filesystem + gs_resource = filesystem("gs://ci-test-bucket") + # get authenticated client + fs_client = fsspec_from_resource(gs_resource) + print(fs_client.ls("ci-test-bucket/standard_source/samples")) + # use to create resource instead of credentials + gs_resource = filesystem("gs://ci-test-bucket/standard_source/samples", credentials=fs_client) + print(list(gs_resource)) + + +@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_csv_transformers( + bucket_url: str, destination_config: DestinationTestConfiguration +) -> None: + pipeline = destination_config.setup_pipeline("test_csv_transformers", dev_mode=True) + # load all csvs merging data on a date column + met_files = filesystem(bucket_url=bucket_url, file_glob="met_csv/A801/*.csv") | read_csv() + met_files.apply_hints(write_disposition="merge", merge_key="date") + load_info = pipeline.run(met_files.with_name("met_csv")) + assert_load_info(load_info) + + # print(pipeline.last_trace.last_normalize_info) + # must contain 24 rows of A881 + if not destination_config.destination == "filesystem": + # TODO: comment out when filesystem destination supports queries (data pond PR) + assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24) + + # load the other folder that contains data for the same day + one other day + # the previous data will be replaced + met_files = filesystem(bucket_url=bucket_url, file_glob="met_csv/A803/*.csv") | read_csv() + met_files.apply_hints(write_disposition="merge", merge_key="date") + load_info = pipeline.run(met_files.with_name("met_csv")) + assert_load_info(load_info) + # print(pipeline.last_trace.last_normalize_info) + # must contain 48 rows of A803 + if not destination_config.destination == "filesystem": + # TODO: comment out when filesystem destination supports queries (data pond PR) + assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48) + # and 48 rows in total -> A881 got replaced + # print(pipeline.default_schema.to_pretty_yaml()) + assert load_table_counts(pipeline, "met_csv") == {"met_csv": 48} + + +@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_standard_readers( + bucket_url: str, destination_config: DestinationTestConfiguration +) -> None: + # extract pipes with standard readers + jsonl_reader = readers(bucket_url, file_glob="**/*.jsonl").read_jsonl() + parquet_reader = readers(bucket_url, file_glob="**/*.parquet").read_parquet() + # also read zipped csvs + csv_reader = readers(bucket_url, file_glob="**/*.csv*").read_csv(float_precision="high") + csv_duckdb_reader = readers(bucket_url, file_glob="**/*.csv*").read_csv_duckdb() + + # a step that copies files into test storage + def _copy(item: FileItemDict): + # instantiate fsspec and copy file + dest_file = os.path.join(TEST_STORAGE_ROOT, item["relative_path"]) + # create dest folder + os.makedirs(os.path.dirname(dest_file), exist_ok=True) + # download file + item.fsspec.download(item["file_url"], dest_file) + # return file item unchanged + return item + + downloader = filesystem(bucket_url, file_glob="**").add_map(_copy) + + # load in single pipeline + pipeline = destination_config.setup_pipeline("test_standard_readers", dev_mode=True) + load_info = pipeline.run( + [ + jsonl_reader.with_name("jsonl_example"), + parquet_reader.with_name("parquet_example"), + downloader.with_name("listing"), + csv_reader.with_name("csv_example"), + csv_duckdb_reader.with_name("csv_duckdb_example"), + ] + ) + # pandas incorrectly guesses that taxi dataset has headers so it skips one row + # so we have 1 less row in csv_example than in csv_duckdb_example + assert_load_info(load_info) + assert load_table_counts( + pipeline, + "jsonl_example", + "parquet_example", + "listing", + "csv_example", + "csv_duckdb_example", + ) == { + "jsonl_example": 1034, + "parquet_example": 1034, + "listing": 11, + "csv_example": 1279, + "csv_duckdb_example": 1281, # TODO: i changed this from 1280, what is going on? :) + } + # print(pipeline.last_trace.last_normalize_info) + # print(pipeline.default_schema.to_pretty_yaml()) + + +@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_incremental_load( + bucket_url: str, destination_config: DestinationTestConfiguration +) -> None: + @dlt.transformer + def bypass(items) -> str: + return items + + pipeline = destination_config.setup_pipeline("test_incremental_load", dev_mode=True) + + # Load all files + all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + # add incremental on modification time + all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) + load_info = pipeline.run((all_files | bypass).with_name("csv_files")) + assert_load_info(load_info) + assert pipeline.last_trace.last_normalize_info.row_counts["csv_files"] == 4 + + table_counts = load_table_counts(pipeline, "csv_files") + assert table_counts["csv_files"] == 4 + + # load again + all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) + load_info = pipeline.run((all_files | bypass).with_name("csv_files")) + # nothing into csv_files + assert "csv_files" not in pipeline.last_trace.last_normalize_info.row_counts + table_counts = load_table_counts(pipeline, "csv_files") + assert table_counts["csv_files"] == 4 + + # load again into different table + all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) + load_info = pipeline.run((all_files | bypass).with_name("csv_files_2")) + assert_load_info(load_info) + assert pipeline.last_trace.last_normalize_info.row_counts["csv_files_2"] == 4 + + +def test_file_chunking() -> None: + resource = filesystem( + bucket_url=TESTS_BUCKET_URLS[0], + file_glob="*/*.csv", + files_per_page=2, + ) + + from dlt.extract.pipe_iterator import PipeIterator + + # use pipe iterator to get items as they go through pipe + for pipe_item in PipeIterator.from_pipe(resource._pipe): + assert len(pipe_item.item) == 2 + # no need to test more chunks + break diff --git a/tests/load/sources/rest_api/__init__.py b/tests/load/sources/rest_api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sources/rest_api/test_rest_api_source.py b/tests/load/sources/rest_api/test_rest_api_source.py new file mode 100644 index 0000000000..25a9952ba4 --- /dev/null +++ b/tests/load/sources/rest_api/test_rest_api_source.py @@ -0,0 +1,128 @@ +from typing import Any +import dlt +import pytest +from dlt.sources.rest_api.typing import RESTAPIConfig +from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator + +from dlt.sources.rest_api import rest_api_source +from tests.pipeline.utils import assert_load_info, load_table_counts +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, +) + + +def _make_pipeline(destination_name: str): + return dlt.pipeline( + pipeline_name="rest_api", + destination=destination_name, + dataset_name="rest_api_data", + full_refresh=True, + ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_rest_api_source(destination_config: DestinationTestConfiguration, request: Any) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + } + }, + "resources": [ + { + "name": "pokemon_list", + "endpoint": "pokemon", + }, + "berry", + "location", + ], + } + data = rest_api_source(config) + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(data) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"pokemon_list", "berry", "location"} + + assert table_counts["pokemon_list"] == 1302 + assert table_counts["berry"] == 64 + assert table_counts["location"] == 1036 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_dependent_resource(destination_config: DestinationTestConfiguration, request: Any) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + } + }, + "resources": [ + { + "name": "pokemon_list", + "endpoint": { + "path": "pokemon", + "paginator": SinglePagePaginator(), + "data_selector": "results", + "params": { + "limit": 2, + }, + }, + "selected": False, + }, + { + "name": "pokemon", + "endpoint": { + "path": "pokemon/{name}", + "params": { + "name": { + "type": "resolve", + "resource": "pokemon_list", + "field": "name", + }, + }, + }, + }, + ], + } + + data = rest_api_source(config) + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(data) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert set(table_counts.keys()) == { + "pokemon", + "pokemon__types", + "pokemon__stats", + "pokemon__moves__version_group_details", + "pokemon__moves", + "pokemon__game_indices", + "pokemon__forms", + "pokemon__abilities", + } + + assert table_counts["pokemon"] == 2 diff --git a/tests/load/sources/sql_database/__init__.py b/tests/load/sources/sql_database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sources/sql_database/conftest.py b/tests/load/sources/sql_database/conftest.py new file mode 100644 index 0000000000..e70467e714 --- /dev/null +++ b/tests/load/sources/sql_database/conftest.py @@ -0,0 +1,40 @@ +from typing import Iterator, Any + +import pytest + +import dlt +from dlt.sources.credentials import ConnectionStringCredentials + +try: + from tests.load.sources.sql_database.sql_source import SQLAlchemySourceDB +except ModuleNotFoundError: + SQLAlchemySourceDB = Any # type: ignore + + +def _create_db(**kwargs) -> Iterator[SQLAlchemySourceDB]: + # TODO: parametrize the fixture so it takes the credentials for all destinations + credentials = dlt.secrets.get( + "destination.postgres.credentials", expected_type=ConnectionStringCredentials + ) + + db = SQLAlchemySourceDB(credentials, **kwargs) + db.create_schema() + try: + db.create_tables() + db.insert_data() + yield db + finally: + db.drop_schema() + + +@pytest.fixture(scope="package") +def sql_source_db(request: pytest.FixtureRequest) -> Iterator[SQLAlchemySourceDB]: + # Without unsupported types so we can test full schema load with connector-x + yield from _create_db(with_unsupported_types=False) + + +@pytest.fixture(scope="package") +def sql_source_db_unsupported_types( + request: pytest.FixtureRequest, +) -> Iterator[SQLAlchemySourceDB]: + yield from _create_db(with_unsupported_types=True) diff --git a/tests/load/sources/sql_database/sql_source.py b/tests/load/sources/sql_database/sql_source.py new file mode 100644 index 0000000000..43ce5406d2 --- /dev/null +++ b/tests/load/sources/sql_database/sql_source.py @@ -0,0 +1,373 @@ +import random +from copy import deepcopy +from typing import Dict, List, TypedDict +from uuid import uuid4 + +import mimesis + + +from sqlalchemy import ( + ARRAY, + BigInteger, + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + MetaData, + Numeric, + SmallInteger, + String, + Table, + Text, + Time, + create_engine, + func, + text, +) + +try: + from sqlalchemy import Uuid # type: ignore[attr-defined] +except ImportError: + # sql alchemy 1.4 + Uuid = String + +from sqlalchemy import ( + schema as sqla_schema, +) + +from sqlalchemy.dialects.postgresql import DATERANGE, JSONB + +from dlt.common.pendulum import pendulum, timedelta +from dlt.common.utils import chunks, uniq_id +from dlt.sources.credentials import ConnectionStringCredentials + + +class SQLAlchemySourceDB: + def __init__( + self, + credentials: ConnectionStringCredentials, + schema: str = None, + with_unsupported_types: bool = False, + ) -> None: + self.credentials = credentials + self.database_url = credentials.to_native_representation() + self.schema = schema or "my_dlt_source" + uniq_id() + self.engine = create_engine(self.database_url) + self.metadata = MetaData(schema=self.schema) + self.table_infos: Dict[str, TableInfo] = {} + self.with_unsupported_types = with_unsupported_types + + def create_schema(self) -> None: + with self.engine.begin() as conn: + conn.execute(sqla_schema.CreateSchema(self.schema, if_not_exists=True)) + + def drop_schema(self) -> None: + with self.engine.begin() as conn: + conn.execute(sqla_schema.DropSchema(self.schema, cascade=True, if_exists=True)) + + def get_table(self, name: str) -> Table: + return self.metadata.tables[f"{self.schema}.{name}"] + + def create_tables(self) -> None: + Table( + "app_user", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column("email", Text(), nullable=False, unique=True), + Column("display_name", Text(), nullable=False), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "chat_channel", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column("name", Text(), nullable=False), + Column("active", Boolean(), nullable=False, server_default=text("true")), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "chat_message", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column("content", Text(), nullable=False), + Column( + "user_id", + Integer(), + ForeignKey("app_user.id"), + nullable=False, + index=True, + ), + Column( + "channel_id", + Integer(), + ForeignKey("chat_channel.id"), + nullable=False, + index=True, + ), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "has_composite_key", + self.metadata, + Column("a", Integer(), primary_key=True), + Column("b", Integer(), primary_key=True), + Column("c", Integer(), primary_key=True), + ) + + def _make_precision_table(table_name: str, nullable: bool) -> None: + Table( + table_name, + self.metadata, + Column("int_col", Integer(), nullable=nullable), + Column("bigint_col", BigInteger(), nullable=nullable), + Column("smallint_col", SmallInteger(), nullable=nullable), + Column("numeric_col", Numeric(precision=10, scale=2), nullable=nullable), + Column("numeric_default_col", Numeric(), nullable=nullable), + Column("string_col", String(length=10), nullable=nullable), + Column("string_default_col", String(), nullable=nullable), + Column("datetime_tz_col", DateTime(timezone=True), nullable=nullable), + Column("datetime_ntz_col", DateTime(timezone=False), nullable=nullable), + Column("date_col", Date, nullable=nullable), + Column("time_col", Time, nullable=nullable), + Column("float_col", Float, nullable=nullable), + Column("json_col", JSONB, nullable=nullable), + Column("bool_col", Boolean, nullable=nullable), + ) + + _make_precision_table("has_precision", False) + _make_precision_table("has_precision_nullable", True) + + if self.with_unsupported_types: + Table( + "has_unsupported_types", + self.metadata, + # Column("unsupported_daterange_1", DATERANGE, nullable=False), + Column("supported_text", Text, nullable=False), + Column("supported_int", Integer, nullable=False), + Column("unsupported_array_1", ARRAY(Integer), nullable=False), + # Column("supported_datetime", DateTime(timezone=True), nullable=False), + ) + + self.metadata.create_all(bind=self.engine) + + # Create a view + q = f""" + CREATE VIEW {self.schema}.chat_message_view AS + SELECT + cm.id, + cm.content, + cm.created_at as _created_at, + cm.updated_at as _updated_at, + au.email as user_email, + au.display_name as user_display_name, + cc.name as channel_name, + CAST(NULL as TIMESTAMP) as _null_ts + FROM {self.schema}.chat_message cm + JOIN {self.schema}.app_user au ON cm.user_id = au.id + JOIN {self.schema}.chat_channel cc ON cm.channel_id = cc.id + """ + with self.engine.begin() as conn: + conn.execute(text(q)) + + def _fake_users(self, n: int = 8594) -> List[int]: + person = mimesis.Person() + user_ids: List[int] = [] + table = self.metadata.tables[f"{self.schema}.app_user"] + info = self.table_infos.setdefault( + "app_user", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + email=person.email(unique=True), + display_name=person.name(), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + user_ids.extend(result.scalars()) + info["row_count"] += n + info["ids"] += user_ids + return user_ids + + def _fake_channels(self, n: int = 500) -> List[int]: + _text = mimesis.Text() + dev = mimesis.Development() + table = self.metadata.tables[f"{self.schema}.chat_channel"] + channel_ids: List[int] = [] + info = self.table_infos.setdefault( + "chat_channel", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + name=" ".join(_text.words()), + active=dev.boolean(), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + channel_ids.extend(result.scalars()) + info["row_count"] += n + info["ids"] += channel_ids + return channel_ids + + def fake_messages(self, n: int = 9402) -> List[int]: + user_ids = self.table_infos["app_user"]["ids"] + channel_ids = self.table_infos["chat_channel"]["ids"] + _text = mimesis.Text() + choice = mimesis.Choice() + table = self.metadata.tables[f"{self.schema}.chat_message"] + message_ids: List[int] = [] + info = self.table_infos.setdefault( + "chat_message", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + content=_text.random.choice(_text.extract(["questions"])), + user_id=choice(user_ids), + channel_id=choice(channel_ids), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + message_ids.extend(result.scalars()) + info["row_count"] += len(message_ids) + info["ids"].extend(message_ids) + # View is the same number of rows as the table + view_info = deepcopy(info) + view_info["is_view"] = True + view_info = self.table_infos.setdefault("chat_message_view", view_info) + view_info["row_count"] = info["row_count"] + view_info["ids"] = info["ids"] + return message_ids + + def _fake_precision_data(self, table_name: str, n: int = 100, null_n: int = 0) -> None: + table = self.metadata.tables[f"{self.schema}.{table_name}"] + self.table_infos.setdefault(table_name, dict(row_count=n + null_n, is_view=False)) # type: ignore[call-overload] + rows = [ + dict( + int_col=random.randrange(-2147483648, 2147483647), + bigint_col=random.randrange(-9223372036854775808, 9223372036854775807), + smallint_col=random.randrange(-32768, 32767), + numeric_col=random.randrange(-9999999999, 9999999999) / 100, + numeric_default_col=random.randrange(-9999999999, 9999999999) / 100, + string_col=mimesis.Text().word()[:10], + string_default_col=mimesis.Text().word(), + datetime_tz_col=mimesis.Datetime().datetime(timezone="UTC"), + datetime_ntz_col=mimesis.Datetime().datetime(), # no timezone + date_col=mimesis.Datetime().date(), + time_col=mimesis.Datetime().time(), + float_col=random.random(), + json_col='{"data": [1, 2, 3]}', # NOTE: can we do this? + bool_col=random.randint(0, 1) == 1, + ) + for _ in range(n + null_n) + ] + for row in rows[n:]: + # all fields to None + for field in row: + row[field] = None + with self.engine.begin() as conn: + conn.execute(table.insert().values(rows)) + + def _fake_chat_data(self, n: int = 9402) -> None: + self._fake_users() + self._fake_channels() + self.fake_messages() + + def _fake_unsupported_data(self, n: int = 100) -> None: + table = self.metadata.tables[f"{self.schema}.has_unsupported_types"] + self.table_infos.setdefault("has_unsupported_types", dict(row_count=n, is_view=False)) # type: ignore[call-overload] + rows = [ + dict( + # unsupported_daterange_1="[2020-01-01, 2020-09-01]", + supported_text=mimesis.Text().word(), + supported_int=random.randint(0, 100), + unsupported_array_1=[1, 2, 3], + # supported_datetime="2015-08-12T01:25:22.468126+0100", + ) + for _ in range(n) + ] + with self.engine.begin() as conn: + conn.execute(table.insert().values(rows)) + + def insert_data(self) -> None: + self._fake_chat_data() + self._fake_precision_data("has_precision") + self._fake_precision_data("has_precision_nullable", null_n=10) + if self.with_unsupported_types: + self._fake_unsupported_data() + + +class IncrementingDate: + def __init__(self, start_value: pendulum.DateTime = None) -> None: + self.started = False + self.start_value = start_value or pendulum.now() + self.current_value = self.start_value + + def __next__(self) -> pendulum.DateTime: + if not self.started: + self.started = True + return self.current_value + self.current_value += timedelta(seconds=random.randrange(0, 120)) + return self.current_value + + +class TableInfo(TypedDict): + row_count: int + ids: List[int] + created_at: IncrementingDate + is_view: bool diff --git a/tests/load/sources/sql_database/test_helpers.py b/tests/load/sources/sql_database/test_helpers.py new file mode 100644 index 0000000000..cc88fc0080 --- /dev/null +++ b/tests/load/sources/sql_database/test_helpers.py @@ -0,0 +1,173 @@ +import pytest + +import dlt +from dlt.common.typing import TDataItem + + +from dlt.common.exceptions import MissingDependencyException + +try: + from dlt.sources.sql_database.helpers import TableLoader, TableBackend + from dlt.sources.sql_database.schema_types import table_to_columns + from tests.load.sources.sql_database.sql_source import SQLAlchemySourceDB +except MissingDependencyException: + pytest.skip("Tests require sql alchemy", allow_module_level=True) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_cursor_or_unique_column_not_in_table( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + table = sql_source_db.get_table("chat_message") + + with pytest.raises(KeyError): + TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=dlt.sources.incremental("not_a_column"), + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_max( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """Verify query is generated according to incremental settings""" + + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = max + cursor_path = "created_at" + row_order = "asc" + end_value = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .order_by(table.c.created_at.asc()) + .where(table.c.created_at >= MockIncremental.last_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_min( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = min + cursor_path = "created_at" + row_order = "desc" + end_value = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .order_by(table.c.created_at.asc()) # `min` func swaps order + .where(table.c.created_at <= MockIncremental.last_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_end_value( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + now = dlt.common.pendulum.now() + + class MockIncremental: + last_value = now + last_value_func = min + cursor_path = "created_at" + end_value = now.add(hours=1) + row_order = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .where(table.c.created_at <= MockIncremental.last_value) + .where(table.c.created_at > MockIncremental.end_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_any_fun( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = lambda x: x[-1] + cursor_path = "created_at" + row_order = "asc" + end_value = dlt.common.pendulum.now() + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = table.select() + + assert query.compare(expected) + + +def mock_json_column(field: str) -> TDataItem: + """""" + import pyarrow as pa + import pandas as pd + + json_mock_str = '{"data": [1, 2, 3]}' + + def _unwrap(table: TDataItem) -> TDataItem: + if isinstance(table, pd.DataFrame): + table[field] = [None if s is None else json_mock_str for s in table[field]] + return table + else: + col_index = table.column_names.index(field) + json_str_array = pa.array([None if s is None else json_mock_str for s in table[field]]) + return table.set_column( + col_index, + pa.field(field, pa.string(), nullable=table.schema.field(field).nullable), + json_str_array, + ) + + return _unwrap diff --git a/tests/load/sources/sql_database/test_sql_database_source.py b/tests/load/sources/sql_database/test_sql_database_source.py new file mode 100644 index 0000000000..58382877ee --- /dev/null +++ b/tests/load/sources/sql_database/test_sql_database_source.py @@ -0,0 +1,1188 @@ +import os +import re +from copy import deepcopy +from datetime import datetime # noqa: I251 +from typing import Any, Callable, cast, List, Optional, Set + +import pytest + +import dlt +from dlt.common import json +from dlt.common.configuration.exceptions import ConfigFieldMissingException +from dlt.common.exceptions import MissingDependencyException + +from dlt.common.schema.typing import TColumnSchema, TSortOrder, TTableSchemaColumns +from dlt.common.utils import uniq_id +from dlt.extract.exceptions import ResourceExtractionError + +from dlt.sources import DltResource + +from tests.pipeline.utils import ( + assert_load_info, + assert_schema_on_data, + load_tables_to_dicts, +) +from tests.load.sources.sql_database.test_helpers import mock_json_column +from tests.utils import data_item_length + + +try: + from dlt.sources.sql_database import ( + ReflectionLevel, + TableBackend, + sql_database, + sql_table, + ) + from dlt.sources.sql_database.helpers import unwrap_json_connector_x + from tests.load.sources.sql_database.sql_source import SQLAlchemySourceDB + import sqlalchemy as sa +except MissingDependencyException: + pytest.skip("Tests require sql alchemy", allow_module_level=True) + + +@pytest.fixture(autouse=True) +def dispose_engines(): + yield + import gc + + # will collect and dispose all hanging engines + gc.collect() + + +@pytest.fixture(autouse=True) +def reset_os_environ(): + # Save the current state of os.environ + original_environ = deepcopy(os.environ) + yield + # Restore the original state of os.environ + os.environ.clear() + os.environ.update(original_environ) + + +def make_pipeline(destination_name: str) -> dlt.Pipeline: + return dlt.pipeline( + pipeline_name="sql_database" + uniq_id(), + destination=destination_name, + dataset_name="test_sql_pipeline_" + uniq_id(), + full_refresh=False, + ) + + +def convert_json_to_text(t): + if isinstance(t, sa.JSON): + return sa.Text + return t + + +def default_test_callback( + destination_name: str, backend: TableBackend +) -> Optional[Callable[[sa.types.TypeEngine], sa.types.TypeEngine]]: + if backend == "pyarrow" and destination_name == "bigquery": + return convert_json_to_text + return None + + +def convert_time_to_us(table): + """map transform converting time column to microseconds (ie. from nanoseconds)""" + import pyarrow as pa + from pyarrow import compute as pc + + time_ns_column = table["time_col"] + time_us_column = pc.cast(time_ns_column, pa.time64("us"), safe=False) + new_table = table.set_column( + table.column_names.index("time_col"), + "time_col", + time_us_column, + ) + return new_table + + +def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None: + # verify database + database = sql_database( + sql_source_db.engine, schema=sql_source_db.schema, table_names=["chat_message"] + ) + assert len(list(database)) == sql_source_db.table_infos["chat_message"]["row_count"] + + # verify table + table = sql_table(sql_source_db.engine, table="chat_message", schema=sql_source_db.schema) + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + + +def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: + # set the credentials per table name + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = ( + sql_source_db.engine.url.render_as_string(False) + ) + table = sql_table(table="chat_message", schema=sql_source_db.schema) + assert table.name == "chat_message" + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + + with pytest.raises(ConfigFieldMissingException): + sql_table(table="has_composite_key", schema=sql_source_db.schema) + + # set backend + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__BACKEND"] = "pandas" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # just one frame here + assert len(list(table)) == 1 + + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CHUNK_SIZE"] = "1000" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # now 10 frames with chunk size of 1000 + assert len(list(table)) == 10 + + # make it fail on cursor + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at_x" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + with pytest.raises(ResourceExtractionError) as ext_ex: + len(list(table)) + assert "'updated_at_x'" in str(ext_ex.value) + + +def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: + # set the credentials per table name + os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = sql_source_db.engine.url.render_as_string( + False + ) + # applies to both sql table and sql database + table = sql_table(table="chat_message", schema=sql_source_db.schema) + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == sql_source_db.table_infos["chat_message"]["row_count"] + + # set backend + os.environ["SOURCES__SQL_DATABASE__BACKEND"] = "pandas" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # just one frame here + assert len(list(table)) == 1 + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == 1 + + os.environ["SOURCES__SQL_DATABASE__CHUNK_SIZE"] = "1000" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # now 10 frames with chunk size of 1000 + assert len(list(table)) == 10 + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == 10 + + # make it fail on cursor + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at_x" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + with pytest.raises(ResourceExtractionError) as ext_ex: + len(list(table)) + assert "'updated_at_x'" in str(ext_ex.value) + with pytest.raises(ResourceExtractionError) as ext_ex: + list(sql_database(schema=sql_source_db.schema).with_resources("chat_message")) + # other resources will be loaded, incremental is selective + assert len(list(sql_database(schema=sql_source_db.schema).with_resources("app_user"))) > 0 + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +@pytest.mark.parametrize("row_order", ["asc", "desc", None]) +@pytest.mark.parametrize("last_value_func", [min, max, lambda x: max(x)]) +def test_load_sql_table_resource_incremental_end_value( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + row_order: TSortOrder, + last_value_func: Any, +) -> None: + start_id = sql_source_db.table_infos["chat_message"]["ids"][0] + end_id = sql_source_db.table_infos["chat_message"]["ids"][-1] // 2 + + if last_value_func is min: + start_id, end_id = end_id, start_id + + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + backend=backend, + incremental=dlt.sources.incremental( + "id", + initial_value=start_id, + end_value=end_id, + row_order=row_order, + last_value_func=last_value_func, + ), + ) + ] + + try: + rows = list(sql_table_source()) + except Exception as exc: + if isinstance(exc.__context__, NotImplementedError): + pytest.skip("Test skipped due to: " + str(exc.__context__)) + raise + # half of the records loaded -1 record. end values is non inclusive + assert data_item_length(rows) == abs(end_id - start_id) + # check first and last id to see if order was applied + if backend == "sqlalchemy": + if row_order == "asc" and last_value_func is max: + assert rows[0]["id"] == start_id + assert rows[-1]["id"] == end_id - 1 # non inclusive + if row_order == "desc" and last_value_func is max: + assert rows[0]["id"] == end_id - 1 # non inclusive + assert rows[-1]["id"] == start_id + if row_order == "asc" and last_value_func is min: + assert rows[0]["id"] == start_id + assert ( + rows[-1]["id"] == end_id + 1 + ) # non inclusive, but + 1 because last value func is min + if row_order == "desc" and last_value_func is min: + assert ( + rows[0]["id"] == end_id + 1 + ) # non inclusive, but + 1 because last value func is min + assert rows[-1]["id"] == start_id + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_load_sql_table_resource_select_columns( + sql_source_db: SQLAlchemySourceDB, defer_table_reflect: bool, backend: TableBackend +) -> None: + # get chat messages with content column removed + chat_messages = sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + defer_table_reflect=defer_table_reflect, + table_adapter_callback=lambda table: table._columns.remove(table.columns["content"]), # type: ignore[attr-defined] + backend=backend, + ) + pipeline = make_pipeline("duckdb") + load_info = pipeline.run(chat_messages) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + assert "content" not in pipeline.default_schema.tables["chat_message"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_load_sql_table_source_select_columns( + sql_source_db: SQLAlchemySourceDB, defer_table_reflect: bool, backend: TableBackend +) -> None: + mod_tables: Set[str] = set() + + def adapt(table) -> None: + mod_tables.add(table) + if table.name == "chat_message": + table._columns.remove(table.columns["content"]) + + # get chat messages with content column removed + all_tables = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + defer_table_reflect=defer_table_reflect, + table_names=(list(sql_source_db.table_infos.keys()) if defer_table_reflect else None), + table_adapter_callback=adapt, + backend=backend, + ) + pipeline = make_pipeline("duckdb") + load_info = pipeline.run(all_tables) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db) + assert "content" not in pipeline.default_schema.tables["chat_message"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("reflection_level", ["full", "full_with_precision"]) +@pytest.mark.parametrize("with_defer", [True, False]) +def test_extract_without_pipeline( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + with_defer: bool, +) -> None: + # make sure that we can evaluate tables without pipeline + source = sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "app_user", "chat_message", "chat_channel"], + schema=sql_source_db.schema, + reflection_level=reflection_level, + defer_table_reflect=with_defer, + backend=backend, + ) + assert len(list(source)) > 0 + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"]) +@pytest.mark.parametrize("with_defer", [False, True]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_reflection_levels( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + with_defer: bool, + standalone_resource: bool, +) -> None: + """Test all reflection, correct schema is inferred""" + + def prepare_source(): + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="has_precision", + backend=backend, + defer_table_reflect=with_defer, + reflection_level=reflection_level, + ) + yield sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="app_user", + backend=backend, + defer_table_reflect=with_defer, + reflection_level=reflection_level, + ) + + return dummy_source() + + return sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "app_user"], + schema=sql_source_db.schema, + reflection_level=reflection_level, + defer_table_reflect=with_defer, + backend=backend, + ) + + source = prepare_source() + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + schema = pipeline.default_schema + assert "has_precision" in schema.tables + + col_names = [col["name"] for col in schema.tables["has_precision"]["columns"].values()] + expected_col_names = [col["name"] for col in PRECISION_COLUMNS] + + # on sqlalchemy json col is not written to schema if no types are discovered + if backend == "sqlalchemy" and reflection_level == "minimal" and not with_defer: + expected_col_names = [col for col in expected_col_names if col != "json_col"] + + assert col_names == expected_col_names + + # Pk col is always reflected + pk_col = schema.tables["app_user"]["columns"]["id"] + assert pk_col["primary_key"] is True + + if reflection_level == "minimal": + resource_cols = source.resources["has_precision"].compute_table_schema()["columns"] + schema_cols = pipeline.default_schema.tables["has_precision"]["columns"] + # We should have all column names on resource hints after extract but no data type or precision + for col, schema_col in zip(resource_cols.values(), schema_cols.values()): + assert col.get("data_type") is None + assert col.get("precision") is None + assert col.get("scale") is None + if backend == "sqlalchemy": # Data types are inferred from pandas/arrow during extract + assert schema_col.get("data_type") is None + + pipeline.normalize() + # Check with/out precision after normalize + schema_cols = pipeline.default_schema.tables["has_precision"]["columns"] + if reflection_level == "full": + # Columns have data type set + assert_no_precision_columns(schema_cols, backend, False) + + elif reflection_level == "full_with_precision": + # Columns have data type and precision scale set + assert_precision_columns(schema_cols, backend, False) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_type_adapter_callback( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool +) -> None: + def conversion_callback(t): + if isinstance(t, sa.JSON): + return sa.Text + elif hasattr(sa, "Double") and isinstance(t, sa.Double): + return sa.BIGINT + return t + + common_kwargs = dict( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + type_adapter_callback=conversion_callback, + reflection_level="full", + ) + + if standalone_resource: + source = sql_table( + table="has_precision", + **common_kwargs, # type: ignore[arg-type] + ) + else: + source = sql_database( # type: ignore[assignment] + table_names=["has_precision"], + **common_kwargs, # type: ignore[arg-type] + ) + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + schema = pipeline.default_schema + table = schema.tables["has_precision"] + assert table["columns"]["json_col"]["data_type"] == "text" + assert ( + table["columns"]["float_col"]["data_type"] == "bigint" + if hasattr(sa, "Double") + else "double" + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize( + "table_name,nullable", (("has_precision", False), ("has_precision_nullable", True)) +) +def test_all_types_with_precision_hints( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + table_name: str, + nullable: bool, +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full_with_precision", + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + + # add JSON unwrap for connectorx + if backend == "connectorx": + source.resources[table_name].add_map(unwrap_json_connector_x("json_col")) + pipeline.extract(source) + pipeline.normalize(loader_file_format="parquet") + info = pipeline.load() + assert_load_info(info) + + schema = pipeline.default_schema + table = schema.tables[table_name] + assert_precision_columns(table["columns"], backend, nullable) + assert_schema_on_data( + table, + load_tables_to_dicts(pipeline, table_name)[table_name], + nullable, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize( + "table_name,nullable", (("has_precision", False), ("has_precision_nullable", True)) +) +def test_all_types_no_precision_hints( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + table_name: str, + nullable: bool, +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full", + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + + # add JSON unwrap for connectorx + if backend == "connectorx": + source.resources[table_name].add_map(unwrap_json_connector_x("json_col")) + pipeline.extract(source) + pipeline.normalize(loader_file_format="parquet") + pipeline.load().raise_on_failed_jobs() + + schema = pipeline.default_schema + # print(pipeline.default_schema.to_pretty_yaml()) + table = schema.tables[table_name] + assert_no_precision_columns(table["columns"], backend, nullable) + assert_schema_on_data( + table, + load_tables_to_dicts(pipeline, table_name)[table_name], + nullable, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_incremental_composite_primary_key_from_table( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, +) -> None: + resource = sql_table( + credentials=sql_source_db.credentials, + table="has_composite_key", + schema=sql_source_db.schema, + backend=backend, + ) + + assert resource.incremental.primary_key == ["a", "b", "c"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("upfront_incremental", (True, False)) +def test_set_primary_key_deferred_incremental( + sql_source_db: SQLAlchemySourceDB, + upfront_incremental: bool, + backend: TableBackend, +) -> None: + # this tests dynamically adds primary key to resource and as consequence to incremental + updated_at = dlt.sources.incremental("updated_at") # type: ignore[var-annotated] + resource = sql_table( + credentials=sql_source_db.credentials, + table="chat_message", + schema=sql_source_db.schema, + defer_table_reflect=True, + incremental=updated_at if upfront_incremental else None, + backend=backend, + ) + + resource.apply_hints(incremental=None if upfront_incremental else updated_at) + + # nothing set for deferred reflect + assert resource.incremental.primary_key is None + + def _assert_incremental(item): + # for all the items, all keys must be present + _r = dlt.current.source().resources[dlt.current.resource_name()] + # assert _r.incremental._incremental is updated_at + if len(item) == 0: + # not yet propagated + assert _r.incremental.primary_key is None + else: + assert _r.incremental.primary_key == ["id"] + assert _r.incremental._incremental.primary_key == ["id"] + assert _r.incremental._incremental._transformers["json"].primary_key == ["id"] + assert _r.incremental._incremental._transformers["arrow"].primary_key == ["id"] + return item + + pipeline = make_pipeline("duckdb") + # must evaluate resource for primary key to be set + pipeline.extract(resource.add_step(_assert_incremental)) # type: ignore[arg-type] + + assert resource.incremental.primary_key == ["id"] + assert resource.incremental._incremental.primary_key == ["id"] + assert resource.incremental._incremental._transformers["json"].primary_key == ["id"] + assert resource.incremental._incremental._transformers["arrow"].primary_key == ["id"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_in_source( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "chat_message"], + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + # mock the right json values for backends not supporting it + if backend in ("connectorx", "pandas"): + source.resources["has_precision"].add_map(mock_json_column("json_col")) + + # no columns in both tables + assert source.has_precision.columns == {} + assert source.chat_message.columns == {} + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + # use insert values to convert parquet into INSERT + pipeline.normalize(loader_file_format="insert_values") + pipeline.load().raise_on_failed_jobs() + precision_table = pipeline.default_schema.get_table("has_precision") + assert_precision_columns( + precision_table["columns"], + backend, + nullable=False, + ) + assert_schema_on_data( + precision_table, + load_tables_to_dicts(pipeline, "has_precision")["has_precision"], + True, + backend in ["sqlalchemy", "pyarrow"], + ) + assert len(source.chat_message.columns) > 0 # type: ignore[arg-type] + assert source.chat_message.compute_table_schema()["columns"]["id"]["primary_key"] is True + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_no_source_connect(backend: TableBackend) -> None: + source = sql_database( + credentials="mysql+pymysql://test@test/test", + table_names=["has_precision", "chat_message"], + schema="schema", + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + + # no columns in both tables + assert source.has_precision.columns == {} + assert source.chat_message.columns == {} + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_in_resource( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + table = sql_table( + credentials=sql_source_db.credentials, + table="has_precision", + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + # mock the right json values for backends not supporting it + if backend in ("connectorx", "pandas"): + table.add_map(mock_json_column("json_col")) + + # no columns in both tables + assert table.columns == {} + + pipeline = make_pipeline("duckdb") + pipeline.extract(table) + # use insert values to convert parquet into INSERT + pipeline.normalize(loader_file_format="insert_values") + pipeline.load().raise_on_failed_jobs() + precision_table = pipeline.default_schema.get_table("has_precision") + assert_precision_columns( + precision_table["columns"], + backend, + nullable=False, + ) + assert_schema_on_data( + precision_table, + load_tables_to_dicts(pipeline, "has_precision")["has_precision"], + True, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["pyarrow", "pandas", "connectorx"]) +def test_destination_caps_context(sql_source_db: SQLAlchemySourceDB, backend: TableBackend) -> None: + # use athena with timestamp precision == 3 + table = sql_table( + credentials=sql_source_db.credentials, + table="has_precision", + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + + # no columns in both tables + assert table.columns == {} + + pipeline = make_pipeline("athena") + pipeline.extract(table) + pipeline.normalize() + # timestamps are milliseconds + columns = pipeline.default_schema.get_table("has_precision")["columns"] + assert columns["datetime_tz_col"]["precision"] == columns["datetime_ntz_col"]["precision"] == 3 + # prevent drop + pipeline.destination = None + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_table_from_view(sql_source_db: SQLAlchemySourceDB, backend: TableBackend) -> None: + """View can be extract by sql_table without any reflect flags""" + table = sql_table( + credentials=sql_source_db.credentials, + table="chat_message_view", + schema=sql_source_db.schema, + backend=backend, + # use minimal level so we infer types from DATA + reflection_level="minimal", + incremental=dlt.sources.incremental("_created_at"), + ) + + pipeline = make_pipeline("duckdb") + info = pipeline.run(table) + assert_load_info(info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message_view"]) + assert "content" in pipeline.default_schema.tables["chat_message_view"]["columns"] + assert "_created_at" in pipeline.default_schema.tables["chat_message_view"]["columns"] + db_data = load_tables_to_dicts(pipeline, "chat_message_view")["chat_message_view"] + assert "content" in db_data[0] + assert "_created_at" in db_data[0] + # make sure that all NULLs is not present + assert "_null_ts" in pipeline.default_schema.tables["chat_message_view"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_database_include_views( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """include_view flag reflects and extracts views as tables""" + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + include_views=True, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + assert_row_counts(pipeline, sql_source_db, include_views=True) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_database_include_view_in_table_names( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """Passing a view explicitly in table_names should reflect it, regardless of include_views flag""" + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=["app_user", "chat_message_view"], + include_views=False, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + assert_row_counts(pipeline, sql_source_db, ["app_user", "chat_message_view"]) + + +@pytest.mark.parametrize("backend", ["pyarrow", "pandas", "sqlalchemy"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +@pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"]) +@pytest.mark.parametrize("type_adapter", [True, False]) +def test_infer_unsupported_types( + sql_source_db_unsupported_types: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + standalone_resource: bool, + type_adapter: bool, +) -> None: + def type_adapter_callback(t): + if isinstance(t, sa.ARRAY): + return sa.JSON + return t + + if backend == "pyarrow" and type_adapter: + pytest.skip("Arrow does not support type adapter for arrays") + + common_kwargs = dict( + credentials=sql_source_db_unsupported_types.credentials, + schema=sql_source_db_unsupported_types.schema, + reflection_level=reflection_level, + backend=backend, + type_adapter_callback=type_adapter_callback if type_adapter else None, + ) + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="has_unsupported_types", + ) + + source = dummy_source() + source.max_table_nesting = 0 + else: + source = sql_database( + **common_kwargs, # type: ignore[arg-type] + table_names=["has_unsupported_types"], + ) + source.max_table_nesting = 0 + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + columns = pipeline.default_schema.tables["has_unsupported_types"]["columns"] + + pipeline.normalize() + pipeline.load() + + assert_row_counts(pipeline, sql_source_db_unsupported_types, ["has_unsupported_types"]) + + schema = pipeline.default_schema + assert "has_unsupported_types" in schema.tables + columns = schema.tables["has_unsupported_types"]["columns"] + + rows = load_tables_to_dicts(pipeline, "has_unsupported_types")["has_unsupported_types"] + + if backend == "pyarrow": + # TODO: duckdb writes structs as strings (not json encoded) to json columns + # Just check that it has a value + + assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) + assert columns["unsupported_array_1"]["data_type"] == "complex" + # Other columns are loaded + assert isinstance(rows[0]["supported_text"], str) + assert isinstance(rows[0]["supported_int"], int) + elif backend == "sqlalchemy": + # sqla value is a dataclass and is inferred as complex + + assert columns["unsupported_array_1"]["data_type"] == "complex" + + elif backend == "pandas": + # pandas parses it as string + if type_adapter and reflection_level != "minimal": + assert columns["unsupported_array_1"]["data_type"] == "complex" + + assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_sql_database_included_columns( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool +) -> None: + # include only some columns from the table + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCLUDED_COLUMNS"] = json.dumps( + ["id", "created_at"] + ) + + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=["chat_message"], + reflection_level="full", + defer_table_reflect=defer_table_reflect, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + schema = pipeline.default_schema + schema_cols = set( + col + for col in schema.get_table_columns("chat_message", include_incomplete=True) + if not col.startswith("_dlt_") + ) + assert schema_cols == {"id", "created_at"} + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_sql_table_included_columns( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool +) -> None: + source = sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + reflection_level="full", + defer_table_reflect=defer_table_reflect, + backend=backend, + included_columns=["id", "created_at"], + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + schema = pipeline.default_schema + schema_cols = set( + col + for col in schema.get_table_columns("chat_message", include_incomplete=True) + if not col.startswith("_dlt_") + ) + assert schema_cols == {"id", "created_at"} + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_query_adapter_callback( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool +) -> None: + def query_adapter_callback(query, table): + if table.name == "chat_channel": + # Only select active channels + return query.where(table.c.active.is_(True)) + # Use the original query for other tables + return query + + common_kwargs = dict( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full", + backend=backend, + query_adapter_callback=query_adapter_callback, + ) + + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="chat_channel", + ) + + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="chat_message", + ) + + source = dummy_source() + else: + source = sql_database( + **common_kwargs, # type: ignore[arg-type] + table_names=["chat_message", "chat_channel"], + ) + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + pipeline.normalize() + pipeline.load().raise_on_failed_jobs() + + channel_rows = load_tables_to_dicts(pipeline, "chat_channel")["chat_channel"] + assert channel_rows and all(row["active"] for row in channel_rows) + + # unfiltred table loads all rows + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +def assert_row_counts( + pipeline: dlt.Pipeline, + sql_source_db: SQLAlchemySourceDB, + tables: Optional[List[str]] = None, + include_views: bool = False, +) -> None: + with pipeline.sql_client() as c: + if not tables: + tables = [ + tbl_name + for tbl_name, info in sql_source_db.table_infos.items() + if include_views or not info["is_view"] + ] + for table in tables: + info = sql_source_db.table_infos[table] + with c.execute_query(f"SELECT count(*) FROM {table}") as cur: + row = cur.fetchone() + assert row[0] == info["row_count"] + + +def assert_precision_columns( + columns: TTableSchemaColumns, backend: TableBackend, nullable: bool +) -> None: + actual = list(columns.values()) + expected = NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS + # always has nullability set and always has hints + expected = cast(List[TColumnSchema], deepcopy(expected)) + if backend == "sqlalchemy": + expected = remove_timestamp_precision(expected) + actual = remove_dlt_columns(actual) + if backend == "pyarrow": + expected = add_default_decimal_precision(expected) + if backend == "pandas": + expected = remove_timestamp_precision(expected, with_timestamps=False) + if backend == "connectorx": + # connector x emits 32 precision which gets merged with sql alchemy schema + del columns["int_col"]["precision"] + assert actual == expected + + +def assert_no_precision_columns( + columns: TTableSchemaColumns, backend: TableBackend, nullable: bool +) -> None: + actual = list(columns.values()) + # we always infer and emit nullability + expected = cast( + List[TColumnSchema], + deepcopy(NULL_NO_PRECISION_COLUMNS if nullable else NOT_NULL_NO_PRECISION_COLUMNS), + ) + if backend == "pyarrow": + expected = cast( + List[TColumnSchema], + deepcopy(NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS), + ) + # always has nullability set and always has hints + # default precision is not set + expected = remove_default_precision(expected) + expected = add_default_decimal_precision(expected) + elif backend == "sqlalchemy": + # no precision, no nullability, all hints inferred + # remove dlt columns + actual = remove_dlt_columns(actual) + elif backend == "pandas": + # no precision, no nullability, all hints inferred + # pandas destroys decimals + expected = convert_non_pandas_types(expected) + # on one of the timestamps somehow there is timezone info..., we only remove values set to false + # to be sure no bad data is coming in + actual = remove_timezone_info(actual, only_falsy=True) + elif backend == "connectorx": + expected = cast( + List[TColumnSchema], + deepcopy(NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS), + ) + expected = convert_connectorx_types(expected) + expected = remove_timezone_info(expected, only_falsy=False) + # on one of the timestamps somehow there is timezone info..., we only remove values set to false + # to be sure no bad data is coming in + actual = remove_timezone_info(actual, only_falsy=True) + + assert actual == expected + + +def convert_non_pandas_types(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "timestamp": + column["precision"] = 6 + return columns + + +def remove_dlt_columns(columns: List[TColumnSchema]) -> List[TColumnSchema]: + return [col for col in columns if not col["name"].startswith("_dlt")] + + +def remove_default_precision(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "bigint" and column.get("precision") == 32: + del column["precision"] + if column["data_type"] == "text" and column.get("precision"): + del column["precision"] + return remove_timezone_info(columns, only_falsy=False) + + +def remove_timezone_info(columns: List[TColumnSchema], only_falsy: bool) -> List[TColumnSchema]: + for column in columns: + if not only_falsy: + column.pop("timezone", None) + elif column.get("timezone") is False: + column.pop("timezone", None) + return columns + + +def remove_timestamp_precision( + columns: List[TColumnSchema], with_timestamps: bool = True +) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "timestamp" and column["precision"] == 6 and with_timestamps: + del column["precision"] + if column["data_type"] == "time" and column["precision"] == 6: + del column["precision"] + return columns + + +def convert_connectorx_types(columns: List[TColumnSchema]) -> List[TColumnSchema]: + """connector x converts decimals to double, otherwise tries to keep data types and precision + nullability is not kept, string precision is not kept + """ + for column in columns: + if column["data_type"] == "bigint": + if column["name"] == "int_col": + column["precision"] = 32 # only int and bigint in connectorx + if column["data_type"] == "text" and column.get("precision"): + del column["precision"] + return columns + + +def add_default_decimal_precision(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "decimal" and not column.get("precision"): + column["precision"] = 38 + column["scale"] = 9 + return columns + + +PRECISION_COLUMNS: List[TColumnSchema] = [ + { + "data_type": "bigint", + "name": "int_col", + }, + { + "data_type": "bigint", + "name": "bigint_col", + }, + { + "data_type": "bigint", + "precision": 32, + "name": "smallint_col", + }, + { + "data_type": "decimal", + "precision": 10, + "scale": 2, + "name": "numeric_col", + }, + { + "data_type": "decimal", + "name": "numeric_default_col", + }, + { + "data_type": "text", + "precision": 10, + "name": "string_col", + }, + { + "data_type": "text", + "name": "string_default_col", + }, + {"data_type": "timestamp", "precision": 6, "name": "datetime_tz_col", "timezone": True}, + {"data_type": "timestamp", "precision": 6, "name": "datetime_ntz_col", "timezone": False}, + { + "data_type": "date", + "name": "date_col", + }, + { + "data_type": "time", + "name": "time_col", + "precision": 6, + }, + { + "data_type": "double", + "name": "float_col", + }, + { + "data_type": "complex", + "name": "json_col", + }, + { + "data_type": "bool", + "name": "bool_col", + }, +] + +NOT_NULL_PRECISION_COLUMNS = [{"nullable": False, **column} for column in PRECISION_COLUMNS] +NULL_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": True, **column} for column in PRECISION_COLUMNS +] + +# but keep decimal precision +NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + ( + {"name": column["name"], "data_type": column["data_type"]} # type: ignore[misc] + if column["data_type"] != "decimal" + else dict(column) + ) + for column in PRECISION_COLUMNS +] + +NOT_NULL_NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": False, **column} for column in NO_PRECISION_COLUMNS +] +NULL_NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": True, **column} for column in NO_PRECISION_COLUMNS +] diff --git a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py new file mode 100644 index 0000000000..7012602b4a --- /dev/null +++ b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py @@ -0,0 +1,347 @@ +import os +from typing import Any, List + +import humanize +import pytest + +import dlt +from dlt.sources import DltResource +from dlt.sources.credentials import ConnectionStringCredentials +from dlt.common.exceptions import MissingDependencyException + +from tests.load.utils import ( + DestinationTestConfiguration, + destinations_configs, +) +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, +) + +try: + from dlt.sources.sql_database import TableBackend, sql_database, sql_table + from tests.load.sources.sql_database.test_helpers import mock_json_column + from tests.load.sources.sql_database.test_sql_database_source import ( + assert_row_counts, + convert_time_to_us, + default_test_callback, + ) + from tests.load.sources.sql_database.sql_source import SQLAlchemySourceDB + from dlt.common.libs.sql_alchemy import IS_SQL_ALCHEMY_20 +except MissingDependencyException: + pytest.skip("Tests require sql alchemy", allow_module_level=True) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_schema_loads_all_tables( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + reflection_level="minimal", + type_adapter_callback=default_test_callback(destination_config.destination, backend), + ) + + if destination_config.destination == "bigquery" and backend == "connectorx": + # connectorx generates nanoseconds time which bigquery cannot load + source.has_precision.add_map(convert_time_to_us) + source.has_precision_nullable.add_map(convert_time_to_us) + + if backend != "sqlalchemy": + # always use mock json + source.has_precision.add_map(mock_json_column("json_col")) + source.has_precision_nullable.add_map(mock_json_column("json_col")) + + assert "chat_message_view" not in source.resources # Views are not reflected by default + + load_info = pipeline.run(source) + print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at)) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_schema_loads_all_tables_parallel( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + reflection_level="minimal", + type_adapter_callback=default_test_callback(destination_config.destination, backend), + ).parallelize() + + if destination_config.destination == "bigquery" and backend == "connectorx": + # connectorx generates nanoseconds time which bigquery cannot load + source.has_precision.add_map(convert_time_to_us) + source.has_precision_nullable.add_map(convert_time_to_us) + + if backend != "sqlalchemy": + # always use mock json + source.has_precision.add_map(mock_json_column("json_col")) + source.has_precision_nullable.add_map(mock_json_column("json_col")) + + load_info = pipeline.run(source) + print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at)) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_names( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + tables = ["chat_channel", "chat_message"] + load_info = pipeline.run( + sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=tables, + reflection_level="minimal", + backend=backend, + ) + ) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, tables) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_incremental( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + """Run pipeline twice. Insert more rows after first run + and ensure only those rows are stored after the second run. + """ + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at" + + if not IS_SQL_ALCHEMY_20 and backend == "connectorx": + pytest.skip("Test will not run on sqlalchemy 1.4 with connectorx") + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + tables = ["chat_message"] + + def make_source(): + return sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=tables, + reflection_level="minimal", + backend=backend, + ) + + load_info = pipeline.run(make_source()) + assert_load_info(load_info) + sql_source_db.fake_messages(n=100) + load_info = pipeline.run(make_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, tables) + + +@pytest.mark.skip(reason="Skipping this test temporarily") +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_mysql_data_load( + destination_config: DestinationTestConfiguration, backend: TableBackend, request: Any +) -> None: + # reflect a database + credentials = ConnectionStringCredentials( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ) + database = sql_database(credentials) + assert "family" in database.resources + + if backend == "connectorx": + # connector-x has different connection string format + backend_kwargs = {"conn": "mysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam"} + else: + backend_kwargs = {} + + # no longer needed: asdecimal used to infer decimal or not + # def _double_as_decimal_adapter(table: sa.Table) -> sa.Table: + # for column in table.columns.values(): + # if isinstance(column.type, sa.Double): + # column.type.asdecimal = False + + # load a single table + family_table = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + backend=backend, + reflection_level="minimal", + backend_kwargs=backend_kwargs, + # table_adapter_callback=_double_as_decimal_adapter, + ) + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(family_table, write_disposition="merge") + assert_load_info(load_info) + counts_1 = load_table_counts(pipeline, "family") + + # load again also with merge + family_table = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + backend=backend, + reflection_level="minimal", + # we also try to remove dialect automatically + backend_kwargs={}, + # table_adapter_callback=_double_as_decimal_adapter, + ) + load_info = pipeline.run(family_table, write_disposition="merge") + assert_load_info(load_info) + counts_2 = load_table_counts(pipeline, "family") + # no duplicates + assert counts_1 == counts_2 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_resource_loads_data( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_load_sql_table_resource_incremental( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + if not IS_SQL_ALCHEMY_20 and backend == "connectorx": + pytest.skip("Test will not run on sqlalchemy 1.4 with connectorx") + + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + incremental=dlt.sources.incremental("updated_at"), + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + sql_source_db.fake_messages(n=100) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_load_sql_table_resource_incremental_initial_value( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + if not IS_SQL_ALCHEMY_20 and backend == "connectorx": + pytest.skip("Test will not run on sqlalchemy 1.4 with connectorx") + + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + incremental=dlt.sources.incremental( + "updated_at", + sql_source_db.table_infos["chat_message"]["created_at"].start_value, + ), + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 918f9beab9..535d5d28e4 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2732,7 +2732,7 @@ def assert_imported_file( def test_duckdb_column_invalid_timestamp() -> None: - # DuckDB does not have timestamps with timezone and precision + # DuckDB does not have timestamps with timezone and precision, will default to timezone @dlt.resource( columns={"event_tstamp": {"data_type": "timestamp", "timezone": True, "precision": 3}}, primary_key="event_id", @@ -2741,6 +2741,4 @@ def events(): yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] pipeline = dlt.pipeline(destination="duckdb") - - with pytest.raises((TerminalValueError, PipelineStepFailed)): - pipeline.run(events()) + pipeline.run(events()) diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index dfb5f3f82d..d605fa9893 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Callable, Sequence +from typing import Any, Dict, List, Set, Callable, Sequence import pytest import random from os import environ @@ -6,16 +6,16 @@ import dlt from dlt.common import json, sleep -from dlt.common.destination.exceptions import DestinationUndefinedEntity +from dlt.common.data_types import py_type_to_sc_type from dlt.common.pipeline import LoadInfo from dlt.common.schema.utils import get_table_format from dlt.common.typing import DictStrAny from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.fs_client import FSClientBase -from dlt.pipeline.exceptions import SqlClientNotAvailable -from dlt.common.storages import FileStorage from dlt.destinations.exceptions import DatabaseUndefinedRelation +from dlt.common.schema.typing import TTableSchema + PIPELINE_TEST_CASES_PATH = "./tests/pipeline/cases/" @@ -420,3 +420,68 @@ def assert_query_data( # the second is load id if info: assert row[1] in info.loads_ids + + +def assert_schema_on_data( + table_schema: TTableSchema, + rows: List[Dict[str, Any]], + requires_nulls: bool, + check_complex: bool, +) -> None: + """Asserts that `rows` conform to `table_schema`. Fields and their order must conform to columns. Null values and + python data types are checked. + """ + table_columns = table_schema["columns"] + columns_with_nulls: Set[str] = set() + for row in rows: + # check columns + assert set(table_schema["columns"].keys()) == set(row.keys()) + # check column order + assert list(table_schema["columns"].keys()) == list(row.keys()) + # check data types + for key, value in row.items(): + print(key) + print(value) + if value is None: + assert table_columns[key][ + "nullable" + ], f"column {key} must be nullable: value is None" + # next value. we cannot validate data type + columns_with_nulls.add(key) + continue + expected_dt = table_columns[key]["data_type"] + # allow complex strings + if expected_dt == "complex": + if check_complex: + # NOTE: we expect a dict or a list here. simple types of null will fail the test + value = json.loads(value) + else: + # skip checking complex types + continue + actual_dt = py_type_to_sc_type(type(value)) + assert actual_dt == expected_dt + + if requires_nulls: + # make sure that all nullable columns in table received nulls + assert ( + set(col["name"] for col in table_columns.values() if col["nullable"]) + == columns_with_nulls + ), "Some columns didn't receive NULLs which is required" + + +def load_table_distinct_counts( + p: dlt.Pipeline, distinct_column: str, *table_names: str +) -> DictStrAny: + """Returns counts of distinct values for column `distinct_column` for `table_names` as dict""" + with p.sql_client() as c: + query = "\nUNION ALL\n".join( + [ + f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM" + f" {c.make_qualified_table_name(name)}" + for name in table_names + ] + ) + + with c.execute_query(query) as cur: + rows = list(cur.fetchall()) + return {r[0]: r[1] for r in rows} diff --git a/tests/sources/conftest.py b/tests/sources/conftest.py new file mode 100644 index 0000000000..89f7cdffed --- /dev/null +++ b/tests/sources/conftest.py @@ -0,0 +1,7 @@ +from tests.utils import ( + preserve_environ, + autouse_test_storage, + patch_home_dir, + wipe_pipeline, + duckdb_pipeline_location, +) diff --git a/tests/sources/filesystem/__init__.py b/tests/sources/filesystem/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/filesystem/test_filesystem_pipeline_template.py b/tests/sources/filesystem/test_filesystem_pipeline_template.py new file mode 100644 index 0000000000..38c51c110c --- /dev/null +++ b/tests/sources/filesystem/test_filesystem_pipeline_template.py @@ -0,0 +1,22 @@ +import pytest + +from tests.common.storages.utils import TEST_SAMPLE_FILES + + +@pytest.mark.parametrize( + "example_name", + ( + "read_custom_file_type_excel", + "stream_and_merge_csv", + "read_csv_with_duckdb", + "read_csv_duckdb_compressed", + "read_parquet_and_jsonl_chunked", + "read_files_incrementally_mtime", + ), +) +def test_all_examples(example_name: str) -> None: + from dlt.sources import filesystem_pipeline + + filesystem_pipeline.TESTS_BUCKET_URL = TEST_SAMPLE_FILES + + getattr(filesystem_pipeline, example_name)() diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 10dd23877d..d59df3a4bb 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,258 +1 @@ -import base64 -from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode - -import pytest -import requests_mock - -from dlt.sources.helpers.rest_client import RESTClient - -from .api_router import APIRouter -from .paginators import PageNumberPaginator, OffsetPaginator, CursorPaginator - - -MOCK_BASE_URL = "https://api.example.com" -DEFAULT_PAGE_SIZE = 5 -DEFAULT_TOTAL_PAGES = 5 -DEFAULT_LIMIT = 10 - - -router = APIRouter(MOCK_BASE_URL) - - -def generate_posts(count=DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES): - return [{"id": i, "title": f"Post {i}"} for i in range(count)] - - -def generate_comments(post_id, count=50): - return [{"id": i, "body": f"Comment {i} for post {post_id}"} for i in range(count)] - - -def get_page_number(qs, key="page", default=1): - return int(qs.get(key, [default])[0]) - - -def create_next_page_url(request, paginator, use_absolute_url=True): - scheme, netloc, path, _, _ = urlsplit(request.url) - query = urlencode(paginator.next_page_url_params) - if use_absolute_url: - return urlunsplit([scheme, netloc, path, query, ""]) - else: - return f"{path}?{query}" - - -def paginate_by_page_number( - request, records, records_key="data", use_absolute_url=True, index_base=1 -): - page_number = get_page_number(request.qs, default=index_base) - paginator = PageNumberPaginator(records, page_number, index_base=index_base) - - response = { - records_key: paginator.page_records, - **paginator.metadata, - } - - if paginator.next_page_url_params: - response["next_page"] = create_next_page_url(request, paginator, use_absolute_url) - - return response - - -@pytest.fixture(scope="module") -def mock_api_server(): - with requests_mock.Mocker() as m: - - @router.get(r"/posts(\?page=\d+)?$") - def posts(request, context): - return paginate_by_page_number(request, generate_posts()) - - @router.get(r"/posts_zero_based(\?page=\d+)?$") - def posts_zero_based(request, context): - return paginate_by_page_number(request, generate_posts(), index_base=0) - - @router.get(r"/posts_header_link(\?page=\d+)?$") - def posts_header_link(request, context): - records = generate_posts() - page_number = get_page_number(request.qs) - paginator = PageNumberPaginator(records, page_number) - - response = paginator.page_records - - if paginator.next_page_url_params: - next_page_url = create_next_page_url(request, paginator) - context.headers["Link"] = f'<{next_page_url}>; rel="next"' - - return response - - @router.get(r"/posts_relative_next_url(\?page=\d+)?$") - def posts_relative_next_url(request, context): - return paginate_by_page_number(request, generate_posts(), use_absolute_url=False) - - @router.get(r"/posts_offset_limit(\?offset=\d+&limit=\d+)?$") - def posts_offset_limit(request, context): - records = generate_posts() - offset = int(request.qs.get("offset", [0])[0]) - limit = int(request.qs.get("limit", [DEFAULT_LIMIT])[0]) - paginator = OffsetPaginator(records, offset, limit) - - return { - "data": paginator.page_records, - **paginator.metadata, - } - - @router.get(r"/posts_cursor(\?cursor=\d+)?$") - def posts_cursor(request, context): - records = generate_posts() - cursor = int(request.qs.get("cursor", [0])[0]) - paginator = CursorPaginator(records, cursor) - - return { - "data": paginator.page_records, - **paginator.metadata, - } - - @router.get(r"/posts/(\d+)/comments") - def post_comments(request, context): - post_id = int(request.url.split("/")[-2]) - return paginate_by_page_number(request, generate_comments(post_id)) - - @router.get(r"/posts/\d+$") - def post_detail(request, context): - post_id = request.url.split("/")[-1] - return {"id": post_id, "body": f"Post body {post_id}"} - - @router.get(r"/posts/\d+/some_details_404") - def post_detail_404(request, context): - """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" - post_id = int(request.url.split("/")[-2]) - if post_id < 1: - return {"id": post_id, "body": f"Post body {post_id}"} - else: - context.status_code = 404 - return {"error": "Post not found"} - - @router.get(r"/posts_under_a_different_key$") - def posts_with_results_key(request, context): - return paginate_by_page_number(request, generate_posts(), records_key="many-results") - - @router.post(r"/posts/search$") - def search_posts(request, context): - body = request.json() - page_size = body.get("page_size", DEFAULT_PAGE_SIZE) - page_number = body.get("page", 1) - - # Simulate a search with filtering - records = generate_posts() - ids_greater_than = body.get("ids_greater_than", 0) - records = [r for r in records if r["id"] > ids_greater_than] - - total_records = len(records) - total_pages = (total_records + page_size - 1) // page_size - start_index = (page_number - 1) * page_size - end_index = start_index + page_size - records_slice = records[start_index:end_index] - - return { - "data": records_slice, - "next_page": page_number + 1 if page_number < total_pages else None, - } - - @router.get("/protected/posts/basic-auth") - def protected_basic_auth(request, context): - auth = request.headers.get("Authorization") - creds = "user:password" - creds_base64 = base64.b64encode(creds.encode()).decode() - if auth == f"Basic {creds_base64}": - return paginate_by_page_number(request, generate_posts()) - context.status_code = 401 - return {"error": "Unauthorized"} - - @router.get("/protected/posts/bearer-token") - def protected_bearer_token(request, context): - auth = request.headers.get("Authorization") - if auth == "Bearer test-token": - return paginate_by_page_number(request, generate_posts()) - context.status_code = 401 - return {"error": "Unauthorized"} - - @router.get("/protected/posts/bearer-token-plain-text-error") - def protected_bearer_token_plain_text_erorr(request, context): - auth = request.headers.get("Authorization") - if auth == "Bearer test-token": - return paginate_by_page_number(request, generate_posts()) - context.status_code = 401 - return "Unauthorized" - - @router.get("/protected/posts/api-key") - def protected_api_key(request, context): - api_key = request.headers.get("x-api-key") - if api_key == "test-api-key": - return paginate_by_page_number(request, generate_posts()) - context.status_code = 401 - return {"error": "Unauthorized"} - - @router.post("/oauth/token") - def oauth_token(request, context): - if oauth_authorize(request): - return {"access_token": "test-token", "expires_in": 3600} - context.status_code = 401 - return {"error": "Unauthorized"} - - @router.post("/oauth/token-expires-now") - def oauth_token_expires_now(request, context): - if oauth_authorize(request): - return {"access_token": "test-token", "expires_in": 0} - context.status_code = 401 - return {"error": "Unauthorized"} - - @router.post("/auth/refresh") - def refresh_token(request, context): - body = request.json() - if body.get("refresh_token") == "valid-refresh-token": - return {"access_token": "new-valid-token"} - context.status_code = 401 - return {"error": "Invalid refresh token"} - - @router.post("/custom-oauth/token") - def custom_oauth_token(request, context): - qs = parse_qs(request.text) - if ( - qs.get("grant_type")[0] == "account_credentials" - and qs.get("account_id")[0] == "test-account-id" - and request.headers["Authorization"] - == "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==" - ): - return {"access_token": "test-token", "expires_in": 3600} - context.status_code = 401 - return {"error": "Unauthorized"} - - router.register_routes(m) - - yield m - - -@pytest.fixture -def rest_client() -> RESTClient: - return RESTClient( - base_url="https://api.example.com", - headers={"Accept": "application/json"}, - ) - - -def oauth_authorize(request): - qs = parse_qs(request.text) - grant_type = qs.get("grant_type")[0] - if "jwt-bearer" in grant_type: - return True - if "client_credentials" in grant_type: - return ( - qs["client_secret"][0] == "test-client-secret" - and qs["client_id"][0] == "test-client-id" - ) - - -def assert_pagination(pages, page_size=DEFAULT_PAGE_SIZE, total_pages=DEFAULT_TOTAL_PAGES): - assert len(pages) == total_pages - for i, page in enumerate(pages): - assert page == [ - {"id": i, "title": f"Post {i}"} for i in range(i * page_size, (i + 1) * page_size) - ] +from tests.sources.rest_api.conftest import * # noqa: F403 diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index af914bf89d..5ec48e2972 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -77,7 +77,7 @@ class TestRESTClient: def test_get_single_resource(self, rest_client): response = rest_client.get("/posts/1") assert response.status_code == 200 - assert response.json() == {"id": "1", "body": "Post body 1"} + assert response.json() == {"id": 1, "body": "Post body 1"} def test_pagination(self, rest_client: RESTClient): pages_iter = rest_client.paginate( @@ -412,7 +412,7 @@ def update_request(self, request): page_generator = rest_client.paginate( path="/posts/search", method="POST", - json={"ids_greater_than": posts_skip - 1}, + json={"ids_greater_than": posts_skip - 1, "page_size": 5, "page_count": 5}, paginator=JSONBodyPageCursorPaginator(), ) result = [post for page in list(page_generator) for post in page] diff --git a/tests/sources/rest_api/__init__.py b/tests/sources/rest_api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/rest_api/configurations/__init__.py b/tests/sources/rest_api/configurations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/rest_api/configurations/source_configs.py b/tests/sources/rest_api/configurations/source_configs.py new file mode 100644 index 0000000000..334bfdd230 --- /dev/null +++ b/tests/sources/rest_api/configurations/source_configs.py @@ -0,0 +1,335 @@ +from collections import namedtuple +from typing import cast, List + +import dlt +from dlt.common.typing import TSecretStrValue +from dlt.common.exceptions import DictValidationException +from dlt.common.configuration.specs import configspec +from dlt.sources.helpers.rest_client.paginators import HeaderLinkPaginator +from dlt.sources.helpers.rest_client.auth import OAuth2AuthBase + +from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator +from dlt.sources.helpers.rest_client.auth import HttpBasicAuth + +from dlt.sources.rest_api.typing import RESTAPIConfig + + +ConfigTest = namedtuple("ConfigTest", ["expected_message", "exception", "config"]) + +INVALID_CONFIGS = [ + ConfigTest( + expected_message="following required fields are missing {'resources'}", + exception=DictValidationException, + config={"client": {"base_url": ""}}, + ), + ConfigTest( + expected_message="following required fields are missing {'client'}", + exception=DictValidationException, + config={"resources": []}, + ), + ConfigTest( + expected_message="In path ./client: following fields are unexpected {'invalid_key'}", + exception=DictValidationException, + config={ + "client": { + "base_url": "https://api.example.com", + "invalid_key": "value", + }, + "resources": ["posts"], + }, + ), + ConfigTest( + expected_message="field 'paginator' with value invalid_paginator is not one of:", + exception=DictValidationException, + config={ + "client": { + "base_url": "https://api.example.com", + "paginator": "invalid_paginator", + }, + "resources": ["posts"], + }, + ), + ConfigTest( + expected_message="issuess", + exception=ValueError, + config={ + "client": {"base_url": "https://github.com/api/v2"}, + "resources": [ + "issues", + { + "name": "comments", + "endpoint": { + "path": "issues/{id}/comments", + "params": { + "id": { + "type": "resolve", + "resource": "issuess", + "field": "id", + }, + }, + }, + }, + ], + }, + ), + ConfigTest( + expected_message="{org}/{repo}/issues/", + exception=ValueError, + config={ + "client": {"base_url": "https://github.com/api/v2"}, + "resources": [ + {"name": "issues", "endpoint": {"path": "{org}/{repo}/issues/"}}, + { + "name": "comments", + "endpoint": { + "path": "{org}/{repo}/issues/{id}/comments", + "params": { + "id": { + "type": "resolve", + "resource": "issues", + "field": "id", + }, + }, + }, + }, + ], + }, + ), +] + + +class CustomPaginator(HeaderLinkPaginator): + def __init__(self) -> None: + super().__init__(links_next_key="prev") + + +@configspec +class CustomOAuthAuth(OAuth2AuthBase): + pass + + +VALID_CONFIGS: List[RESTAPIConfig] = [ + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_comments", + "endpoint": { + "path": "posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + }, + }, + }, + }, + ], + }, + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "params": { + "limit": 100, + }, + "paginator": "json_link", + }, + }, + ], + }, + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "params": { + "limit": 1, + }, + "paginator": SinglePagePaginator(), + }, + }, + ], + }, + { + "client": { + "base_url": "https://example.com", + "auth": {"type": "bearer", "token": "X"}, + }, + "resources": ["users"], + }, + { + "client": { + "base_url": "https://example.com", + "auth": {"token": "X"}, + }, + "resources": ["users"], + }, + { + "client": { + "base_url": "https://example.com", + "paginator": CustomPaginator(), + "auth": CustomOAuthAuth(access_token=cast(TSecretStrValue, "X")), + }, + "resource_defaults": { + "table_name": lambda event: event["type"], + "endpoint": { + "paginator": CustomPaginator(), + "params": {"since": dlt.sources.incremental[str]("user_id")}, + }, + }, + "resources": [ + { + "name": "users", + "endpoint": { + "paginator": CustomPaginator(), + "params": {"since": dlt.sources.incremental[str]("user_id")}, + }, + } + ], + }, + { + "client": { + "base_url": "https://example.com", + "paginator": "header_link", + "auth": HttpBasicAuth("my-secret", cast(TSecretStrValue, "")), + }, + "resources": ["users"], + }, + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "params": { + "limit": 100, + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-25T11:21:28Z", + }, + }, + "paginator": "json_link", + }, + }, + ], + }, + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "params": { + "limit": 100, + }, + "paginator": "json_link", + "incremental": { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": "2024-01-25T11:21:28Z", + }, + }, + }, + ], + }, + { + "client": { + "base_url": "https://api.example.com", + "headers": { + "X-Test-Header": "test42", + }, + }, + "resources": [ + "users", + {"name": "users_2"}, + {"name": "users_list", "endpoint": "users_list"}, + ], + }, + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_comments", + "table_name": lambda item: item["type"], + "endpoint": { + "path": "posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + }, + }, + }, + }, + ], + }, + { + "client": {"base_url": "https://github.com/api/v2"}, + "resources": [ + { + "name": "issues", + "endpoint": { + "path": "{org}/{repo}/issues/", + "params": {"org": "dlt-hub", "repo": "dlt"}, + }, + }, + { + "name": "comments", + "endpoint": { + "path": "{org}/{repo}/issues/{id}/comments", + "params": { + "org": "dlt-hub", + "repo": "dlt", + "id": { + "type": "resolve", + "resource": "issues", + "field": "id", + }, + }, + }, + }, + ], + }, +] + + +# NOTE: leaves some parameters as defaults to test if they are set correctly +PAGINATOR_TYPE_CONFIGS = [ + {"type": "auto"}, + {"type": "single_page"}, + {"type": "page_number", "page": 10, "base_page": 1, "total_path": "response.pages"}, + {"type": "offset", "limit": 100, "maximum_offset": 1000}, + {"type": "header_link", "links_next_key": "next_page"}, + {"type": "json_link", "next_url_path": "response.nex_page_link"}, + {"type": "cursor", "cursor_param": "cursor"}, +] + + +# NOTE: leaves some required parameters to inject them from config +AUTH_TYPE_CONFIGS = [ + {"type": "bearer", "token": "token"}, + {"type": "api_key", "location": "cookie"}, + {"type": "http_basic", "username": "username"}, + { + "type": "oauth2_client_credentials", + "access_token_url": "https://example.com/oauth/token", + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "access_token_request_data": {"foo": "bar"}, + "default_token_expiration": 60, + }, +] diff --git a/tests/sources/rest_api/configurations/test_auth_config.py b/tests/sources/rest_api/configurations/test_auth_config.py new file mode 100644 index 0000000000..4c925c05b1 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_auth_config.py @@ -0,0 +1,311 @@ +import re +from typing import Any, Dict, List, Literal, NamedTuple, Optional, Union, cast, get_args + +import pytest +from requests.auth import AuthBase + +import dlt +import dlt.common +import dlt.common.exceptions +import dlt.extract +from dlt.common.configuration import inject_section +from dlt.common.configuration.specs import ConfigSectionContext +from dlt.common.typing import TSecretStrValue +from dlt.common.utils import custom_environ +from dlt.sources.rest_api import ( + _mask_secrets, + rest_api_source, +) +from dlt.sources.rest_api.config_setup import ( + AUTH_MAP, + create_auth, +) +from dlt.sources.rest_api.typing import ( + AuthConfigBase, + AuthType, + AuthTypeConfig, + RESTAPIConfig, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + pass + + +from dlt.sources.helpers.rest_client.auth import ( + APIKeyAuth, + BearerTokenAuth, + HttpBasicAuth, + OAuth2ClientCredentials, +) + +from .source_configs import ( + AUTH_TYPE_CONFIGS, +) + + +@pytest.mark.parametrize("auth_type", get_args(AuthType)) +@pytest.mark.parametrize( + "section", ("SOURCES__REST_API__CREDENTIALS", "SOURCES__CREDENTIALS", "CREDENTIALS") +) +def test_auth_shorthands(auth_type: AuthType, section: str) -> None: + # TODO: remove when changes in rest_client/auth.py are released + if auth_type == "oauth2_client_credentials": + pytest.skip("Waiting for release of changes in rest_client/auth.py") + + # mock all required envs + with custom_environ( + { + f"{section}__TOKEN": "token", + f"{section}__API_KEY": "api_key", + f"{section}__USERNAME": "username", + f"{section}__PASSWORD": "password", + # TODO: uncomment when changes in rest_client/auth.py are released + # f"{section}__ACCESS_TOKEN_URL": "https://example.com/oauth/token", + # f"{section}__CLIENT_ID": "a_client_id", + # f"{section}__CLIENT_SECRET": "a_client_secret", + } + ): + # shorthands need to instantiate from config + with inject_section( + ConfigSectionContext(sections=("sources", "rest_api")), merge_existing=False + ): + import os + + print(os.environ) + auth = create_auth(auth_type) + assert isinstance(auth, AUTH_MAP[auth_type]) + if isinstance(auth, BearerTokenAuth): + assert auth.token == "token" + if isinstance(auth, APIKeyAuth): + assert auth.api_key == "api_key" + assert auth.location == "header" + assert auth.name == "Authorization" + if isinstance(auth, HttpBasicAuth): + assert auth.username == "username" + assert auth.password == "password" + # TODO: uncomment when changes in rest_client/auth.py are released + # if isinstance(auth, OAuth2ClientCredentials): + # assert auth.access_token_url == "https://example.com/oauth/token" + # assert auth.client_id == "a_client_id" + # assert auth.client_secret == "a_client_secret" + # assert auth.default_token_expiration == 3600 + + +@pytest.mark.parametrize("auth_type_config", AUTH_TYPE_CONFIGS) +@pytest.mark.parametrize( + "section", ("SOURCES__REST_API__CREDENTIALS", "SOURCES__CREDENTIALS", "CREDENTIALS") +) +def test_auth_type_configs(auth_type_config: AuthTypeConfig, section: str) -> None: + # mock all required envs + with custom_environ( + { + f"{section}__API_KEY": "api_key", + f"{section}__NAME": "session-cookie", + f"{section}__PASSWORD": "password", + } + ): + # shorthands need to instantiate from config + with inject_section( + ConfigSectionContext(sections=("sources", "rest_api")), merge_existing=False + ): + auth = create_auth(auth_type_config) # type: ignore + assert isinstance(auth, AUTH_MAP[auth_type_config["type"]]) + if isinstance(auth, BearerTokenAuth): + # from typed dict + assert auth.token == "token" + if isinstance(auth, APIKeyAuth): + assert auth.location == "cookie" + # injected + assert auth.api_key == "api_key" + assert auth.name == "session-cookie" + if isinstance(auth, HttpBasicAuth): + # typed dict + assert auth.username == "username" + # injected + assert auth.password == "password" + if isinstance(auth, OAuth2ClientCredentials): + assert auth.access_token_url == "https://example.com/oauth/token" + assert auth.client_id == "a_client_id" + assert auth.client_secret == "a_client_secret" + assert auth.default_token_expiration == 60 + + +@pytest.mark.parametrize( + "section", ("SOURCES__REST_API__CREDENTIALS", "SOURCES__CREDENTIALS", "CREDENTIALS") +) +def test_auth_instance_config(section: str) -> None: + auth = APIKeyAuth(location="param", name="token") + with custom_environ( + { + f"{section}__API_KEY": "api_key", + f"{section}__NAME": "session-cookie", + } + ): + # shorthands need to instantiate from config + with inject_section( + ConfigSectionContext(sections=("sources", "rest_api")), merge_existing=False + ): + # this also resolved configuration + resolved_auth = create_auth(auth) + assert resolved_auth is auth + # explicit + assert auth.location == "param" + # injected + assert auth.api_key == "api_key" + # config overrides explicit (TODO: reverse) + assert auth.name == "session-cookie" + + +def test_bearer_token_fallback() -> None: + auth = create_auth({"token": "secret"}) + assert isinstance(auth, BearerTokenAuth) + assert auth.token == "secret" + + +def test_error_message_invalid_auth_type() -> None: + with pytest.raises(ValueError) as e: + create_auth("non_existing_method") # type: ignore + assert ( + str(e.value) + == "Invalid authentication: non_existing_method." + " Available options: bearer, api_key, http_basic, oauth2_client_credentials." + ) + + +class AuthConfigTest(NamedTuple): + secret_keys: List[Literal["token", "api_key", "password", "username"]] + config: Union[Dict[str, Any], AuthConfigBase] + masked_secrets: Optional[List[str]] = ["s*****t"] + + +AUTH_CONFIGS = [ + AuthConfigTest( + secret_keys=["token"], + config={ + "type": "bearer", + "token": "sensitive-secret", + }, + ), + AuthConfigTest( + secret_keys=["api_key"], + config={ + "type": "api_key", + "api_key": "sensitive-secret", + }, + ), + AuthConfigTest( + secret_keys=["username", "password"], + config={ + "type": "http_basic", + "username": "sensitive-secret", + "password": "sensitive-secret", + }, + masked_secrets=["s*****t", "s*****t"], + ), + AuthConfigTest( + secret_keys=["username", "password"], + config={ + "type": "http_basic", + "username": "", + "password": "sensitive-secret", + }, + masked_secrets=["*****", "s*****t"], + ), + AuthConfigTest( + secret_keys=["username", "password"], + config={ + "type": "http_basic", + "username": "sensitive-secret", + "password": "", + }, + masked_secrets=["s*****t", "*****"], + ), + AuthConfigTest( + secret_keys=["token"], + config=BearerTokenAuth(token=cast(TSecretStrValue, "sensitive-secret")), + ), + AuthConfigTest( + secret_keys=["api_key"], + config=APIKeyAuth(api_key=cast(TSecretStrValue, "sensitive-secret")), + ), + AuthConfigTest( + secret_keys=["username", "password"], + config=HttpBasicAuth("sensitive-secret", cast(TSecretStrValue, "sensitive-secret")), + masked_secrets=["s*****t", "s*****t"], + ), + AuthConfigTest( + secret_keys=["username", "password"], + config=HttpBasicAuth("sensitive-secret", cast(TSecretStrValue, "")), + masked_secrets=["s*****t", "*****"], + ), + AuthConfigTest( + secret_keys=["username", "password"], + config=HttpBasicAuth("", cast(TSecretStrValue, "sensitive-secret")), + masked_secrets=["*****", "s*****t"], + ), +] + + +@pytest.mark.parametrize("secret_keys, config, masked_secrets", AUTH_CONFIGS) +def test_secret_masking_auth_config(secret_keys, config, masked_secrets): + masked = _mask_secrets(config) + for key, mask in zip(secret_keys, masked_secrets): + assert masked[key] == mask # type: ignore[literal-required] + + +def test_secret_masking_oauth() -> None: + config = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, ""), + client_id=cast(TSecretStrValue, "sensitive-secret"), + client_secret=cast(TSecretStrValue, "sensitive-secret"), + ) + + obj = _mask_secrets(config) + assert "sensitive-secret" not in str(obj) + + # TODO + # assert masked.access_token == "None" + # assert masked.client_id == "s*****t" + # assert masked.client_secret == "s*****t" + + +def test_secret_masking_custom_auth() -> None: + class CustomAuthConfigBase(AuthConfigBase): + def __init__(self, token: str = "sensitive-secret"): + self.token = token + + class CustomAuthBase(AuthBase): + def __init__(self, token: str = "sensitive-secret"): + self.token = token + + auth = _mask_secrets(CustomAuthConfigBase()) + assert "s*****t" not in str(auth) + # TODO + # assert auth.token == "s*****t" + + auth_2 = _mask_secrets(CustomAuthBase()) # type: ignore[arg-type] + assert "s*****t" not in str(auth_2) + # TODO + # assert auth_2.token == "s*****t" + + +def test_validation_masks_auth_secrets() -> None: + incorrect_config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + "auth": { # type: ignore[typeddict-item] + "type": "bearer", + "location": "header", + "token": "sensitive-secret", + }, + }, + "resources": ["posts"], + } + with pytest.raises(dlt.common.exceptions.DictValidationException) as e: + rest_api_source(incorrect_config) + assert ( + re.search("sensitive-secret", str(e.value)) is None + ), "unexpectedly printed 'sensitive-secret'" + assert e.match(re.escape("'{'type': 'bearer', 'location': 'header', 'token': 's*****t'}'")) diff --git a/tests/sources/rest_api/configurations/test_configuration.py b/tests/sources/rest_api/configurations/test_configuration.py new file mode 100644 index 0000000000..0167ea1eb8 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_configuration.py @@ -0,0 +1,403 @@ +from copy import copy +from typing import cast +from unittest.mock import patch + +import pytest + +import dlt +import dlt.common +import dlt.common.exceptions +import dlt.extract +from dlt.common.utils import update_dict_nested +from dlt.sources.helpers.rest_client.paginators import ( + HeaderLinkPaginator, + SinglePagePaginator, +) +from dlt.sources.rest_api import ( + rest_api_resources, + rest_api_source, +) +from dlt.sources.rest_api.config_setup import ( + _make_endpoint_resource, + _merge_resource_endpoints, + _setup_single_entity_endpoint, +) +from dlt.sources.rest_api.typing import ( + Endpoint, + EndpointResource, + EndpointResourceBase, + RESTAPIConfig, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + pass + + +from .source_configs import ( + INVALID_CONFIGS, + VALID_CONFIGS, +) + + +@pytest.mark.parametrize("expected_message, exception, invalid_config", INVALID_CONFIGS) +def test_invalid_configurations(expected_message, exception, invalid_config): + with pytest.raises(exception, match=expected_message): + rest_api_source(invalid_config) + + +@pytest.mark.parametrize("valid_config", VALID_CONFIGS) +def test_valid_configurations(valid_config): + rest_api_source(valid_config) + + +@pytest.mark.parametrize("config", VALID_CONFIGS) +def test_configurations_dict_is_not_modified_in_place(config): + # deep clone dicts but do not touch instances of classes so ids still compare + config_copy = update_dict_nested({}, config) + rest_api_source(config) + assert config_copy == config + + +def test_resource_expand() -> None: + # convert str into name / path + assert _make_endpoint_resource("path", {}) == { + "name": "path", + "endpoint": {"path": "path"}, + } + # expand endpoint str into path + assert _make_endpoint_resource({"name": "resource", "endpoint": "path"}, {}) == { + "name": "resource", + "endpoint": {"path": "path"}, + } + # expand name into path with optional endpoint + assert _make_endpoint_resource({"name": "resource"}, {}) == { + "name": "resource", + "endpoint": {"path": "resource"}, + } + # endpoint path is optional + assert _make_endpoint_resource({"name": "resource", "endpoint": {}}, {}) == { + "name": "resource", + "endpoint": {"path": "resource"}, + } + + +def test_resource_endpoint_deep_merge() -> None: + # columns deep merged + resource = _make_endpoint_resource( + { + "name": "resources", + "columns": [ + {"name": "col_a", "data_type": "bigint"}, + {"name": "col_b"}, + ], + }, + { + "columns": [ + {"name": "col_a", "data_type": "text", "primary_key": True}, + {"name": "col_c", "data_type": "timestamp", "partition": True}, + ] + }, + ) + assert resource["columns"] == { + # data_type and primary_key merged + "col_a": {"name": "col_a", "data_type": "bigint", "primary_key": True}, + # from defaults + "col_c": {"name": "col_c", "data_type": "timestamp", "partition": True}, + # from resource (partial column moved to the end) + "col_b": {"name": "col_b"}, + } + # json and params deep merged + resource = _make_endpoint_resource( + { + "name": "resources", + "endpoint": { + "json": {"param1": "A", "param2": "B"}, + "params": {"param1": "A", "param2": "B"}, + }, + }, + { + "endpoint": { + "json": {"param1": "X", "param3": "Y"}, + "params": {"param1": "X", "param3": "Y"}, + } + }, + ) + assert resource["endpoint"] == { + "json": {"param1": "A", "param3": "Y", "param2": "B"}, + "params": {"param1": "A", "param3": "Y", "param2": "B"}, + "path": "resources", + } + + +def test_resource_endpoint_shallow_merge() -> None: + # merge paginators and other typed dicts as whole + resource_config: EndpointResource = { + "name": "resources", + "max_table_nesting": 5, + "write_disposition": {"disposition": "merge", "strategy": "scd2"}, + "schema_contract": {"tables": "freeze"}, + "endpoint": { + "paginator": {"type": "cursor", "cursor_param": "cursor"}, + "incremental": {"cursor_path": "$", "start_param": "since"}, + }, + } + + resource = _make_endpoint_resource( + resource_config, + { + "max_table_nesting": 1, + "parallelized": True, + "write_disposition": { + "disposition": "replace", + }, + "schema_contract": {"columns": "freeze"}, + "endpoint": { + "paginator": { + "type": "header_link", + }, + "incremental": { + "cursor_path": "response.id", + "start_param": "since", + "end_param": "before", + }, + }, + }, + ) + # resource should keep all values, just parallel is added + expected_resource = copy(resource_config) + expected_resource["parallelized"] = True + assert resource == expected_resource + + +def test_resource_merge_with_objects() -> None: + paginator = SinglePagePaginator() + incremental = dlt.sources.incremental[int]("id", row_order="asc") + resource = _make_endpoint_resource( + { + "name": "resource", + "endpoint": { + "path": "path/to", + "paginator": paginator, + "params": {"since": incremental}, + }, + }, + { + "table_name": lambda item: item["type"], + "endpoint": { + "paginator": HeaderLinkPaginator(), + "params": {"since": dlt.sources.incremental[int]("id", row_order="desc")}, + }, + }, + ) + # objects are as is, not cloned + assert resource["endpoint"]["paginator"] is paginator # type: ignore[index] + assert resource["endpoint"]["params"]["since"] is incremental # type: ignore[index] + # callable coming from default + assert callable(resource["table_name"]) + + +def test_resource_merge_with_none() -> None: + endpoint_config: EndpointResource = { + "name": "resource", + "endpoint": {"path": "user/{id}", "paginator": None, "data_selector": None}, + } + # None should be able to reset the default + resource = _make_endpoint_resource( + endpoint_config, + {"endpoint": {"paginator": SinglePagePaginator(), "data_selector": "data"}}, + ) + # nones will overwrite defaults + assert resource == endpoint_config + + +def test_setup_for_single_item_endpoint() -> None: + # single item should revert to single page validator + endpoint = _setup_single_entity_endpoint({"path": "user/{id}"}) + assert endpoint["data_selector"] == "$" + assert isinstance(endpoint["paginator"], SinglePagePaginator) + + # this is not single page + endpoint = _setup_single_entity_endpoint({"path": "user/{id}/messages"}) + assert "data_selector" not in endpoint + + # simulate using None to remove defaults + endpoint_config: EndpointResource = { + "name": "resource", + "endpoint": {"path": "user/{id}", "paginator": None, "data_selector": None}, + } + # None should be able to reset the default + resource = _make_endpoint_resource( + endpoint_config, + {"endpoint": {"paginator": HeaderLinkPaginator(), "data_selector": "data"}}, + ) + + endpoint = _setup_single_entity_endpoint(cast(Endpoint, resource["endpoint"])) + assert endpoint["data_selector"] == "$" + assert isinstance(endpoint["paginator"], SinglePagePaginator) + + +def test_resource_schema() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + "users", + { + "name": "user", + "endpoint": { + "path": "user/{id}", + "paginator": None, + "data_selector": None, + "params": { + "id": { + "type": "resolve", + "field": "id", + "resource": "users", + }, + }, + }, + }, + ], + } + resources = rest_api_resources(config) + assert len(resources) == 2 + resource = resources[0] + assert resource.name == "users" + assert resources[1].name == "user" + + +def test_resource_hints_are_passed_to_resource_constructor() -> None: + config: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "params": { + "limit": 100, + }, + }, + "table_name": "a_table", + "max_table_nesting": 2, + "write_disposition": "merge", + "columns": {"a_text": {"name": "a_text", "data_type": "text"}}, + "primary_key": "a_pk", + "merge_key": "a_merge_key", + "schema_contract": {"tables": "evolve"}, + "table_format": "iceberg", + "selected": False, + }, + ], + } + + with patch.object(dlt, "resource", wraps=dlt.resource) as mock_resource_constructor: + rest_api_resources(config) + mock_resource_constructor.assert_called_once() + expected_kwargs = { + "table_name": "a_table", + "max_table_nesting": 2, + "write_disposition": "merge", + "columns": {"a_text": {"name": "a_text", "data_type": "text"}}, + "primary_key": "a_pk", + "merge_key": "a_merge_key", + "schema_contract": {"tables": "evolve"}, + "table_format": "iceberg", + "selected": False, + } + for arg in expected_kwargs.items(): + _, kwargs = mock_resource_constructor.call_args_list[0] + assert arg in kwargs.items() + + +def test_resource_defaults_params_get_merged() -> None: + resource_defaults: EndpointResourceBase = { + "primary_key": "id", + "write_disposition": "merge", + "endpoint": { + "params": { + "per_page": 30, + }, + }, + } + + resource: EndpointResource = { + "endpoint": { + "path": "issues", + "params": { + "sort": "updated", + "direction": "desc", + "state": "open", + }, + }, + } + merged_resource = _merge_resource_endpoints(resource_defaults, resource) + assert merged_resource["endpoint"]["params"]["per_page"] == 30 # type: ignore[index] + + +def test_resource_defaults_params_get_overwritten() -> None: + resource_defaults: EndpointResourceBase = { + "primary_key": "id", + "write_disposition": "merge", + "endpoint": { + "params": { + "per_page": 30, + }, + }, + } + + resource: EndpointResource = { + "endpoint": { + "path": "issues", + "params": { + "per_page": 50, + "sort": "updated", + }, + }, + } + merged_resource = _merge_resource_endpoints(resource_defaults, resource) + assert merged_resource["endpoint"]["params"]["per_page"] == 50 # type: ignore[index] + + +def test_resource_defaults_params_no_resource_params() -> None: + resource_defaults: EndpointResourceBase = { + "primary_key": "id", + "write_disposition": "merge", + "endpoint": { + "params": { + "per_page": 30, + }, + }, + } + + resource: EndpointResource = { + "endpoint": { + "path": "issues", + }, + } + merged_resource = _merge_resource_endpoints(resource_defaults, resource) + assert merged_resource["endpoint"]["params"]["per_page"] == 30 # type: ignore[index] + + +def test_resource_defaults_no_params() -> None: + resource_defaults: EndpointResourceBase = { + "primary_key": "id", + "write_disposition": "merge", + } + + resource: EndpointResource = { + "endpoint": { + "path": "issues", + "params": { + "per_page": 50, + "sort": "updated", + }, + }, + } + merged_resource = _merge_resource_endpoints(resource_defaults, resource) + assert merged_resource["endpoint"]["params"] == { # type: ignore[index] + "per_page": 50, + "sort": "updated", + } diff --git a/tests/sources/rest_api/configurations/test_custom_auth_config.py b/tests/sources/rest_api/configurations/test_custom_auth_config.py new file mode 100644 index 0000000000..1a5a2e58a3 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_custom_auth_config.py @@ -0,0 +1,79 @@ +from base64 import b64encode +from typing import Any, Dict, cast + +import pytest + +from dlt.sources import rest_api +from dlt.sources.helpers.rest_client.auth import APIKeyAuth, OAuth2ClientCredentials +from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig + + +class CustomOAuth2(OAuth2ClientCredentials): + def build_access_token_request(self) -> Dict[str, Any]: + """Used e.g. by Zoom Zoom Video Communications, Inc.""" + authentication: str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": self.access_token_request_data, + } + + +class TestCustomAuth: + @pytest.fixture + def custom_auth_config(self) -> AuthConfig: + config: AuthConfig = { + "type": "custom_oauth_2", # type: ignore + "access_token_url": "https://example.com/oauth/token", + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "access_token_request_data": { + "grant_type": "account_credentials", + "account_id": "test_account_id", + }, + } + return config + + def test_creates_builtin_auth_without_registering(self) -> None: + config: ApiKeyAuthConfig = { + "type": "api_key", + "api_key": "test-secret", + "location": "header", + } + auth = cast(APIKeyAuth, rest_api.config_setup.create_auth(config)) + assert auth.api_key == "test-secret" + + def test_not_registering_throws_error(self, custom_auth_config: AuthConfig) -> None: + with pytest.raises(ValueError) as e: + rest_api.config_setup.create_auth(custom_auth_config) + + assert e.match("Invalid authentication: custom_oauth_2.") + + def test_registering_adds_to_AUTH_MAP(self, custom_auth_config: AuthConfig) -> None: + rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) + cls = rest_api.config_setup.get_auth_class("custom_oauth_2") + assert cls is CustomOAuth2 + + # teardown test + del rest_api.config_setup.AUTH_MAP["custom_oauth_2"] + + def test_registering_allows_usage(self, custom_auth_config: AuthConfig) -> None: + rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) + auth = cast(CustomOAuth2, rest_api.config_setup.create_auth(custom_auth_config)) + request = auth.build_access_token_request() + assert request["data"]["account_id"] == "test_account_id" + + # teardown test + del rest_api.config_setup.AUTH_MAP["custom_oauth_2"] + + def test_registering_not_auth_config_base_throws_error(self) -> None: + class NotAuthConfigBase: + pass + + with pytest.raises(ValueError) as e: + rest_api.config_setup.register_auth( + "not_an_auth_config_base", NotAuthConfigBase # type: ignore + ) + assert e.match("Invalid auth: NotAuthConfigBase.") diff --git a/tests/sources/rest_api/configurations/test_custom_paginator_config.py b/tests/sources/rest_api/configurations/test_custom_paginator_config.py new file mode 100644 index 0000000000..f8ac060218 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_custom_paginator_config.py @@ -0,0 +1,69 @@ +from typing import cast + +import pytest + +from dlt.sources import rest_api +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +from dlt.sources.rest_api.typing import PaginatorConfig + + +class CustomPaginator(JSONLinkPaginator): + """A paginator that uses a specific key in the JSON response to find + the next page URL. + """ + + def __init__( + self, + next_url_path="$['@odata.nextLink']", + ): + super().__init__(next_url_path=next_url_path) + + +class TestCustomPaginator: + @pytest.fixture + def custom_paginator_config(self) -> PaginatorConfig: + config: PaginatorConfig = { + "type": "custom_paginator", # type: ignore + "next_url_path": "response.next_page_link", + } + return config + + def teardown_method(self, method): + try: + del rest_api.config_setup.PAGINATOR_MAP["custom_paginator"] + except KeyError: + pass + + def test_creates_builtin_paginator_without_registering(self) -> None: + config: PaginatorConfig = { + "type": "json_response", + "next_url_path": "response.next_page_link", + } + paginator = rest_api.config_setup.create_paginator(config) + assert paginator.has_next_page is True + + def test_not_registering_throws_error(self, custom_paginator_config) -> None: + with pytest.raises(ValueError) as e: + rest_api.config_setup.create_paginator(custom_paginator_config) + + assert e.match("Invalid paginator: custom_paginator.") + + def test_registering_adds_to_PAGINATOR_MAP(self, custom_paginator_config) -> None: + rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) + cls = rest_api.config_setup.get_paginator_class("custom_paginator") + assert cls is CustomPaginator + + def test_registering_allows_usage(self, custom_paginator_config) -> None: + rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) + paginator = rest_api.config_setup.create_paginator(custom_paginator_config) + paginator = cast(CustomPaginator, paginator) + assert paginator.has_next_page is True + assert str(paginator.next_url_path) == "response.next_page_link" + + def test_registering_not_base_paginator_throws_error(self) -> None: + class NotAPaginator: + pass + + with pytest.raises(ValueError) as e: + rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator) # type: ignore[arg-type] + assert e.match("Invalid paginator: NotAPaginator.") diff --git a/tests/sources/rest_api/configurations/test_incremental_config.py b/tests/sources/rest_api/configurations/test_incremental_config.py new file mode 100644 index 0000000000..a374b644df --- /dev/null +++ b/tests/sources/rest_api/configurations/test_incremental_config.py @@ -0,0 +1,352 @@ +import re +import dlt.common +import dlt.common.exceptions +from dlt.common import pendulum + +import dlt.extract +import pytest +from typing import cast + + +import dlt + +from dlt.extract.incremental import Incremental + +from dlt.sources.rest_api import ( + _validate_param_type, + _set_incremental_params, +) + +from dlt.sources.rest_api.config_setup import ( + IncrementalParam, + setup_incremental_object, +) +from dlt.sources.rest_api.typing import ( + IncrementalConfig, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + pass + + +@pytest.fixture() +def incremental_with_init_and_end() -> Incremental[str]: + return dlt.sources.incremental( + cursor_path="updated_at", + initial_value="2024-01-01T00:00:00Z", + end_value="2024-06-30T00:00:00Z", + ) + + +@pytest.fixture() +def incremental_with_init() -> Incremental[str]: + return dlt.sources.incremental( + cursor_path="updated_at", + initial_value="2024-01-01T00:00:00Z", + ) + + +def test_invalid_incremental_type_is_not_accepted() -> None: + request_params = { + "foo": "bar", + "since": { + "type": "no_incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + }, + } + with pytest.raises(ValueError) as e: + _validate_param_type(request_params) + + assert e.match("Invalid param type: no_incremental.") + + +def test_one_resource_cannot_have_many_incrementals() -> None: + request_params = { + "foo": "bar", + "first_incremental": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + }, + "second_incremental": { + "type": "incremental", + "cursor_path": "created_at", + "initial_value": "2024-01-01T00:00:00Z", + }, + } + with pytest.raises(ValueError) as e: + setup_incremental_object(request_params) + error_message = re.escape( + "Only a single incremental parameter is allower per endpoint. Found: ['first_incremental'," + " 'second_incremental']" + ) + assert e.match(error_message) + + +def test_one_resource_cannot_have_many_incrementals_2(incremental_with_init) -> None: + request_params = { + "foo": "bar", + "first_incremental": { + "type": "incremental", + "cursor_path": "created_at", + "initial_value": "2024-02-02T00:00:00Z", + }, + "second_incremental": incremental_with_init, + } + with pytest.raises(ValueError) as e: + setup_incremental_object(request_params) + error_message = re.escape( + "Only a single incremental parameter is allower per endpoint. Found: ['first_incremental'," + " 'second_incremental']" + ) + assert e.match(error_message) + + +def test_constructs_incremental_from_request_param() -> None: + request_params = { + "foo": "bar", + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + }, + } + (incremental_config, incremental_param, _) = setup_incremental_object(request_params) + assert incremental_config == dlt.sources.incremental( + cursor_path="updated_at", initial_value="2024-01-01T00:00:00Z" + ) + assert incremental_param == IncrementalParam(start="since", end=None) + + +def test_constructs_incremental_from_request_param_with_incremental_object( + incremental_with_init, +) -> None: + request_params = { + "foo": "bar", + "since": dlt.sources.incremental( + cursor_path="updated_at", initial_value="2024-01-01T00:00:00Z" + ), + } + (incremental_obj, incremental_param, _) = setup_incremental_object(request_params) + assert incremental_param == IncrementalParam(start="since", end=None) + + assert incremental_with_init == incremental_obj + + +def test_constructs_incremental_from_request_param_with_convert( + incremental_with_init, +) -> None: + def epoch_to_datetime(epoch: str): + return pendulum.from_timestamp(int(epoch)) + + param_config = { + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "convert": epoch_to_datetime, + } + } + + (incremental_obj, incremental_param, convert) = setup_incremental_object(param_config, None) + assert incremental_param == IncrementalParam(start="since", end=None) + assert convert == epoch_to_datetime + + assert incremental_with_init == incremental_obj + + +def test_does_not_construct_incremental_from_request_param_with_unsupported_incremental( + incremental_with_init_and_end, +) -> None: + param_config = { + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "end_value": "2024-06-30T00:00:00Z", # This is ignored + } + } + + with pytest.raises(ValueError) as e: + setup_incremental_object(param_config) + + assert e.match( + "Only start_param and initial_value are allowed in the configuration of param: since." + ) + + param_config_2 = { + "since_2": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "end_param": "2024-06-30T00:00:00Z", # This is ignored + } + } + + with pytest.raises(ValueError) as e: + setup_incremental_object(param_config_2) + + assert e.match( + "Only start_param and initial_value are allowed in the configuration of param: since_2." + ) + + param_config_3 = {"since_3": incremental_with_init_and_end} + + with pytest.raises(ValueError) as e: + setup_incremental_object(param_config_3) + + assert e.match("Only initial_value is allowed in the configuration of param: since_3.") + + +def test_constructs_incremental_from_endpoint_config_incremental( + incremental_with_init, +) -> None: + config = { + "incremental": { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + } + } + incremental_config = cast(IncrementalConfig, config.get("incremental")) + (incremental_obj, incremental_param, _) = setup_incremental_object( + {}, + incremental_config, + ) + assert incremental_param == IncrementalParam(start="since", end="until") + + assert incremental_with_init == incremental_obj + + +def test_constructs_incremental_from_endpoint_config_incremental_with_convert( + incremental_with_init_and_end, +) -> None: + def epoch_to_datetime(epoch): + return pendulum.from_timestamp(int(epoch)) + + resource_config_incremental: IncrementalConfig = { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "end_value": "2024-06-30T00:00:00Z", + "convert": epoch_to_datetime, + } + + (incremental_obj, incremental_param, convert) = setup_incremental_object( + {}, resource_config_incremental + ) + assert incremental_param == IncrementalParam(start="since", end="until") + assert convert == epoch_to_datetime + assert incremental_with_init_and_end == incremental_obj + + +def test_calls_convert_from_endpoint_config_incremental(mocker) -> None: + def epoch_to_date(epoch: str): + return pendulum.from_timestamp(int(epoch)).to_date_string() + + callback = mocker.Mock(side_effect=epoch_to_date) + incremental_obj = mocker.Mock() + incremental_obj.last_value = "1" + + incremental_param = IncrementalParam(start="since", end=None) + created_param = _set_incremental_params({}, incremental_obj, incremental_param, callback) + assert created_param == {"since": "1970-01-01"} + assert callback.call_args_list[0].args == ("1",) + + +def test_calls_convert_from_request_param(mocker) -> None: + def epoch_to_datetime(epoch: str): + return pendulum.from_timestamp(int(epoch)).to_date_string() + + callback = mocker.Mock(side_effect=epoch_to_datetime) + start = 1 + one_day_later = 60 * 60 * 24 + incremental_config: IncrementalConfig = { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": str(start), + "end_value": str(one_day_later), + "convert": callback, + } + + (incremental_obj, incremental_param, _) = setup_incremental_object({}, incremental_config) + assert incremental_param is not None + assert incremental_obj is not None + created_param = _set_incremental_params({}, incremental_obj, incremental_param, callback) + assert created_param == {"since": "1970-01-01", "until": "1970-01-02"} + assert callback.call_args_list[0].args == (str(start),) + assert callback.call_args_list[1].args == (str(one_day_later),) + + +def test_default_convert_is_identity() -> None: + start = 1 + one_day_later = 60 * 60 * 24 + incremental_config: IncrementalConfig = { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": str(start), + "end_value": str(one_day_later), + } + + (incremental_obj, incremental_param, _) = setup_incremental_object({}, incremental_config) + assert incremental_param is not None + assert incremental_obj is not None + created_param = _set_incremental_params({}, incremental_obj, incremental_param, None) + assert created_param == {"since": str(start), "until": str(one_day_later)} + + +def test_incremental_param_transform_is_deprecated(incremental_with_init) -> None: + """Tests that deprecated interface works but issues deprecation warning""" + + def epoch_to_datetime(epoch: str): + return pendulum.from_timestamp(int(epoch)) + + param_config = { + "since": { + "type": "incremental", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "transform": epoch_to_datetime, + } + } + + with pytest.deprecated_call(): + (incremental_obj, incremental_param, convert) = setup_incremental_object(param_config, None) + + assert incremental_param == IncrementalParam(start="since", end=None) + assert convert == epoch_to_datetime + + assert incremental_with_init == incremental_obj + + +def test_incremental_endpoint_config_transform_is_deprecated( + incremental_with_init_and_end, +) -> None: + """Tests that deprecated interface works but issues deprecation warning""" + + def epoch_to_datetime(epoch): + return pendulum.from_timestamp(int(epoch)) + + resource_config_incremental: IncrementalConfig = { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": "2024-01-01T00:00:00Z", + "end_value": "2024-06-30T00:00:00Z", + "transform": epoch_to_datetime, # type: ignore[typeddict-unknown-key] + } + + with pytest.deprecated_call(): + (incremental_obj, incremental_param, convert) = setup_incremental_object( + {}, resource_config_incremental + ) + assert incremental_param == IncrementalParam(start="since", end="until") + assert convert == epoch_to_datetime + assert incremental_with_init_and_end == incremental_obj diff --git a/tests/sources/rest_api/configurations/test_paginator_config.py b/tests/sources/rest_api/configurations/test_paginator_config.py new file mode 100644 index 0000000000..6513daf15c --- /dev/null +++ b/tests/sources/rest_api/configurations/test_paginator_config.py @@ -0,0 +1,161 @@ +from typing import get_args + +import pytest + +import dlt +import dlt.common +import dlt.common.exceptions +import dlt.extract +from dlt.common.jsonpath import compile_path +from dlt.sources.helpers.rest_client.paginators import ( + HeaderLinkPaginator, + JSONResponseCursorPaginator, + JSONResponsePaginator, + OffsetPaginator, + PageNumberPaginator, +) +from dlt.sources.rest_api import ( + rest_api_source, +) +from dlt.sources.rest_api.config_setup import ( + PAGINATOR_MAP, + create_paginator, +) +from dlt.sources.rest_api.typing import ( + PaginatorConfig, + PaginatorType, + RESTAPIConfig, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + from dlt.sources.helpers.rest_client.paginators import ( + JSONResponsePaginator as JSONLinkPaginator, + ) + + +from .source_configs import ( + PAGINATOR_TYPE_CONFIGS, +) + + +@pytest.mark.parametrize("paginator_type", get_args(PaginatorType)) +def test_paginator_shorthands(paginator_type: PaginatorConfig) -> None: + try: + create_paginator(paginator_type) + except ValueError as v_ex: + # offset paginator cannot be instantiated + assert paginator_type == "offset" + assert "offset" in str(v_ex) + + +@pytest.mark.parametrize("paginator_type_config", PAGINATOR_TYPE_CONFIGS) +def test_paginator_type_configs(paginator_type_config: PaginatorConfig) -> None: + paginator = create_paginator(paginator_type_config) + if paginator_type_config["type"] == "auto": # type: ignore[index] + assert paginator is None + else: + # assert types and default params + assert isinstance(paginator, PAGINATOR_MAP[paginator_type_config["type"]]) # type: ignore[index] + # check if params are bound + if isinstance(paginator, HeaderLinkPaginator): + assert paginator.links_next_key == "next_page" + if isinstance(paginator, PageNumberPaginator): + assert paginator.current_value == 10 + assert paginator.base_index == 1 + assert paginator.param_name == "page" + assert paginator.total_path == compile_path("response.pages") + assert paginator.maximum_value is None + if isinstance(paginator, OffsetPaginator): + assert paginator.current_value == 0 + assert paginator.param_name == "offset" + assert paginator.limit == 100 + assert paginator.limit_param == "limit" + assert paginator.total_path == compile_path("total") + assert paginator.maximum_value == 1000 + if isinstance(paginator, JSONLinkPaginator): + assert paginator.next_url_path == compile_path("response.nex_page_link") + if isinstance(paginator, JSONResponseCursorPaginator): + assert paginator.cursor_path == compile_path("cursors.next") + assert paginator.cursor_param == "cursor" + + +def test_paginator_instance_config() -> None: + paginator = OffsetPaginator(limit=100) + assert create_paginator(paginator) is paginator + + +def test_page_number_paginator_creation() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + "paginator": { + "type": "page_number", + "page_param": "foobar", + "total_path": "response.pages", + "base_page": 1, + "maximum_page": 5, + }, + }, + "resources": ["posts"], + } + try: + rest_api_source(config) + except dlt.common.exceptions.DictValidationException: + pytest.fail("DictValidationException was unexpectedly raised") + + +def test_allow_deprecated_json_response_paginator(mock_api_server) -> None: + """ + Delete this test as soon as we stop supporting the deprecated key json_response + for the JSONLinkPaginator + """ + config: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "paginator": { + "type": "json_response", + "next_url_path": "links.next", + }, + }, + }, + ], + } + + rest_api_source(config) + + +def test_allow_deprecated_json_response_paginator_2(mock_api_server) -> None: + """ + Delete this test as soon as we stop supporting the deprecated key json_response + for the JSONLinkPaginator + """ + config: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "paginator": JSONResponsePaginator(next_url_path="links.next"), + }, + }, + ], + } + + rest_api_source(config) + + +def test_error_message_invalid_paginator() -> None: + with pytest.raises(ValueError) as e: + create_paginator("non_existing_method") # type: ignore + assert ( + str(e.value) + == "Invalid paginator: non_existing_method. Available options: json_link, json_response," + " header_link, auto, single_page, cursor, offset, page_number." + ) diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py new file mode 100644 index 0000000000..a0ca7ce890 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -0,0 +1,332 @@ +import re +from copy import deepcopy + +import pytest +from graphlib import CycleError # type: ignore + +from dlt.sources.rest_api import ( + rest_api_resources, + rest_api_source, +) +from dlt.sources.rest_api.config_setup import ( + _bind_path_params, + process_parent_data_item, +) +from dlt.sources.rest_api.typing import ( + EndpointResource, + ResolvedParam, + RESTAPIConfig, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + from dlt.sources.helpers.rest_client.paginators import ( + JSONResponsePaginator as JSONLinkPaginator, + ) + + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + pass + + +def test_bind_path_param() -> None: + three_params: EndpointResource = { + "name": "comments", + "endpoint": { + "path": "{org}/{repo}/issues/{id}/comments", + "params": { + "org": "dlt-hub", + "repo": "dlt", + "id": { + "type": "resolve", + "field": "id", + "resource": "issues", + }, + }, + }, + } + tp_1 = deepcopy(three_params) + _bind_path_params(tp_1) + + # do not replace resolved params + assert tp_1["endpoint"]["path"] == "dlt-hub/dlt/issues/{id}/comments" # type: ignore[index] + # bound params popped + assert len(tp_1["endpoint"]["params"]) == 1 # type: ignore[index] + assert "id" in tp_1["endpoint"]["params"] # type: ignore[index] + + tp_2 = deepcopy(three_params) + tp_2["endpoint"]["params"]["id"] = 12345 # type: ignore[index] + _bind_path_params(tp_2) + assert tp_2["endpoint"]["path"] == "dlt-hub/dlt/issues/12345/comments" # type: ignore[index] + assert len(tp_2["endpoint"]["params"]) == 0 # type: ignore[index] + + # param missing + tp_3 = deepcopy(three_params) + with pytest.raises(ValueError) as val_ex: + del tp_3["endpoint"]["params"]["id"] # type: ignore[index, union-attr] + _bind_path_params(tp_3) + # path is a part of an exception + assert tp_3["endpoint"]["path"] in str(val_ex.value) # type: ignore[index] + + # path without params + tp_4 = deepcopy(three_params) + tp_4["endpoint"]["path"] = "comments" # type: ignore[index] + # no unbound params + del tp_4["endpoint"]["params"]["id"] # type: ignore[index, union-attr] + tp_5 = deepcopy(tp_4) + _bind_path_params(tp_4) + assert tp_4 == tp_5 + + # resolved param will remain unbounded and + tp_6 = deepcopy(three_params) + tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] + with pytest.raises(NotImplementedError): + _bind_path_params(tp_6) + + +def test_process_parent_data_item() -> None: + resolve_param = ResolvedParam( + "id", {"field": "obj_id", "resource": "issues", "type": "resolve"} + ) + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_param, None + ) + assert bound_path == "dlt-hub/dlt/issues/12345/comments" + assert parent_record == {} + + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_param, ["obj_id"] + ) + assert parent_record == {"_issues_obj_id": 12345} + + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", + {"obj_id": 12345, "obj_node": "node_1"}, + resolve_param, + ["obj_id", "obj_node"], + ) + assert parent_record == {"_issues_obj_id": 12345, "_issues_obj_node": "node_1"} + + # test nested data + resolve_param_nested = ResolvedParam( + "id", {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"} + ) + item = {"some_results": {"obj_id": 12345}} + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", item, resolve_param_nested, None + ) + assert bound_path == "dlt-hub/dlt/issues/12345/comments" + + # param path not found + with pytest.raises(ValueError) as val_ex: + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_param, None + ) + assert "Transformer expects a field 'obj_id'" in str(val_ex.value) + + # included path not found + with pytest.raises(ValueError) as val_ex: + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{id}/comments", + {"obj_id": 12345, "obj_node": "node_1"}, + resolve_param, + ["obj_id", "node"], + ) + assert "in order to include it in child records under _issues_node" in str(val_ex.value) + + +def test_two_resources_can_depend_on_one_parent_resource() -> None: + user_id = { + "user_id": { + "type": "resolve", + "field": "id", + "resource": "users", + } + } + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + "users", + { + "name": "user_details", + "endpoint": { + "path": "user/{user_id}/", + "params": user_id, # type: ignore[typeddict-item] + }, + }, + { + "name": "meetings", + "endpoint": { + "path": "meetings/{user_id}/", + "params": user_id, # type: ignore[typeddict-item] + }, + }, + ], + } + resources = rest_api_source(config).resources + assert resources["meetings"]._pipe.parent.name == "users" + assert resources["user_details"]._pipe.parent.name == "users" + + +def test_dependent_resource_cannot_bind_multiple_parameters() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + "users", + { + "name": "user_details", + "endpoint": { + "path": "user/{user_id}/{group_id}", + "params": { + "user_id": { + "type": "resolve", + "field": "id", + "resource": "users", + }, + "group_id": { + "type": "resolve", + "field": "group", + "resource": "users", + }, + }, + }, + }, + ], + } + with pytest.raises(ValueError) as e: + rest_api_resources(config) + + error_part_1 = re.escape( + "Multiple resolved params for resource user_details: [ResolvedParam(param_name='user_id'" + ) + error_part_2 = re.escape("ResolvedParam(param_name='group_id'") + assert e.match(error_part_1) + assert e.match(error_part_2) + + +def test_one_resource_cannot_bind_two_parents() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + "users", + "groups", + { + "name": "user_details", + "endpoint": { + "path": "user/{user_id}/{group_id}", + "params": { + "user_id": { + "type": "resolve", + "field": "id", + "resource": "users", + }, + "group_id": { + "type": "resolve", + "field": "id", + "resource": "groups", + }, + }, + }, + }, + ], + } + + with pytest.raises(ValueError) as e: + rest_api_resources(config) + + error_part_1 = re.escape( + "Multiple resolved params for resource user_details: [ResolvedParam(param_name='user_id'" + ) + error_part_2 = re.escape("ResolvedParam(param_name='group_id'") + assert e.match(error_part_1) + assert e.match(error_part_2) + + +def test_resource_dependent_dependent() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + "locations", + { + "name": "location_details", + "endpoint": { + "path": "location/{location_id}", + "params": { + "location_id": { + "type": "resolve", + "field": "id", + "resource": "locations", + }, + }, + }, + }, + { + "name": "meetings", + "endpoint": { + "path": "/meetings/{room_id}", + "params": { + "room_id": { + "type": "resolve", + "field": "room_id", + "resource": "location_details", + }, + }, + }, + }, + ], + } + + resources = rest_api_source(config).resources + assert resources["meetings"]._pipe.parent.name == "location_details" + assert resources["location_details"]._pipe.parent.name == "locations" + + +def test_circular_resource_bindingis_invalid() -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "chicken", + "endpoint": { + "path": "chicken/{egg_id}/", + "params": { + "egg_id": { + "type": "resolve", + "field": "id", + "resource": "egg", + }, + }, + }, + }, + { + "name": "egg", + "endpoint": { + "path": "egg/{chicken_id}/", + "params": { + "chicken_id": { + "type": "resolve", + "field": "id", + "resource": "chicken", + }, + }, + }, + }, + ], + } + + with pytest.raises(CycleError) as e: + rest_api_resources(config) + assert e.match(re.escape("'nodes are in a cycle', ['chicken', 'egg', 'chicken']")) diff --git a/tests/sources/rest_api/configurations/test_response_actions_config.py b/tests/sources/rest_api/configurations/test_response_actions_config.py new file mode 100644 index 0000000000..c9889b1e09 --- /dev/null +++ b/tests/sources/rest_api/configurations/test_response_actions_config.py @@ -0,0 +1,138 @@ +import pytest +from typing import List + +from dlt.sources.rest_api import ( + rest_api_source, +) + +from dlt.sources.rest_api.config_setup import ( + create_response_hooks, + _handle_response_action, +) +from dlt.sources.rest_api.typing import ( + RESTAPIConfig, + ResponseAction, +) + +try: + from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator +except ImportError: + pass + + +def test_create_multiple_response_actions(): + def custom_hook(response, *args, **kwargs): + return response + + response_actions: List[ResponseAction] = [ + custom_hook, + {"status_code": 404, "action": "ignore"}, + {"content": "Not found", "action": "ignore"}, + {"status_code": 200, "content": "some text", "action": "ignore"}, + ] + hooks = create_response_hooks(response_actions) + assert len(hooks["response"]) == 4 + + response_actions_2: List[ResponseAction] = [ + custom_hook, + {"status_code": 200, "action": custom_hook}, + ] + hooks_2 = create_response_hooks(response_actions_2) + assert len(hooks_2["response"]) == 2 + + +def test_response_action_raises_type_error(mocker): + class C: + pass + + response = mocker.Mock() + response.status_code = 200 + + with pytest.raises(ValueError) as e_1: + _handle_response_action(response, {"status_code": 200, "action": C()}) # type: ignore[typeddict-item] + assert e_1.match("does not conform to expected type") + + with pytest.raises(ValueError) as e_2: + _handle_response_action(response, {"status_code": 200, "action": 123}) # type: ignore[typeddict-item] + assert e_2.match("does not conform to expected type") + + assert ("ignore", None) == _handle_response_action( + response, {"status_code": 200, "action": "ignore"} + ) + assert ("foobar", None) == _handle_response_action( + response, {"status_code": 200, "action": "foobar"} + ) + + +def test_parses_hooks_from_response_actions(mocker): + response = mocker.Mock() + response.status_code = 200 + + hook_1 = mocker.Mock() + hook_2 = mocker.Mock() + + assert (None, [hook_1]) == _handle_response_action( + response, {"status_code": 200, "action": hook_1} + ) + assert (None, [hook_1, hook_2]) == _handle_response_action( + response, {"status_code": 200, "action": [hook_1, hook_2]} + ) + + +def test_config_validation_for_response_actions(mocker): + mock_response_hook_1 = mocker.Mock() + mock_response_hook_2 = mocker.Mock() + config_1: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + { + "status_code": 200, + "action": mock_response_hook_1, + }, + ], + }, + }, + ], + } + + rest_api_source(config_1) + + config_2: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + mock_response_hook_1, + mock_response_hook_2, + ], + }, + }, + ], + } + + rest_api_source(config_2) + + config_3: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + { + "status_code": 200, + "action": [mock_response_hook_1, mock_response_hook_2], + }, + ], + }, + }, + ], + } + + rest_api_source(config_3) diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py new file mode 100644 index 0000000000..8ef4e41255 --- /dev/null +++ b/tests/sources/rest_api/conftest.py @@ -0,0 +1,270 @@ +import base64 +from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode + +import pytest +import requests_mock + +from dlt.sources.helpers.rest_client import RESTClient + +from tests.sources.helpers.rest_client.api_router import APIRouter +from tests.sources.helpers.rest_client.paginators import ( + PageNumberPaginator, + OffsetPaginator, + CursorPaginator, +) + + +MOCK_BASE_URL = "https://api.example.com" +DEFAULT_PAGE_SIZE = 5 +DEFAULT_TOTAL_PAGES = 5 +DEFAULT_LIMIT = 10 + + +router = APIRouter(MOCK_BASE_URL) + + +def generate_posts(count=DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES): + return [{"id": i, "title": f"Post {i}"} for i in range(count)] + + +def generate_comments(post_id, count=50): + return [ + {"id": i, "post_id": post_id, "body": f"Comment {i} for post {post_id}"} + for i in range(count) + ] + + +def get_page_number(qs, key="page", default=1): + return int(qs.get(key, [default])[0]) + + +def create_next_page_url(request, paginator, use_absolute_url=True): + scheme, netloc, path, _, _ = urlsplit(request.url) + query = urlencode(paginator.next_page_url_params) + if use_absolute_url: + return urlunsplit([scheme, netloc, path, query, ""]) + else: + return f"{path}?{query}" + + +def paginate_by_page_number( + request, records, records_key="data", use_absolute_url=True, index_base=1 +): + page_number = get_page_number(request.qs, default=index_base) + paginator = PageNumberPaginator(records, page_number, index_base=index_base) + + response = { + records_key: paginator.page_records, + **paginator.metadata, + } + + if paginator.next_page_url_params: + response["next_page"] = create_next_page_url(request, paginator, use_absolute_url) + + return response + + +@pytest.fixture(scope="module") +def mock_api_server(): + with requests_mock.Mocker() as m: + + @router.get(r"/posts_no_key(\?page=\d+)?$") + def posts_no_key(request, context): + return paginate_by_page_number(request, generate_posts(), records_key=None) + + @router.get(r"/posts(\?page=\d+)?$") + def posts(request, context): + return paginate_by_page_number(request, generate_posts()) + + @router.get(r"/posts_zero_based(\?page=\d+)?$") + def posts_zero_based(request, context): + return paginate_by_page_number(request, generate_posts(), index_base=0) + + @router.get(r"/posts_header_link(\?page=\d+)?$") + def posts_header_link(request, context): + records = generate_posts() + page_number = get_page_number(request.qs) + paginator = PageNumberPaginator(records, page_number) + + response = paginator.page_records + + if paginator.next_page_url_params: + next_page_url = create_next_page_url(request, paginator) + context.headers["Link"] = f'<{next_page_url}>; rel="next"' + + return response + + @router.get(r"/posts_relative_next_url(\?page=\d+)?$") + def posts_relative_next_url(request, context): + return paginate_by_page_number(request, generate_posts(), use_absolute_url=False) + + @router.get(r"/posts_offset_limit(\?offset=\d+&limit=\d+)?$") + def posts_offset_limit(request, context): + records = generate_posts() + offset = int(request.qs.get("offset", [0])[0]) + limit = int(request.qs.get("limit", [DEFAULT_LIMIT])[0]) + paginator = OffsetPaginator(records, offset, limit) + + return { + "data": paginator.page_records, + **paginator.metadata, + } + + @router.get(r"/posts_cursor(\?cursor=\d+)?$") + def posts_cursor(request, context): + records = generate_posts() + cursor = int(request.qs.get("cursor", [0])[0]) + paginator = CursorPaginator(records, cursor) + + return { + "data": paginator.page_records, + **paginator.metadata, + } + + @router.get(r"/posts/(\d+)/comments") + def post_comments(request, context): + post_id = int(request.url.split("/")[-2]) + return paginate_by_page_number(request, generate_comments(post_id)) + + @router.get(r"/posts/\d+$") + def post_detail(request, context): + post_id = request.url.split("/")[-1] + return {"id": int(post_id), "body": f"Post body {post_id}"} + + @router.get(r"/posts/\d+/some_details_404") + def post_detail_404(request, context): + """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" + post_id = int(request.url.split("/")[-2]) + if post_id < 1: + return {"id": post_id, "body": f"Post body {post_id}"} + else: + context.status_code = 404 + return {"error": "Post not found"} + + @router.get(r"/posts_under_a_different_key$") + def posts_with_results_key(request, context): + return paginate_by_page_number(request, generate_posts(), records_key="many-results") + + @router.post(r"/posts/search$") + def search_posts(request, context): + body = request.json() + page_size = body.get("page_size", DEFAULT_PAGE_SIZE) + page_count = body.get("page_count", DEFAULT_TOTAL_PAGES) + page_number = body.get("page", 1) + + # Simulate a search with filtering + records = generate_posts(page_size * page_count) + ids_greater_than = body.get("ids_greater_than", 0) + records = [r for r in records if r["id"] > ids_greater_than] + + total_records = len(records) + total_pages = (total_records + page_size - 1) // page_size + start_index = (page_number - 1) * page_size + end_index = start_index + page_size + records_slice = records[start_index:end_index] + + return { + "data": records_slice, + "next_page": page_number + 1 if page_number < total_pages else None, + } + + @router.get("/protected/posts/basic-auth") + def protected_basic_auth(request, context): + auth = request.headers.get("Authorization") + creds = "user:password" + creds_base64 = base64.b64encode(creds.encode()).decode() + if auth == f"Basic {creds_base64}": + return paginate_by_page_number(request, generate_posts()) + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.get("/protected/posts/bearer-token") + def protected_bearer_token(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_by_page_number(request, generate_posts()) + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.get("/protected/posts/bearer-token-plain-text-error") + def protected_bearer_token_plain_text_erorr(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_by_page_number(request, generate_posts()) + context.status_code = 401 + return "Unauthorized" + + @router.get("/protected/posts/api-key") + def protected_api_key(request, context): + api_key = request.headers.get("x-api-key") + if api_key == "test-api-key": + return paginate_by_page_number(request, generate_posts()) + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/oauth/token") + def oauth_token(request, context): + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/oauth/token-expires-now") + def oauth_token_expires_now(request, context): + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 0} + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/auth/refresh") + def refresh_token(request, context): + body = request.json() + if body.get("refresh_token") == "valid-refresh-token": + return {"access_token": "new-valid-token"} + context.status_code = 401 + return {"error": "Invalid refresh token"} + + @router.post("/custom-oauth/token") + def custom_oauth_token(request, context): + qs = parse_qs(request.text) + if ( + qs.get("grant_type")[0] == "account_credentials" + and qs.get("account_id")[0] == "test-account-id" + and request.headers["Authorization"] + == "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==" + ): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + + router.register_routes(m) + + yield m + + +@pytest.fixture +def rest_client() -> RESTClient: + return RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + ) + + +def oauth_authorize(request): + qs = parse_qs(request.text) + grant_type = qs.get("grant_type")[0] + if "jwt-bearer" in grant_type: + return True + if "client_credentials" in grant_type: + return ( + qs["client_secret"][0] == "test-client-secret" + and qs["client_id"][0] == "test-client-id" + ) + + +def assert_pagination(pages, page_size=DEFAULT_PAGE_SIZE, total_pages=DEFAULT_TOTAL_PAGES): + assert len(pages) == total_pages + for i, page in enumerate(pages): + assert page == [ + {"id": i, "title": f"Post {i}"} for i in range(i * page_size, (i + 1) * page_size) + ] diff --git a/tests/sources/rest_api/integration/__init__.py b/tests/sources/rest_api/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/rest_api/integration/test_offline.py b/tests/sources/rest_api/integration/test_offline.py new file mode 100644 index 0000000000..2c1f48537b --- /dev/null +++ b/tests/sources/rest_api/integration/test_offline.py @@ -0,0 +1,329 @@ +from typing import Any, List, Optional +from unittest import mock + +import pytest +from requests import Request, Response + +import dlt +from dlt.common import pendulum +from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.sources.helpers.rest_client.paginators import BaseReferencePaginator +from dlt.sources.rest_api import ( + ClientConfig, + Endpoint, + EndpointResource, + RESTAPIConfig, + rest_api_source, +) +from tests.sources.rest_api.conftest import DEFAULT_PAGE_SIZE, DEFAULT_TOTAL_PAGES +from tests.utils import assert_load_info, assert_query_data, load_table_counts + + +def test_load_mock_api(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_comments", + "endpoint": { + "path": "posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + }, + { + "name": "post_details", + "endpoint": { + "path": "posts/{post_id}", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + }, + ], + } + ) + + load_info = pipeline.run(mock_source) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"posts", "post_comments", "post_details"} + + assert table_counts["posts"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_details"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_comments"] == 50 * DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + + with pipeline.sql_client() as client: + posts_table = client.make_qualified_table_name("posts") + posts_details_table = client.make_qualified_table_name("post_details") + post_comments_table = client.make_qualified_table_name("post_comments") + + print(pipeline.default_schema.to_pretty_yaml()) + + assert_query_data( + pipeline, + f"SELECT title FROM {posts_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT body FROM {posts_details_table} ORDER BY id limit 25", + [f"Post body {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT body FROM {post_comments_table} ORDER BY post_id, id limit 5", + [f"Comment {i} for post 0" for i in range(5)], + ) + + +def test_ignoring_endpoint_returning_404(mock_api_server): + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_details", + "endpoint": { + "path": "posts/{post_id}/some_details_404", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + "response_actions": [ + { + "status_code": 404, + "action": "ignore", + }, + ], + }, + }, + ], + } + ) + + res = list(mock_source.with_resources("posts", "post_details").add_limit(1)) + + assert res[:5] == [ + {"id": 0, "body": "Post body 0"}, + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + ] + + +def test_source_with_post_request(mock_api_server): + class JSONBodyPageCursorPaginator(BaseReferencePaginator): + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: + self._next_reference = response.json().get("next_page") + + def update_request(self, request: Request) -> None: + if request.json is None: + request.json = {} + + request.json["page"] = self._next_reference + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "search_posts", + "endpoint": { + "path": "/posts/search", + "method": "POST", + "json": {"ids_greater_than": 50, "page_size": 25, "page_count": 4}, + "paginator": JSONBodyPageCursorPaginator(), + }, + } + ], + } + ) + + res = list(mock_source.with_resources("search_posts")) + + for i in range(49): + assert res[i] == {"id": 51 + i, "title": f"Post {51 + i}"} + + +def test_unauthorized_access_to_protected_endpoint(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "/protected/posts/bearer-token-plain-text-error", + ], + } + ) + + with pytest.raises(PipelineStepFailed) as e: + pipeline.run(mock_source) + assert e.match("401 Client Error") + + +def test_posts_under_results_key(mock_api_server): + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts_under_a_different_key", + "data_selector": "many-results", + "paginator": "json_link", + }, + }, + ], + } + ) + + res = list(mock_source.with_resources("posts").add_limit(1)) + + assert res[:5] == [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + ] + + +def test_posts_without_key(mock_api_server): + mock_source = rest_api_source( + { + "client": { + "base_url": "https://api.example.com", + "paginator": "header_link", + }, + "resources": [ + { + "name": "posts_no_key", + "endpoint": { + "path": "posts_no_key", + }, + }, + ], + } + ) + + res = list(mock_source.with_resources("posts_no_key").add_limit(1)) + + assert res[:5] == [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + ] + + +def test_load_mock_api_typeddict_config(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source = rest_api_source( + RESTAPIConfig( + client=ClientConfig(base_url="https://api.example.com"), + resources=[ + "posts", + EndpointResource( + name="post_comments", + endpoint=Endpoint( + path="posts/{post_id}/comments", + params={ + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + ), + ), + ], + ) + ) + + load_info = pipeline.run(mock_source) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"posts", "post_comments"} + + assert table_counts["posts"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_comments"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES * 50 + + +def test_posts_with_inremental_date_conversion(mock_api_server) -> None: + start_time = pendulum.from_timestamp(1) + one_day_later = start_time.add(days=1) + config: RESTAPIConfig = { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "path": "posts", + "incremental": { + "start_param": "since", + "end_param": "until", + "cursor_path": "updated_at", + "initial_value": str(start_time.int_timestamp), + "end_value": str(one_day_later.int_timestamp), + "convert": lambda epoch: pendulum.from_timestamp( + int(epoch) + ).to_date_string(), + }, + }, + }, + ], + } + RESTClient = dlt.sources.helpers.rest_client.RESTClient + with mock.patch.object(RESTClient, "paginate") as mock_paginate: + source = rest_api_source(config).add_limit(1) + _ = list(source.with_resources("posts")) + assert mock_paginate.call_count == 1 + _, called_kwargs = mock_paginate.call_args_list[0] + assert called_kwargs["params"] == {"since": "1970-01-01", "until": "1970-01-02"} + assert called_kwargs["path"] == "posts" diff --git a/tests/sources/rest_api/integration/test_processing_steps.py b/tests/sources/rest_api/integration/test_processing_steps.py new file mode 100644 index 0000000000..bbe90dda06 --- /dev/null +++ b/tests/sources/rest_api/integration/test_processing_steps.py @@ -0,0 +1,245 @@ +from typing import Any, Callable, Dict, List + +import dlt +from dlt.sources.rest_api import RESTAPIConfig, rest_api_source + + +def _make_pipeline(destination_name: str): + return dlt.pipeline( + pipeline_name="rest_api", + destination=destination_name, + dataset_name="rest_api_data", + full_refresh=True, + ) + + +def test_rest_api_source_filtered(mock_api_server) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] == 1}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + assert len(data) == 1 + assert data[0]["title"] == "Post 1" + + +def test_rest_api_source_exclude_columns(mock_api_server) -> None: + def exclude_columns(columns: List[str]) -> Callable[..., Any]: + def pop_columns(record: Dict[str, Any]) -> Dict[str, Any]: + for col in columns: + record.pop(col) + return record + + return pop_columns + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + { + "map": exclude_columns(["title"]), # type: ignore[typeddict-item] + }, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all("title" not in record for record in data) + + +def test_rest_api_source_anonymize_columns(mock_api_server) -> None: + def anonymize_columns(columns: List[str]) -> Callable[..., Any]: + def empty_columns(record: Dict[str, Any]) -> Dict[str, Any]: + for col in columns: + record[col] = "dummy" + return record + + return empty_columns + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + { + "map": anonymize_columns(["title"]), # type: ignore[typeddict-item] + }, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all(record["title"] == "dummy" for record in data) + + +def test_rest_api_source_map(mock_api_server) -> None: + def lower_title(row): + row["title"] = row["title"].lower() + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"map": lower_title}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all(record["title"].startswith("post ") for record in data) + + +def test_rest_api_source_filter_and_map(mock_api_server) -> None: + def id_by_10(row): + row["id"] = row["id"] * 10 + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"map": id_by_10}, # type: ignore[typeddict-item] + {"filter": lambda x: x["id"] == 10}, # type: ignore[typeddict-item] + ], + }, + { + "name": "posts_2", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] == 10}, # type: ignore[typeddict-item] + {"map": id_by_10}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + assert len(data) == 1 + assert data[0]["title"] == "Post 1" + + data = list(mock_source.with_resources("posts_2")) + assert len(data) == 1 + assert data[0]["id"] == 100 + assert data[0]["title"] == "Post 10" + + +def test_rest_api_source_filtered_child(mock_api_server) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, # type: ignore[typeddict-item] + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + "processing_steps": [ + {"filter": lambda x: x["id"] == 1}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert len(data) == 2 + + +def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: + def extend_body(row): + row["body"] = f"{row['_posts_title']} - {row['body']}" + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, # type: ignore[typeddict-item] + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + "include_from_parent": ["title"], + "processing_steps": [ + {"map": extend_body}, # type: ignore[typeddict-item] + {"filter": lambda x: x["body"].startswith("Post 2")}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert data[0]["body"] == "Post 2 - Comment 0 for post 2" diff --git a/tests/sources/rest_api/integration/test_response_actions.py b/tests/sources/rest_api/integration/test_response_actions.py new file mode 100644 index 0000000000..36a7990db3 --- /dev/null +++ b/tests/sources/rest_api/integration/test_response_actions.py @@ -0,0 +1,135 @@ +from dlt.common import json +from dlt.sources.helpers.requests import Response +from dlt.sources.rest_api import create_response_hooks, rest_api_source + + +def test_response_action_on_status_code(mock_api_server, mocker): + mock_response_hook = mocker.Mock() + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "post_details", + "endpoint": { + "path": "posts/1/some_details_404", + "response_actions": [ + { + "status_code": 404, + "action": mock_response_hook, + }, + ], + }, + }, + ], + } + ) + + list(mock_source.with_resources("post_details").add_limit(1)) + + mock_response_hook.assert_called_once() + + +def test_response_action_on_every_response(mock_api_server, mocker): + def custom_hook(request, *args, **kwargs): + return request + + mock_response_hook = mocker.Mock(side_effect=custom_hook) + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + mock_response_hook, + ], + }, + }, + ], + } + ) + + list(mock_source.with_resources("posts").add_limit(1)) + + mock_response_hook.assert_called_once() + + +def test_multiple_response_actions_on_every_response(mock_api_server, mocker): + def custom_hook(response, *args, **kwargs): + return response + + mock_response_hook_1 = mocker.Mock(side_effect=custom_hook) + mock_response_hook_2 = mocker.Mock(side_effect=custom_hook) + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + mock_response_hook_1, + mock_response_hook_2, + ], + }, + }, + ], + } + ) + + list(mock_source.with_resources("posts").add_limit(1)) + + mock_response_hook_1.assert_called_once() + mock_response_hook_2.assert_called_once() + + +def test_response_actions_called_in_order(mock_api_server, mocker): + def set_encoding(response: Response, *args, **kwargs) -> Response: + assert response.encoding != "windows-1252" + response.encoding = "windows-1252" + return response + + def add_field(response: Response, *args, **kwargs) -> Response: + assert response.encoding == "windows-1252" + payload = response.json() + for record in payload["data"]: + record["custom_field"] = "foobar" + modified_content: bytes = json.dumps(payload).encode("utf-8") + response._content = modified_content + return response + + mock_response_hook_1 = mocker.Mock(side_effect=set_encoding) + mock_response_hook_2 = mocker.Mock(side_effect=add_field) + + response_actions = [ + mock_response_hook_1, + {"status_code": 200, "action": mock_response_hook_2}, + ] + hooks = create_response_hooks(response_actions) + assert len(hooks.get("response")) == 2 + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + { + "name": "posts", + "endpoint": { + "response_actions": [ + mock_response_hook_1, + {"status_code": 200, "action": mock_response_hook_2}, + ], + }, + }, + ], + } + ) + + data = list(mock_source.with_resources("posts").add_limit(1)) + + mock_response_hook_1.assert_called_once() + mock_response_hook_2.assert_called_once() + + assert all(record["custom_field"] == "foobar" for record in data) diff --git a/tests/sources/rest_api/test_rest_api_pipeline_template.py b/tests/sources/rest_api/test_rest_api_pipeline_template.py new file mode 100644 index 0000000000..ef30b63a7f --- /dev/null +++ b/tests/sources/rest_api/test_rest_api_pipeline_template.py @@ -0,0 +1,20 @@ +import dlt +import pytest +from dlt.common.typing import TSecretStrValue + + +# NOTE: needs github secrets to work +@pytest.mark.parametrize( + "example_name", + ( + "load_github", + "load_pokemon", + ), +) +def test_all_examples(example_name: str) -> None: + from dlt.sources import rest_api_pipeline + + # reroute token location from secrets + github_token: TSecretStrValue = dlt.secrets.get("sources.github.access_token") + dlt.secrets["sources.rest_api_pipeline.github.access_token"] = github_token + getattr(rest_api_pipeline, example_name)() diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py new file mode 100644 index 0000000000..f6b97a7f47 --- /dev/null +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -0,0 +1,116 @@ +import dlt +import pytest +from dlt.sources.rest_api.typing import RESTAPIConfig +from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator + +from dlt.sources.rest_api import rest_api_source +from tests.utils import ALL_DESTINATIONS, assert_load_info, load_table_counts + + +def _make_pipeline(destination_name: str): + return dlt.pipeline( + pipeline_name="rest_api", + destination=destination_name, + dataset_name="rest_api_data", + full_refresh=True, + ) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_rest_api_source(destination_name: str) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + } + }, + "resources": [ + { + "name": "pokemon_list", + "endpoint": "pokemon", + }, + "berry", + "location", + ], + } + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + load_info = pipeline.run(data) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"pokemon_list", "berry", "location"} + + assert table_counts["pokemon_list"] == 1302 + assert table_counts["berry"] == 64 + assert table_counts["location"] == 1036 + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_dependent_resource(destination_name: str) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + } + }, + "resources": [ + { + "name": "pokemon_list", + "endpoint": { + "path": "pokemon", + "paginator": SinglePagePaginator(), + "data_selector": "results", + "params": { + "limit": 2, + }, + }, + "selected": False, + }, + { + "name": "pokemon", + "endpoint": { + "path": "pokemon/{name}", + "params": { + "name": { + "type": "resolve", + "resource": "pokemon_list", + "field": "name", + }, + }, + }, + }, + ], + } + + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + load_info = pipeline.run(data) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert set(table_counts.keys()) == { + "pokemon", + "pokemon__types", + "pokemon__stats", + "pokemon__moves__version_group_details", + "pokemon__moves", + "pokemon__game_indices", + "pokemon__forms", + "pokemon__abilities", + } + + assert table_counts["pokemon"] == 2 diff --git a/tests/sources/sql_database/__init__.py b/tests/sources/sql_database/__init__.py new file mode 100644 index 0000000000..f10ab98368 --- /dev/null +++ b/tests/sources/sql_database/__init__.py @@ -0,0 +1 @@ +# almost all tests are in tests/load since a postgres instance is required for this to work diff --git a/tests/sources/sql_database/test_arrow_helpers.py b/tests/sources/sql_database/test_arrow_helpers.py new file mode 100644 index 0000000000..8328bed89b --- /dev/null +++ b/tests/sources/sql_database/test_arrow_helpers.py @@ -0,0 +1,114 @@ +from datetime import date, datetime, timezone # noqa: I251 +from uuid import uuid4 + +import pyarrow as pa +import pytest + +from dlt.sources.sql_database.arrow_helpers import row_tuples_to_arrow + + +@pytest.mark.parametrize("all_unknown", [True, False]) +def test_row_tuples_to_arrow_unknown_types(all_unknown: bool) -> None: + """Test inferring data types with pyarrow""" + + rows = [ + ( + 1, + "a", + 1.1, + True, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [1, 2, 3], + ), + ( + 2, + "b", + 2.2, + False, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [4, 5, 6], + ), + ( + 3, + "c", + 3.3, + True, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [7, 8, 9], + ), + ] + + # Some columns don't specify data type and should be inferred + columns = { + "int_col": {"name": "int_col", "data_type": "bigint", "nullable": False}, + "str_col": {"name": "str_col", "data_type": "text", "nullable": False}, + "float_col": {"name": "float_col", "nullable": False}, + "bool_col": {"name": "bool_col", "data_type": "bool", "nullable": False}, + "date_col": {"name": "date_col", "nullable": False}, + "uuid_col": {"name": "uuid_col", "nullable": False}, + "datetime_col": { + "name": "datetime_col", + "data_type": "timestamp", + "nullable": False, + }, + "array_col": {"name": "array_col", "nullable": False}, + } + + if all_unknown: + for col in columns.values(): + col.pop("data_type", None) + + # Call the function + result = row_tuples_to_arrow(rows, columns, tz="UTC") # type: ignore[arg-type] + + # Result is arrow table containing all columns in original order with correct types + assert result.num_columns == len(columns) + result_col_names = [f.name for f in result.schema] + expected_names = list(columns) + assert result_col_names == expected_names + + assert pa.types.is_int64(result[0].type) + assert pa.types.is_string(result[1].type) + assert pa.types.is_float64(result[2].type) + assert pa.types.is_boolean(result[3].type) + assert pa.types.is_date(result[4].type) + assert pa.types.is_string(result[5].type) + assert pa.types.is_timestamp(result[6].type) + assert pa.types.is_list(result[7].type) + + +pytest.importorskip("sqlalchemy", minversion="2.0") + + +def test_row_tuples_to_arrow_detects_range_type() -> None: + from sqlalchemy.dialects.postgresql import Range # type: ignore[attr-defined] + + # Applies to NUMRANGE, DATERANGE, etc sql types. Sqlalchemy returns a Range dataclass + IntRange = Range + + rows = [ + (IntRange(1, 10),), + (IntRange(2, 20),), + (IntRange(3, 30),), + ] + result = row_tuples_to_arrow( + rows=rows, # type: ignore[arg-type] + columns={"range_col": {"name": "range_col", "nullable": False}}, + tz="UTC", + ) + assert result.num_columns == 1 + assert pa.types.is_struct(result[0].type) + + # Check range has all fields + range_type = result[0].type + range_fields = {f.name: f for f in range_type} + assert pa.types.is_int64(range_fields["lower"].type) + assert pa.types.is_int64(range_fields["upper"].type) + assert pa.types.is_boolean(range_fields["empty"].type) + assert pa.types.is_string(range_fields["bounds"].type) diff --git a/tests/sources/sql_database/test_sql_database_pipeline_template.py b/tests/sources/sql_database/test_sql_database_pipeline_template.py new file mode 100644 index 0000000000..88c05ea333 --- /dev/null +++ b/tests/sources/sql_database/test_sql_database_pipeline_template.py @@ -0,0 +1,22 @@ +import pytest + + +# TODO: not all template functions are tested here +# we may be able to test more in tests/load/sources +@pytest.mark.parametrize( + "example_name", + ( + "load_select_tables_from_database", + # "load_entire_database", + "load_standalone_table_resource", + "select_columns", + "specify_columns_to_load", + "test_pandas_backend_verbatim_decimals", + "select_with_end_value_and_row_order", + "my_sql_via_pyarrow", + ), +) +def test_all_examples(example_name: str) -> None: + from dlt.sources import sql_database_pipeline + + getattr(sql_database_pipeline, example_name)() diff --git a/tests/sources/test_pipeline_templates.py b/tests/sources/test_pipeline_templates.py new file mode 100644 index 0000000000..0743a21fef --- /dev/null +++ b/tests/sources/test_pipeline_templates.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.mark.parametrize( + "example_name", + ("load_all_datatypes",), +) +def test_debug_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import debug_pipeline + + getattr(debug_pipeline, example_name)() + + +@pytest.mark.parametrize( + "example_name", + ("load_arrow_tables",), +) +def test_arrow_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import arrow_pipeline + + getattr(arrow_pipeline, example_name)() + + +@pytest.mark.parametrize( + "example_name", + ("load_dataframe",), +) +def test_dataframe_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import dataframe_pipeline + + getattr(dataframe_pipeline, example_name)() + + +@pytest.mark.parametrize( + "example_name", + ("load_stuff",), +) +def test_default_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import default_pipeline + + getattr(default_pipeline, example_name)() + + +@pytest.mark.parametrize( + "example_name", + ("load_chess_data",), +) +def test_requests_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import requests_pipeline + + getattr(requests_pipeline, example_name)() + + +@pytest.mark.parametrize( + "example_name", + ("load_api_data", "load_sql_data", "load_pandas_data"), +) +def test_intro_pipeline(example_name: str) -> None: + from dlt.sources.pipeline_templates import intro_pipeline + + getattr(intro_pipeline, example_name)() diff --git a/tests/tools/clean_athena.py b/tests/tools/clean_athena.py new file mode 100644 index 0000000000..163cf4a4e7 --- /dev/null +++ b/tests/tools/clean_athena.py @@ -0,0 +1,20 @@ +"""WARNING: Running this script will drop add schemas in the athena destination set up in your secrets.toml""" + +import dlt +from dlt.destinations.exceptions import DatabaseUndefinedRelation + +if __name__ == "__main__": + pipeline = dlt.pipeline(pipeline_name="drop_athena", destination="athena") + + with pipeline.sql_client() as client: + with client.execute_query("SHOW DATABASES") as cur: + dbs = cur.fetchall() + for db in dbs: + db = db[0] + sql = f"DROP SCHEMA `{db}` CASCADE;" + try: + print(sql) + with client.execute_query(sql): + pass # + except DatabaseUndefinedRelation: + print("Could not delete schema") diff --git a/tests/tools/clean_redshift.py b/tests/tools/clean_redshift.py index 96364d68fb..2783820cc5 100644 --- a/tests/tools/clean_redshift.py +++ b/tests/tools/clean_redshift.py @@ -1,32 +1,34 @@ -from dlt.destinations.impl.postgres.postgres import PostgresClient -from dlt.destinations.impl.postgres.sql_client import psycopg2 -from psycopg2.errors import InsufficientPrivilege, InternalError_, SyntaxError +"""WARNING: Running this script will drop add schemas in the redshift destination set up in your secrets.toml""" -CONNECTION_STRING = "" +import dlt +from dlt.destinations.exceptions import ( + DatabaseUndefinedRelation, + DatabaseTerminalException, + DatabaseTransientException, +) if __name__ == "__main__": - # connect - connection = psycopg2.connect(CONNECTION_STRING) - connection.set_isolation_level(0) + pipeline = dlt.pipeline(pipeline_name="drop_redshift", destination="redshift") - # list all schemas - with connection.cursor() as curr: - curr.execute("""select s.nspname as table_schema, + with pipeline.sql_client() as client: + with client.execute_query("""select s.nspname as table_schema, s.oid as schema_id, u.usename as owner from pg_catalog.pg_namespace s join pg_catalog.pg_user u on u.usesysid = s.nspowner - order by table_schema;""") - schemas = [row[0] for row in curr.fetchall()] - - # delete all schemas, skipp expected errors - with connection.cursor() as curr: - print(f"Deleting {len(schemas)} schemas") - for schema in schemas: - print(f"Deleting {schema}...") + order by table_schema;""") as cur: + dbs = [row[0] for row in cur.fetchall()] + for db in dbs: + if db.startswith("<"): + continue + sql = f"DROP SCHEMA {db} CASCADE;" try: - curr.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE;") - except (InsufficientPrivilege, InternalError_, SyntaxError) as ex: - print(ex) - pass - print(f"Deleted {schema}...") + print(sql) + with client.execute_query(sql): + pass # + except ( + DatabaseUndefinedRelation, + DatabaseTerminalException, + DatabaseTransientException, + ): + print("Could not delete schema") diff --git a/tests/utils.py b/tests/utils.py index 7ae8a361b4..63acb96be7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,7 @@ import platform import sys from os import environ -from typing import Any, Iterable, Iterator, Literal, Union, get_args +from typing import Any, Iterable, Iterator, Literal, Union, get_args, List from unittest.mock import patch import pytest @@ -18,17 +18,21 @@ from dlt.common.configuration.specs.config_providers_context import ( ConfigProvidersContext, ) -from dlt.common.pipeline import PipelineContext, SupportsPipeline +from dlt.common.pipeline import LoadInfo, PipelineContext, SupportsPipeline from dlt.common.runtime.init import init_logging from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry from dlt.common.schema import Schema from dlt.common.storages import FileStorage from dlt.common.storages.versioned_storage import VersionedStorage -from dlt.common.typing import StrAny, TDataItem +from dlt.common.typing import DictStrAny, StrAny, TDataItem from dlt.common.utils import custom_environ, uniq_id TEST_STORAGE_ROOT = "_storage" +ALL_DESTINATIONS = dlt.config.get("ALL_DESTINATIONS", list) or [ + "duckdb", +] + # destination constants IMPLEMENTED_DESTINATIONS = { @@ -333,3 +337,47 @@ def is_running_in_github_fork() -> bool: skipifgithubfork = pytest.mark.skipif( is_running_in_github_fork(), reason="Skipping test because it runs on a PR coming from fork" ) + + +def assert_load_info(info: LoadInfo, expected_load_packages: int = 1) -> None: + """Asserts that expected number of packages was loaded and there are no failed jobs""" + assert len(info.loads_ids) == expected_load_packages + # all packages loaded + assert all(package.state == "loaded" for package in info.load_packages) is True + # no failed jobs in any of the packages + info.raise_on_failed_jobs() + + +def load_table_counts(p: dlt.Pipeline, *table_names: str) -> DictStrAny: + """Returns row counts for `table_names` as dict""" + with p.sql_client() as c: + query = "\nUNION ALL\n".join( + [ + f"SELECT '{name}' as name, COUNT(1) as c FROM {c.make_qualified_table_name(name)}" + for name in table_names + ] + ) + with c.execute_query(query) as cur: + rows = list(cur.fetchall()) + return {r[0]: r[1] for r in rows} + + +def assert_query_data( + p: dlt.Pipeline, + sql: str, + table_data: List[Any], + schema_name: str = None, + info: LoadInfo = None, +) -> None: + """Asserts that query selecting single column of values matches `table_data`. If `info` is provided, second column must contain one of load_ids in `info`""" + with p.sql_client(schema_name=schema_name) as c: + with c.execute_query(sql) as cur: + rows = list(cur.fetchall()) + assert len(rows) == len(table_data) + for r, d in zip(rows, table_data): + row = list(r) + # first element comes from the data + assert row[0] == d + # the second is load id + if info: + assert row[1] in info.loads_ids