diff --git a/.codegen.json b/.codegen.json new file mode 100644 index 0000000000..406e422040 --- /dev/null +++ b/.codegen.json @@ -0,0 +1,10 @@ +{ + "version": { + "src/databricks/labs/remorph/__about__.py": "__version__ = \"$VERSION\"" + }, + "toolchain": { + "required": ["hatch"], + "pre_setup": ["hatch env create"], + "prepend_path": ".venv/bin" + } +} diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..15db44226b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +# Top-most EditorConfig file. +root = true + +# Universal settings. +[*] +indent_style = space +indent_size = 4 +tab_width = 8 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[Makefile] +indent_style = tab +indent_size = tab + +[*.{json,yml}] +indent_size = 2 + +[*.scala] +indent_size = 2 diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 0000000000..13210dfa44 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,92 @@ +# See https://docs.github.com/en/communities/using-templates-to-encourage-useful-issues-and-pull-requests/syntax-for-issue-forms +# and https://docs.github.com/en/communities/using-templates-to-encourage-useful-issues-and-pull-requests/syntax-for-githubs-form-schema +name: Bug Report +description: Something is not working in Remorph +title: "[BUG]: " +labels: ["bug", "needs-triage"] +# assignees: +# - remorph-write +body: + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the bug you encountered. + options: + - label: I have searched the existing issues + required: true + - type: dropdown + id: category + attributes: + label: Category of Bug / Issue + description: Please select the category that best describes the bug / issue you are reporting. + options: + - TranspileParserError + - TranspileValidationError + - TranspileLateralColumnAliasError + - ReconcileError + - Other + validations: + required: true + - type: textarea + attributes: + label: Current Behavior + description: | + A concise description of what you're experiencing. + **Do not paste links to attachments with logs and/or images, as all issues with attachments will get deleted.** + Use the `Relevant log output` field to paste redacted log output without personal identifying information (PII). + You can Ctrl/Cmd+V the screenshot, which would appear as a rendered image if it doesn't contain any PII. + validations: + required: false + - type: textarea + attributes: + label: Expected Behavior + description: A concise description of what you expected to happen. + validations: + required: false + - type: textarea + attributes: + label: Steps To Reproduce + description: Steps to reproduce the behavior. + placeholder: | + 1. In this environment... + 1. With this config... + 1. Run '...' + 1. See error... + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output or Exception details + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell + - type: textarea + id: query + attributes: + label: Sample Query + description: Please copy and paste anonymized Query. This will be automatically formatted into code, so no need for backticks. + render: shell + - type: dropdown + id: os + attributes: + label: Operating System + description: Which operating system do you have Remorph installed on? + options: + - macOS + - Linux + - Windows + validations: + required: true + - type: dropdown + id: version + attributes: + label: Version + description: What version of our software are you running? + options: + - latest via Databricks CLI + - v0.1.5 + - v0.1.4 + - other + default: 0 + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..df7a6310d1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,9 @@ +blank_issues_enabled: false +contact_links: + - name: General Databricks questions + url: https://help.databricks.com/ + about: Issues related to Databricks and not related to Remorph + + - name: Remorph Documentation + url: https://github.com/databrickslabs/remorph/tree/main/docs + about: Documentation about Remorph diff --git a/.github/ISSUE_TEMPLATE/feature.yml b/.github/ISSUE_TEMPLATE/feature.yml new file mode 100644 index 0000000000..0ea335adf5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature.yml @@ -0,0 +1,45 @@ +# See https://docs.github.com/en/communities/using-templates-to-encourage-useful-issues-and-pull-requests/syntax-for-issue-forms +# and https://docs.github.com/en/communities/using-templates-to-encourage-useful-issues-and-pull-requests/syntax-for-githubs-form-schema +name: Feature Request +description: Something new needs to happen with Remorph +title: "[FEATURE]: " +labels: ["enhancement", "needs-triage"] +# assignees: +# - remorph-write +body: + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the feature request you're willing to submit + options: + - label: I have searched the existing issues + required: true + - type: dropdown + id: category + attributes: + label: Category of feature request + description: Please select the category that best describes the feature you are requesting for. + options: + - Transpile + - Reconcile + - Other + validations: + required: true + - type: textarea + attributes: + label: Problem statement + description: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + validations: + required: true + - type: textarea + attributes: + label: Proposed Solution + description: A clear and concise description of what you want to happen. + validations: + required: true + - type: textarea + attributes: + label: Additional Context + description: Add any other context, references or screenshots about the feature request here. + validations: + required: false diff --git a/.github/dependabot.yml b/.github/dependabot.yml index f1747992ce..64f0b80f32 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,4 +3,25 @@ updates: - package-ecosystem: "pip" directory: "/" schedule: - interval: "daily" \ No newline at end of file + interval: "daily" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + - package-ecosystem: "maven" + directory: "/" + schedule: + interval: "daily" + ignore: + # Ignore updates for Databricks Connect: the version in use needs to match the testing infrastructure. + - dependency-name: "com.databricks:databricks-connect" + # Ignore non-patch updates for Scala: we manually manage the Scala version. + - dependency-name: "org.scala-lang:scala-library" + update-types: + # (Scala 2 patch releases are binary compatible, so they're the only type allowed.) + - "version-update:semver-minor" + - "version-update:semver-major" + # Mockito from 5.x requires JDK 11, but we are using JDK 8. + - dependency-name: "org.mockito:mockito-core" + versions: + - ">=5.0.0" diff --git a/.github/scripts/setup_spark_remote.sh b/.github/scripts/setup_spark_remote.sh new file mode 100755 index 0000000000..490e2248b2 --- /dev/null +++ b/.github/scripts/setup_spark_remote.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -xve + +mkdir -p "$HOME"/spark +cd "$HOME"/spark || exit 1 + +version=$(wget -O - https://dlcdn.apache.org/spark/ | grep 'href="spark' | grep -v 'preview' | sed 's::\n:g' | sed -n 's/.*>//p' | tr -d spark- | tr -d / | sort -r --version-sort | head -1) +if [ -z "$version" ]; then + echo "Failed to extract Spark version" + exit 1 +fi + +spark=spark-${version}-bin-hadoop3 +spark_connect="spark-connect_2.12" + +mkdir -p "${spark}" + + +SERVER_SCRIPT=$HOME/spark/${spark}/sbin/start-connect-server.sh + +## check the spark version already exist ,if not download the respective version +if [ -f "${SERVER_SCRIPT}" ];then + echo "Spark Version already exists" +else + if [ -f "${spark}.tgz" ];then + echo "${spark}.tgz already exists" + else + wget "https://dlcdn.apache.org/spark/spark-${version}/${spark}.tgz" + fi + tar -xvf "${spark}.tgz" +fi + +cd "${spark}" || exit 1 +## check spark remote is running,if not start the spark remote +result=$(${SERVER_SCRIPT} --packages org.apache.spark:${spark_connect}:"${version}" > "$HOME"/spark/log.out; echo $?) + +if [ "$result" -ne 0 ]; then + count=$(tail "${HOME}"/spark/log.out | grep -c "SparkConnectServer running as process") + if [ "${count}" == "0" ]; then + echo "Failed to start the server" + exit 1 + fi + # Wait for the server to start by pinging localhost:4040 + echo "Waiting for the server to start..." + for i in {1..30}; do + if nc -z localhost 4040; then + echo "Server is up and running" + break + fi + echo "Server not yet available, retrying in 5 seconds..." + sleep 5 + done + + if ! nc -z localhost 4040; then + echo "Failed to start the server within the expected time" + exit 1 + fi +fi +echo "Started the Server" diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index d7d41c6f9f..4c5a4d0097 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -15,17 +15,17 @@ on: - main env: - HATCH_VERSION: 1.7.0 + HATCH_VERSION: 1.9.1 jobs: - ci: + test-python: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: cache: 'pip' cache-dependency-path: '**/pyproject.toml' @@ -34,17 +34,28 @@ jobs: - name: Install hatch run: pip install hatch==$HATCH_VERSION + - name: Setup Spark Remote + run: | + chmod +x $GITHUB_WORKSPACE/.github/scripts/setup_spark_remote.sh + $GITHUB_WORKSPACE/.github/scripts/setup_spark_remote.sh + - name: Run unit tests - run: hatch run unit:test + run: hatch run test + + - name: Publish test coverage + uses: codecov/codecov-action@v5 + with: + codecov_yml_path: codecov.yml + token: ${{ secrets.CODECOV_TOKEN }} - fmt: + fmt-python: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: cache: 'pip' cache-dependency-path: '**/pyproject.toml' @@ -53,5 +64,220 @@ jobs: - name: Install hatch run: pip install hatch==$HATCH_VERSION - - name: Verify linting - run: hatch run lint:verify + - name: Reformat code + run: make fmt-python + + - name: Fail on differences + run: | + # Exit with status code 1 if there are differences (i.e. unformatted files) + git diff --exit-code + + fmt-scala: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: 11 + + # GitHub Action team seems not to have cycles to make the cache work properly, hence this hack + # See https://github.com/actions/setup-java/issues/255 + # See https://github.com/actions/setup-java/issues/577 + - name: Cache Maven + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + + - name: Reformat code + run: make fmt-scala + + - name: Fail on differences + run: | + # Exit with status code 1 if there are differences (i.e. unformatted files) + git diff --exit-code + + python-no-pylint-disable: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' && (github.event.action == 'opened' || github.event.action == 'synchronize') + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Verify no lint disabled in the new code + run: | + git fetch origin $GITHUB_BASE_REF:$GITHUB_BASE_REF + git diff $GITHUB_BASE_REF...$(git branch --show-current) >> diff_data.txt + python tests/unit/no_cheat.py diff_data.txt >> cheats.txt + COUNT=$(cat cheats.txt | wc -c) + if [ ${COUNT} -gt 1 ]; then + cat cheats.txt + exit 1 + fi + + test-scala: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: 11 + + # GitHub Action team seems not to have cycles to make the cache work properly, hence this hack. + # See https://github.com/actions/setup-java/issues/255 + # See https://github.com/actions/setup-java/issues/577 + - name: Cache Maven + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==$HATCH_VERSION + + - name: Initialize Python virtual environment for StandardInputPythonSubprocess + run: make dev + + - name: Run Unit Tests with Maven + run: mvn --update-snapshots scoverage:report --file pom.xml --fail-at-end + + - name: Upload remorph-core jars as Artifacts + uses: actions/upload-artifact@v4 + with: + name: remorph-core-jars + path: ~/.m2/repository/com/databricks/labs/remorph* + + - name: Publish JUnit report + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + files: | + **/TEST-com.databricks.labs.remorph.coverage*.xml + comment_title: 'Coverage tests results' + check_name: 'Coverage Tests Results' + fail_on: 'nothing' + continue-on-error: true + + - name: Publish test coverage + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + + coverage-tests-with-make: + runs-on: ubuntu-latest + env: + INPUT_DIR_PARENT: . + OUTPUT_DIR: ./test-reports + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: 3.10.x + + - name: Install hatch + run: pip install hatch==$HATCH_VERSION + + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: 11 + + # GitHub Action team seems not to have cycles to make the cache work properly, hence this hack. + # See https://github.com/actions/setup-java/issues/255 + # See https://github.com/actions/setup-java/issues/577 + - name: Cache Maven + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==$HATCH_VERSION + + - name: Install Databricks CLI + uses: databricks/setup-cli@main + + - name: Initialize Python virtual environment for StandardInputPythonSubprocess + run: make dev + + - name: Create dummy test file + run: | + mkdir $INPUT_DIR_PARENT/snowflake + mkdir $INPUT_DIR_PARENT/tsql + echo "SELECT * FROM t;" >> $INPUT_DIR_PARENT/snowflake/dummy_test.sql + echo "SELECT * FROM t;" >> $INPUT_DIR_PARENT/tsql/dummy_test.sql + shell: bash + + - name: Dry run coverage tests with make + run: make dialect_coverage_report + env: # this is a temporary hack + DATABRICKS_HOST: any + DATABRICKS_TOKEN: any + + - name: Verify report file + if: ${{ hashFiles('./test-reports/') == '' }} + run: | + echo "No file produced in tests-reports/" + exit 1 + + + antlr-grammar-linting: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: 11 + + # GitHub Action team seems not to have cycles to make the cache work properly, hence this hack. + # See https://github.com/actions/setup-java/issues/255 + # See https://github.com/actions/setup-java/issues/577 + - name: Cache Maven + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + + - name: Run Lint Test with Maven + run: mvn compile -DskipTests --update-snapshots -B exec:java -pl linter --file pom.xml -Dexec.args="-i core/src/main/antlr4 -o .venv/linter/grammar -c true" + continue-on-error: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000..90f8e08977 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,47 @@ +name: Release + +on: + push: + tags: + - 'v*' + +jobs: + publish: + runs-on: ubuntu-latest + environment: release + permissions: + # Used to authenticate to PyPI via OIDC and sign the release's artifacts with sigstore-python. + id-token: write + # Used to attach signing artifacts to the published release. + contents: write + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Build wheels + run: | + pip install hatch==1.7.0 + hatch build + + - name: Draft release + uses: softprops/action-gh-release@v2 + with: + files: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + + - uses: pypa/gh-action-pypi-publish@release/v1 + name: Publish package distributions to PyPI + + - name: Sign artifacts with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + release-signing-artifacts: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..411b2cf462 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +.venv +.DS_Store +*.pyc +__pycache__ +dist +.idea +/htmlcov/ +*.iml +target/ +.coverage* +coverage.* +*.iws +/core/gen/ +/antlrlinter/gen/ +*.tokens +spark-warehouse/ +remorph_transpile/ +/linter/gen/ +/linter/src/main/antlr4/library/gen/ +.databricks-login.json +/core/src/main/antlr4/com/databricks/labs/remorph/parsers/*/gen/ diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000000..3c0c32ce62 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,20 @@ +align = none +align.openParenDefnSite = false +align.openParenCallSite = false +align.tokens = [] +importSelectors = "singleLine" +optIn = { + configStyleArguments = false +} + +danglingParentheses.preset = false +docstrings.style = Asterisk +docstrings.wrap = false +maxColumn = 120 +runner.dialect = scala212 +fileOverride { + "glob:**/src/**/scala-2.13/**.scala" { + runner.dialect = scala213 + } +} +version = 3.8.0 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..23d4d5ee30 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,335 @@ +# Version changelog + +## 0.9.0 + +* Added support for format_datetime function in presto to Databricks ([#1250](https://github.com/databrickslabs/remorph/issues/1250)). A new `format_datetime` function has been added to the `Parser` class in the `presto.py` file to provide support for formatting datetime values in Presto on Databricks. This function utilizes the `DateFormat.from_arg_list` method from the `local_expression` module to format datetime values according to a specified format string. To ensure compatibility and consistency between Presto and Databricks, a new test file `test_format_datetime_1.sql` has been added, containing SQL queries that demonstrate the usage of the `format_datetime` function in Presto and its equivalent in Databricks, `DATE_FORMAT`. This standalone change adds new functionality without modifying any existing code. +* Added support for SnowFlake `SUBSTR` ([#1238](https://github.com/databrickslabs/remorph/issues/1238)). This commit enhances the library's SnowFlake support by adding the `SUBSTR` function, which was previously unsupported and existed only as an alternative to `SUBSTRING`. The project now fully supports both functions, and the `SUBSTRING` function can be used interchangeably with `SUBSTR` via the new `withConversionStrategy(SynonymOf("SUBSTR"))` method. Additionally, this commit supersedes a previous pull request that lacked a GPG signature and includes a test for the `SUBSTR` function. The `ARRAY_SLICE` function has also been updated to match SnowFlake's behavior, and the project now supports a more comprehensive list of SQL functions with their corresponding arity. +* Added support for json_size function in presto ([#1236](https://github.com/databrickslabs/remorph/issues/1236)). A new `json_size` function for Presto has been added, which determines the size of a JSON object or array and returns an integer. Two new methods, `_build_json_size` and `get_json_object`, have been implemented to handle JSON objects and arrays differently, and the Parser and Tokenizer classes of the Presto class have been updated to include the new json_size function. An alternative implementation for Databricks using SQL functions is provided, and a test case is added to cover a fixed `is not null` error for json_extract in the Databricks generator. Additionally, a new test file for Presto has been added to test the functionality of the `json_extract` function in Presto, and a new method `GetJsonObject` is introduced to extract a JSON object from a given path. The `json_extract` function has also been updated to extract the value associated with a specified key from JSON data in both Presto and Databricks. +* Enclosed subqueries in parenthesis ([#1232](https://github.com/databrickslabs/remorph/issues/1232)). This PR introduces changes to the ExpressionGenerator and LogicalPlanGenerator classes to ensure that subqueries are correctly enclosed in parentheses during code generation. Previously, subqueries were not always enclosed in parentheses, leading to incorrect code. This issue has been addressed by enclosing subqueries in parentheses in the `in` and `scalarSubquery` methods, and by adding new match cases for `ir.Filter` in the `LogicalPlanGenerator` class. The changes also take care to avoid doubling enclosing parentheses in the `.. IN(SELECT...)` pattern. New methods have not been added, and existing functionality has been modified to ensure that subqueries are correctly enclosed in parentheses, leading to the generation of correct SQL code. Test cases have been included in a separate PR. These changes improve the correctness of the generated code, avoiding issues such as `SELECT * FROM SELECT * FROM t WHERE a > `a` WHERE a > 'b'` and ensuring that the generated code includes parentheses around subqueries. +* Fixed serialization of MultipleErrors ([#1177](https://github.com/databrickslabs/remorph/issues/1177)). In the latest release, the encoding of errors in the `com.databricks.labs.remorph.coverage` package has been improved with an update to the `encoders.scala` file. The change involves a fix for serializing `MultipleErrors` instances using the `asJson` method on each error instead of just the message. This modification ensures that all relevant information about each error is included in the encoded output, improving the accuracy of serialization for `MultipleErrors` class. Users who handle multiple errors and require precise serialization representation will benefit from this enhancement, as it guarantees comprehensive information encoding for each error instance. +* Fixed presto strpos and array_average functions ([#1196](https://github.com/databrickslabs/remorph/issues/1196)). This PR introduces new classes `Locate` and `NamedStruct` in the `local_expression.py` file to handle the `STRPOS` and `ARRAY_AVERAGE` functions in a Databricks environment, ensuring compatibility with Presto SQL. The `STRPOS` function, used to locate the position of a substring within a string, now uses the `Locate` class and emits a warning regarding differences in implementation between Presto and Databricks SQL. A new method `_build_array_average` has been added to handle the `ARRAY_AVERAGE` function in Databricks, which calculates the average of an array, accommodating nulls, integers, and doubles. Two SQL test cases have been added to demonstrate the use of the `ARRAY_AVERAGE` function with arrays containing integers and doubles. These changes promote compatibility and consistent behavior between Presto and Databricks when dealing with `STRPOS` and `ARRAY_AVERAGE` functions, enhancing the ability to migrate between the systems smoothly. +* Handled presto Unnest cross join to Databricks lateral view ([#1209](https://github.com/databrickslabs/remorph/issues/1209)). This release introduces new features and updates for handling Presto UNNEST cross joins in Databricks, utilizing the lateral view feature. New methods have been added to improve efficiency and robustness when handling UNNEST cross joins. Additionally, new test cases have been implemented for Presto and Databricks to ensure compatibility and consistency between the two systems in handling UNNEST cross joins, array construction and flattening, and parsing JSON data. Some limitations and issues remain, which will be addressed in future work. The acceptance tests have also been updated, with certain tests now expected to pass, while others may still fail. This release aims to improve the functionality and compatibility of Presto and Databricks when handling UNNEST cross joins and JSON data. +* Implemented remaining TSQL set operations ([#1227](https://github.com/databrickslabs/remorph/issues/1227)). This pull request enhances the TSql parser by adding support for parsing and converting the set operations `UNION [ALL]`, `EXCEPT`, and `INTERSECT` to the Intermediate Representation (IR). Initially, the grammar recognized these operations, but they were not being converted to the IR. This change resolves issues [#1126](https://github.com/databrickslabs/remorph/issues/1126) and [#1102](https://github.com/databrickslabs/remorph/issues/1102) and includes new unit, transpiler, and functional tests, ensuring the correct behavior of these set operations, including precedence rules. The commit also introduces a new test file, `union-all.sql`, demonstrating the correct handling of simple `UNION ALL` operations, ensuring consistent output across TSQL and Databricks SQL platforms. +* Supported multiple columns in order by clause in for ARRAYAGG ([#1228](https://github.com/databrickslabs/remorph/issues/1228)). This commit enhances the ARRAYAGG and LISTAGG functions by adding support for multiple columns in the order by clause and sorting in both ascending and descending order. A new method, sortArray, has been introduced to handle multiple sort orders. The changes also improve the functionality of the ARRAYAGG function in the Snowflake dialect by supporting multiple columns in the ORDER BY clause, with an optional DESC keyword for each column. The `WithinGroupParams` dataclass has been updated in the local expression module to include a list of tuples for the order columns and their sorting direction. These changes provide increased flexibility and control over the output of the ARRAYAGG and LISTAGG functions +* Added TSQL parser support for `(LHS) UNION RHS` queries ([#1211](https://github.com/databrickslabs/remorph/issues/1211)). In this release, we have implemented support for a new form of UNION in the TSQL parser, specifically for queries formatted as `(SELECT a from b) UNION [ALL] SELECT x from y`. This allows the union of two SELECT queries with an optional ALL keyword to include duplicate rows. The implementation includes a new case statement in the `TSqlRelationBuilder` class that handles this form of UNION, creating a `SetOperation` object with the left-hand side and right-hand side of the union, and an `is_all` flag based on the presence of the ALL keyword. Additionally, we have added support for parsing right-associative UNION clauses in TSQL queries, enhancing the flexibility and expressiveness of the TSQL parser for more complex and nuanced queries. The commit also includes new test cases to verify the correct translation of TSQL set operations to Databricks SQL, resolving issue [#1127](https://github.com/databrickslabs/remorph/issues/1127). This enhancement allows for more accurate parsing of TSQL queries that use the UNION operator in various formats. +* Added support for inline columns in CTEs ([#1184](https://github.com/databrickslabs/remorph/issues/1184)). In this release, we have added support for inline columns in Common Table Expressions (CTEs) in Snowflake across various components of our open-source library. This includes updates to the AST (Abstract Syntax Tree) for better TSQL translation and the introduction of the new case class `KnownInterval` for handling intervals. We have also implemented a new method, `DealiasInlineColumnExpressions`, in the `SnowflakePlanParser` class to parse inline columns in CTEs and modify the class constructor to include this new method. Additionally, a new private case class `InlineColumnExpression` has been introduced to allow for more efficient processing of Snowflake CTEs. The SnowflakeToDatabricksTranspiler has also been updated to support inline columns in CTEs, as demonstrated by a new test case. These changes improve compatibility, precision, and usability of the codebase, providing a better overall experience for software engineers working with CTEs in Snowflake. +* Implemented AST for positional column identifiers ([#1181](https://github.com/databrickslabs/remorph/issues/1181)). The recent change introduces an Abstract Syntax Tree (AST) for positional column identifiers in the Snowflake project, specifically in the `ExpressionGenerator` class. The new `NameOrPosition` type represents a column identifier, either by name or position. The `Id` and `Position` classes inherit from `NameOrPosition`, and the `nameOrPosition` method has been added to check and return the appropriate SQL representation. However, due to Databricks' lack of positional column identifier support, the generator side does not yet support this feature. This means that the schema of the table is required to properly translate queries involving positional column identifiers. This enhancement increases the system's flexibility in handling Snowflake's query structures, with the potential for more comprehensive generator-side support in the future. +* Implemented GROUP BY ALL ([#1180](https://github.com/databrickslabs/remorph/issues/1180)). The `GROUP BY ALL` clause is now supported in the LogicalPlanGenerator class of the remorph project, with the addition of a new case to handle the GroupByAll type and updated implementation for the Pivot type. A new case object called `GroupByAll` has been added to the relations.scala file's sealed trait "GroupType". A new test case has been implemented in the SnowflakeToDatabricksTranspilerTest class to check the correct transpilation of the `GROUP BY ALL` clause from Snowflake SQL syntax to Databricks SQL syntax. These changes allow for more flexibility and control in grouping operations and enable the implementation of specific functionality for the GROUP BY ALL clause in Snowflake, improving compatibility with Snowflake SQL syntax. + +Dependency updates: + + * Bump codecov/codecov-action from 4 to 5 ([#1210](https://github.com/databrickslabs/remorph/pull/1210)). + * Bump sqlglot from 25.30.0 to 25.32.1 ([#1254](https://github.com/databrickslabs/remorph/pull/1254)). + +## 0.8.0 + +* Added IR for stored procedures ([#1161](https://github.com/databrickslabs/remorph/issues/1161)). In this release, we have made significant enhancements to the project by adding support for stored procedures. We have introduced a new `CreateVariable` case class to manage variable creation within the intermediate representation (IR), and removed the `SetVariable` case class as it is now redundant. A new `CaseStatement` class has been added to represent SQL case statements with value match, and a `CompoundStatement` class has been implemented to enable encapsulation of a sequence of logical plans within a single compound statement. The `DeclareCondition`, `DeclareContinueHandler`, and `DeclareExitHandler` case classes have been introduced to handle conditional logic and exit handlers in stored procedures. New classes `DeclareVariable`, `ElseIf`, `ForStatement`, `If`, `Iterate`, `Leave`, `Loop`, `RepeatUntil`, `Return`, `SetVariable`, and `Signal` have been added to the project to provide more comprehensive support for procedural language features and control flow management in stored procedures. We have also included SnowflakeCommandBuilder support for stored procedures and updated the `visitExecuteTask` method to handle stored procedure calls using the `SetVariable` method. +* Added Variant Support ([#998](https://github.com/databrickslabs/remorph/issues/998)). In this commit, support for the Variant datatype has been added to the create table functionality, enhancing the system's compatibility with Snowflake's datatypes. A new VariantType has been introduced, which allows for more comprehensive handling of data during create table operations. Additionally, a `remarks VARIANT` line is added in the CREATE TABLE statement and the corresponding spec test has been updated. The Variant datatype is a flexible datatype that can store different types of data, such as arrays, objects, and strings, offering increased functionality for users working with variant data. Furthermore, this change will enable the use of the Variant datatype in Snowflake tables and improves the data modeling capabilities of the system. +* Added `PySpark` generator ([#1026](https://github.com/databrickslabs/remorph/issues/1026)). The engineering team has developed a new `PySpark` generator for the `com.databricks.labs.remorph.generators` package. This addition introduces a new parameter, `logical`, of type `Generator[ir.LogicalPlan, String]`, in the `SQLGenerator` for SQL queries. A new abstract class `BasePythonGenerator` has been added, which extends the `Generator` class and generates Python code. A `ExpressionGenerator` class has also been added, which extends `BasePythonGenerator` and is responsible for generating Python code for `ir.Expression` objects. A new `LogicalPlanGenerator` class has been added, which extends `BasePythonGenerator` and is responsible for generating Python code for a given `ir.LogicalPlan`. A new `StatementGenerator` class has been implemented, which converts `Statement` objects into Python code. A new Python-generating class, `PythonGenerator`, has been added, which includes the implementation of an abstract syntax tree (AST) for Python in Scala. This AST includes classes for various Python language constructs. Additionally, new implicit classes for `PythonInterpolator`, `PythonOps`, and `PythonSeqOps` have been added to allow for the creation of PySpark code using the Remorph framework. The `AndOrToBitwise` rule has been implemented to convert `And` and `Or` expressions to their bitwise equivalents. The `DotToFCol` rule has been implemented to transform code that references columns using dot notation in a DataFrame to use the `col` function with a string literal of the column name instead. A new `PySparkStatements` object and `PySparkExpressions` class have been added, which provide functionality for transforming expressions in a data processing pipeline to PySpark equivalents. The `SnowflakeToPySparkTranspiler` class has been added to transpile Snowflake queries to PySpark code. A new `PySpark` generator has been added to the `Transpiler` class, which is implemented as an instance of the `SqlGenerator` class. This change enhances the `Transpiler` class with a new `PySpark` generator and improves serialization efficiency. +* Added `debug-bundle` command for folder-to-folder translation ([#1045](https://github.com/databrickslabs/remorph/issues/1045)). In this release, we have introduced a `debug-bundle` command to the remorph project's CLI, specifically added to the `proxy_command` function, which already includes `debug-script`, `debug-me`, and `debug-coverage` commands. This new command enhances the tool's debugging capabilities, allowing developers to generate a bundle of translated queries for folder-to-folder translation tasks. The `debug-bundle` command accepts three flags: `dialect`, `src`, and `dst`, specifying the SQL dialect, source directory, and destination directory, respectively. Furthermore, the update includes refactoring the `FileSetGenerator` class in the `orchestration` package of the `com.databricks.labs.remorph.generators` package, adding a `debug-bundle` command to the `Main` object, and updating the `FileQueryHistoryProvider` method in the `ApplicationContext` trait. These improvements focus on providing a convenient way to convert folder-based SQL scripts to other formats like SQL and PySpark, enhancing the translation capabilities of the project. +* Added `ruff` Python formatter proxy ([#1038](https://github.com/databrickslabs/remorph/issues/1038)). In this release, we have added support for the `ruff` Python formatter in our project's continuous integration and development workflow. We have also introduced a new `FORMAT` stage in the `WorkflowStage` object in the `Result` Scala object to include formatting as a separate step in the workflow. A new `RuffFormatter` class has been added to format Python code using the `ruff` tool, and a `StandardInputPythonSubprocess` class has been included to run a Python subprocess and capture its output and errors. Additionally, we have added a proxy for the `ruff` formatter to the SnowflakeToPySparkTranspilerTest for Scala to improve the readability of the transpiled Python code generated by the SnowflakeToPySparkTranspiler. Lastly, we have introduced a new `ruff` formatter proxy in the test code for the transpiler library to enforce format and style conventions in Python code. These changes aim to improve the development and testing experience for the project and ensure that the code follows the desired formatting and style standards. +* Added baseline for translating workflows ([#1042](https://github.com/databrickslabs/remorph/issues/1042)). In this release, several new features have been added to the open-source library to improve the translation of workflows. A new dependency for the Jackson YAML data format library, version 2.14.0, has been added to the pom.xml file to enable processing YAML files and converting them to Java objects. A new `FileSet` class has been introduced, which provides an in-memory data structure to manage a set of files, allowing users to add, retrieve, and remove files by name and persist the contents of the files to the file system. A new `FileSetGenerator` class has been added that generates a `FileSet` object from a `JobNode` object, enabling the translation of workflows by generating all necessary files for a workspace. A new `DefineJob` class has been developed to define a new rule for processing `JobNode` objects in the Remorph system, converting instances of `SuccessPy` and `SuccessSQL` into `PythonNotebookTask` and `SqlNotebookTask` objects, respectively. Additionally, various new classes, such as `GenerateBundleFile`, `QueryHistoryToQueryNodes`, `ReformatCode`, `TryGeneratePythonNotebook`, `TryGenerateSQL`, `TrySummarizeFailures`, `InformationFile`, `SuccessPy`, `SuccessSQL`, `FailedQuery`, `Migration`, `PartialQuery`, `QueryPlan`, `RawMigration`, `Comment`, and `PlanComment`, have been introduced to provide a more comprehensive and nuanced job orchestration framework. The `Library` case class has been updated to better separate concerns between library configuration and code assets. These changes address issue [#1042](https://github.com/databrickslabs/remorph/issues/1042) and provide a more robust and flexible workflow translation solution. +* Added correct generation of `databricks.yml` for `QueryHistory` ([#1044](https://github.com/databrickslabs/remorph/issues/1044)). The FileSet class in the FileSet.scala file has been updated to include a new method that correctly generates the `databricks.yml` file for the `QueryHistory` feature. This file is used for orchestrating cross-compiled queries, creating three files in total - two SQL notebooks with translated and formatted queries and a `databricks.yml` file to define an asset bundle for the queries. The new method in the FileSet class writes the content to the file using the `Files.write` method from the `java.nio.file` package instead of the previously used `PrintWriter`. The FileSetGenerator class has been updated to include the new `databricks.yml` file generation, and new rules and methods have been added to improve the accuracy and consistency of schema definitions in the generated orchestration files. Additionally, the DefineJob and DefineSchemas classes have been introduced to simplify the orchestration generation process. +* Added documentation around Transformation ([#1043](https://github.com/databrickslabs/remorph/issues/1043)). In this release, the Transformation class in our open-source library has been enhanced with detailed documentation, type parameters, and new methods. The class represents a stateful computation that produces an output of type Out while managing a state of type State. The new methods include map and flatMap for modifying the output and chaining transformations, as well as run and runAndDiscardState for executing the computation with a given initial state and producing a Result containing the output and state. Additionally, we have introduced a new trait called TransformationConstructors that provides constructors for successful transformations, error transformations, lifted results, state retrieval, replacement, and updates. The CodeGenerator trait in our code generation library has also been updated with several new methods for more control and flexibility in the code generation process. These include commas and spaces for formatting output, updateGenCtx for updating the GeneratorContext, nest and unnest for indentation, withIndentedBlock for producing nested blocks of code, and withGenCtx for creating transformations that use the current GeneratorContext. +* Added tests for Snow ARRAY_REMOVE function ([#979](https://github.com/databrickslabs/remorph/issues/979)). In this release, we have added tests for the Snowflake ARRAY_REMOVE function in the SnowflakeToDatabricksTranspilerTest. The tests, currently ignored, demonstrate the usage of the ARRAY_REMOVE function with different data types, such as integers and doubles. A TODO comment is included for a test case involving VARCHAR casting, to be enabled once the necessary casting functionality is implemented. This update enhances the library's capabilities and ensures that the ARRAY_REMOVE function can handle a variety of data types. Software engineers can refer to these tests to understand the usage of the ARRAY_REMOVE function in the transpiler and the planned casting functionality. +* Avoid non local return ([#1052](https://github.com/databrickslabs/remorph/issues/1052)). In this release, the `render` method of the `generators` package object in the `com.databricks.labs.remorph` package has been updated to avoid using non-local returns and follow recommended coding practices. Instead of returning early from the method, it now uses `Option` to track failures and a `try-catch` block to handle exceptions. In cases of exception during string concatenation, the method sets the `failureOpt` variable to `Some(lift(KoResult(WorkflowStage.GENERATE, UncaughtException(e))))`. Additionally, the test file "CodeInterpolatorSpec.scala" has been modified to fix an issue with exception handling. In the updated code, new variables for each argument are introduced, and the problematic code is placed within an interpolated string, allowing for proper exception handling. This release enhances the robustness and reliability of the code interpolator and ensures that the method follows recommended coding practices. +* Collect errors in `Phase` ([#1046](https://github.com/databrickslabs/remorph/issues/1046)). The open-source library Remorph has received significant updates, focusing on enhancing error collection and simplifying the Transformation class. The changes include a new method `recordError` in the abstract Phase trait and its concrete implementations for collecting errors during each phase. The Transformation class has been simplified by removing the unused Phase parameter, while the Generator, CodeGenerator, and FileSetGenerator have been specialized to use Transformation without the Phase parameter. The TryGeneratePythonNotebook, TryGenerateSQL, CodeInterpolator, and TBASeqOps classes have been updated for a more concise and focused state. The imports have been streamlined, and the PySparkGenerator, SQLGenerator, and PlanParser have been modified to remove the unused Phase type parameter. A new test file, TransformationTest, has been added to check the error collection functionality in the Transformation class. Overall, these enhancements improve the reliability, readability, and maintainability of the Remorph library. +* Correctly generate `F.fn_name` for builtin PySpark functions ([#1037](https://github.com/databrickslabs/remorph/issues/1037)). This commit introduces changes to the generation of `F.fn_name` for builtin PySpark functions in the PySparkExpressions.scala file, specifically for PySpark's builtin functions (`fn`). It includes a new case to handle these functions by converting them to their lowercase equivalent in Python using `Locale.getDefault`. Additionally, changes have been made to handle window specifications more accurately, such as using `ImportClassSideEffect` with `windowSpec` and generating and applying a window function (`fn`) over it. The `LAST_VALUE` function has been modified to `LAST` in the SnowflakeToDatabricksTranspilerTest.scala file, and new methods such as `First`, `Last`, `Lag`, `Lead`, and `NthValue` have been added to the SnowflakeCallMapper class. These changes improve the accuracy, flexibility, and compatibility of PySpark when working with built-in functions and window specifications, making the codebase more readable, maintainable, and efficient. +* Create Command Extended ([#1033](https://github.com/databrickslabs/remorph/issues/1033)). In this release, the open-source library has been updated with several new features related to table management and SQL code generation. A new method `replaceTable` has been added to the `LogicalPlanGenerator` class, which generates SQL code for a `ReplaceTableCommand` IR node and replaces an existing table with the same name if it already exists. Additionally, support has been added for generating SQL code for an `IdentityConstraint` IR node, which specifies whether a column is an auto-incrementing identity column. The `CREATE TABLE` statement has been updated to include the `AUTOINCREMENT` and `REPLACE` constraints, and a new `IdentityConstraint` case class has been introduced to extend the capabilities of the `UnnamedConstraint` class. The `TSqlDDLBuilder` class has also been updated to handle the `IDENTITY` keyword more effectively. A new command implementation with `AUTOINCREMENT` and `REPLACE` constraints has been added, and a new SQL script has been included in the functional tests for testing CREATE DDL statements with identity columns. Finally, the SQL transpiler has been updated to support the `CREATE OR REPLACE PROCEDURE` syntax, providing more flexibility and convenience for users working with stored procedures in their SQL code. These updates aim to improve the functionality and ease of use of the open-source library for software engineers working with SQL code and table management. +* Don't draft automated releases ([#995](https://github.com/databrickslabs/remorph/issues/995)). In this release, we have made a modification to the release.yml file in the .github/workflows directory by removing the "draft: true" line. This change removes the creation of draft releases in the automated release process, simplifying it and making it more straightforward for users to access new versions of the software. The job section of the release.yml file now only includes the `release` job, with the "release-signing-artifacts: true" still enabled, ensuring that the artifacts are signed. This improvement enhances the overall release process, making it more efficient and user-friendly. +* Enhance the Snow ARRAY_SORT function support ([#994](https://github.com/databrickslabs/remorph/issues/994)). With this release, the Snowflake ARRAY_SORT function now supports Boolean literals as parameters, improving its functionality. The changes include validating Boolean parameters in SnowflakeCallMapper.scala, throwing a TranspileException for unsupported arguments, and simplifying the IR using the DBSQL SORT_ARRAY function. Additionally, new test cases have been added to SnowflakeCallMapperSpec for the ARRAY_SORT and ARRAY_SLICE functions. The SnowflakeToDatabricksTranspilerTest class has also been updated with new test cases that cover the enhanced ARRAY_SORT function's various overloads and combinations of Boolean literals, NULLs, and a custom sorting function. This ensures that invalid usage is caught during transpilation, providing better error handling and supporting more use cases. +* Ensure that unparsable text is not lost in the generated output ([#1012](https://github.com/databrickslabs/remorph/issues/1012)). In this release, we have implemented an enhancement to the error handling strategy in the ANTLR-generated parsers for SQL. This change records where parsing failed and gathers un-parsable input, preserving them as custom error nodes in the ParseTree at strategic points. The new custom error strategy allows visitors for higher level rules such as `sqlCommand` in Snowflake and `sqlClauses` in TSQL to check for an error node in the children and generate an Ir node representing the un-parsed text. Additionally, new methods have been introduced to handle error recovery, find the highest context in the tree for the particular parser, and recursively find the context. A separate improvement is planned to ensure the PLanParser no longer stops when syntax errors are discovered, allowing safe traversal of the ParseTree. This feature is targeted towards software engineers working with SQL parsing and aims to improve error handling and recovery. +* Fetch table definitions for TSQL ([#986](https://github.com/databrickslabs/remorph/issues/986)). This pull request introduces a new `TableDefinition` case class that encapsulates metadata properties for tables in TSQL, such as catalog name, schema name, table name, location, table format, view definition, columns, table size, and comments. A `TSqlTableDefinitions` class has been added with methods to retrieve table definitions, all schemas, and all catalogs from TSQL. The `SnowflakeTypeBuilder` is updated to parse data types from TSQL. The `SnowflakeTableDefinitions` class has been refactored to use the new `TableDefinition` case class and retrieve table definitions more efficiently. The changes also include adding two new test cases to verify the correct retrieval of table definitions and catalogs for TSQL. +* Fixed handling of projected expressions in `TreeNode` ([#1159](https://github.com/databrickslabs/remorph/issues/1159)). In this release, we have addressed the handling of projected expressions in the `TreeNode` class, resolving issues [#1072](https://github.com/databrickslabs/remorph/issues/1072) and [#1159](https://github.com/databrickslabs/remorph/issues/1159). The `expressions` method in the `Plan` abstract class has been modified to include the `final` keyword, restricting overriding in subclasses. This method now returns all expressions present in a query from the current plan operator and its descendants. Additionally, we have introduced a new private method, `seqToExpressions`, used for recursively finding all expressions from a given sequence. The `Project` class, representing a relational algebra operation that projects a subset of columns in a table, now utilizes a new `columns` parameter instead of `expressions`. Similar changes have been applied to other classes extending `UnaryNode`, such as `Join`, `Deduplicate`, and `Hint`. The `values` parameter of the `Values` class has also been altered to accurately represent input values. A new test class, `JoinTest`, has been introduced to verify the correct propagation of expressions in join nodes, ensuring intended data transformations. +* Handling any_keys_match from presto ([#1048](https://github.com/databrickslabs/remorph/issues/1048)). In this commit, we have added support for the `any_keys_match` Presto function in Databricks by implementing it using existing Databricks functions. The `any_keys_match` function checks if any keys in a map match a given condition. Specifically, we have introduced two new classes, `MapKeys` and `ArrayExists`, which allow us to extract keys from the input map and check if any of the keys satisfy the given condition using the `exists` function. This is accomplished by renaming `exists` to `array_exists` to better reflect its purpose. Additionally, we have provided a Databricks SQL query that mimics the behavior of the `any_keys_match` function in Presto and added tests to ensure that it works as expected. These changes enable users to perform equivalent operations with a consistent syntax in Databricks and Presto. +* Improve IR for job nodes ([#1041](https://github.com/databrickslabs/remorph/issues/1041)). The open-source library has undergone improvements to the Intermediate Representation (IR) for job nodes, as indicated by the commit message "Improve IR for job nodes." This release introduces several significant changes, including: Refactoring of the `JobNode` class to extend the `TreeNode` class and the addition of a new abstract class `LeafJobNode` that overrides the `children` method to always return an empty `Seq`. Enhancements to the `ClusterSpec` case class, which now includes a `toSDK` method that properly initializes and sets the values of the fields in the SDK `ClusterSpec` object. Improvements to the `NewClusterSpec` class, which updates the types of several optional fields and introduces changes to the `toSDK` method for better conversion to the SDK format. Removal of the `Job` class, which previously represented a job in the IR of workflows. Changes to the `JobCluster` case class, which updates the `newCluster` attribute from `ClusterSpec` to `NewClusterSpec`. Update to the `JobEmailNotifications` class, which now extends `LeafJobNode` and includes new methods and overwrites existing ones from `LeafJobNode`. Improvements to the `JobNotificationSettings` class, which replaces the original `toSDK` method with a new implementation for more accurate SDK representation of job notification settings. Refactoring of the `JobParameterDefinition` class, which updates the `toSDK` method for more efficient conversion to the SDK format. These changes simplify the representation of job nodes, align the codebase more closely with the underlying SDK, and improve overall code maintainability and compatibility with other Remorph components. +* Query History From Folder ([#991](https://github.com/databrickslabs/remorph/issues/991)). The Estimator class in the Remorph project has been updated to enhance the query history interface by adding metadata from reading from a folder, improving its ability to handle queries from different users and increasing the accuracy of estimation reports. The Anonymizer class has also been updated to handle cases where the user field is missing, ensuring the anonymization process can proceed smoothly and preventing potential errors. A new FileQueryHistory class has been added to provide query history functionality by reading metadata from a specified folder. The SnowflakeQueryHistory class has been updated to directly implement the history() method and include new fields in the ExecutedQuery objects, such as 'id', 'source', 'timestamp', 'duration', 'user', and 'filename'. A new ExecutedQuery case class has been introduced, which now includes optional `user` and `filename` fields, and a new QueryHistoryProvider trait has been added with a method history() that returns a QueryHistory object containing a sequence of ExecutedQuery objects, enhancing the query history interface's flexibility and power. Test suites and test data for the Anonymizer and TableGraph classes have been updated to accommodate these changes, allowing for more comprehensive testing of query history functionality. A FileQueryHistorySpec test file has been added to test the FileQueryHistory class's ability to correctly extract queries from SQL files, ensuring the class works as expected. +* Rework serialization using circe+jackson ([#1163](https://github.com/databrickslabs/remorph/issues/1163)). In pull request [#1163](https://github.com/databrickslabs/remorph/issues/1163), the serialization mechanism in the project has been refactored to use the Circe and Jackson libraries, replacing the existing ujson library. This change includes the addition of the Circe, Circe-generic-extras, and Circe-jackson libraries, which are licensed under the Apache 2.0 license. The project now includes the copyright notices and license information for all open-source projects that have contributed code to it, ensuring compliance with open-source licenses. The `CoverageTest` class has been updated to incorporate error encoding using Circe and Jackson libraries, and the `EstimationReport` case classes no longer have implicit `ReadWriter` instances defined using macroRW. Instead, circe and Jackson encode and decode instances are likely defined elsewhere in the codebase. The `BaseQueryRunner` abstract class has been updated to handle both parsing and transpilation errors in a more uniform way, using a `failures` field instead of `transpilation_error` or `parsing_error`. Additionally, a new file, `encoders.scala`, has been introduced, which defines encoders for serializing objects to JSON using the Circe and Jackson libraries. These changes aim to improve serialization and deserialization performance and capabilities, simplify the codebase, and ensure consistent and readable output. +* Some window functions does not support window frame conditions ([#999](https://github.com/databrickslabs/remorph/issues/999)). The Snowflake expression builder has been updated to correct the default window frame specifications for certain window functions and modify the behavior of the ORDER BY clause in these functions. This change ensures that the expression builder generates the correct SQL syntax for unsupported functions like "LAG", "DENSE_RANK", "LEAD", "PERCENT_RANK", "RANK", and "ROW_NUMBER", improving the compatibility and reliability of the generated queries. Additionally, a unit test for the `SnowflakeExpressionBuilder` has been updated to account for changes in the way window functions are handled, enhancing the accuracy of the builder in generating valid SQL for window functions in Snowflake. +* Split workflow definitions into sensible packages ([#1039](https://github.com/databrickslabs/remorph/issues/1039)). The AutoScale class has been refactored and moved to a new package, `com.databricks.labs.remorph.intermediate.workflows.clusters`, extending `JobNode` from `com.databricks.labs.remorph.intermediate.workflows`. It now includes a case class for auto-scaling that takes optional integer arguments `maxWorkers` and `minWorkers`, and a single method `apply` that creates and configures a cluster using the Databricks SDK's `ComputeService`. The AwsAttributes and AzureAttributes classes have also been moved to the `com.databricks.labs.remorph.intermediate.workflows.clusters` package and now extend `JobNode`. These classes manage AWS and Azure-related attributes for compute resources in a workflow. The ClientsTypes case class has been moved to a new clusters sub-package within the workflows package and now extends `JobNode`, and the ClusterLogConf class has been moved to a new clusters package. The JobDeployment class has been refactored and moved to `com.databricks.labs.remorph.intermediate.workflows.jobs`, and the JobEmailNotifications, JobsHealthRule, and WorkspaceStorageInfo classes have been moved to new packages and now import `JobNode`. These changes improve the organization and maintainability of the codebase, making it easier to understand and navigate. +* TO_NUMBER/TO_DECIMAL/TO_NUMERIC without precision and scale ([#1053](https://github.com/databrickslabs/remorph/issues/1053)). This pull request introduces improvements to the transpilation process for handling cases where precision and scale are not specified for TO_NUMBER, TO_DECIMAL, or TO_NUMERIC Snowflake functions. The updated transpiler now automatically applies default values when these parameters are omitted, with precision set to the maximum allowed value of 38 and scale set to 0. A new method has been added to manage these cases, and four new test cases have been included to verify the transpilation of TO_NUMBER and TO_DECIMAL functions without specified precision and scale, and with various input formats. This change ensures consistent behavior across different SQL dialects for cases where precision and scale are not explicitly defined in the conversion functions. +* Table comments captured as part of Snowflake Table Definition ([#989](https://github.com/databrickslabs/remorph/issues/989)). In this release, we have added support for capturing table comments as part of Snowflake Table Definitions in the remorph library. This includes modifying the TableDefinition case class to include an optional comment field, and updating the SQL query in the SnowflakeTableDefinitions class to retrieve table comments. A new integration test for Snowflake table definitions has also been introduced to ensure the proper functioning of the new feature. This test creates a connection to the Snowflake database, retrieves a list of table definitions, and checks for the presence of table comments. These changes are part of our ongoing efforts to improve metadata capture for Snowflake tables (Note: The commit message references issue [#945](https://github.com/databrickslabs/remorph/issues/945) on GitHub, which this pull request is intended to close.) +* Use Transformation to get rid of the ctx parameter in generators ([#1040](https://github.com/databrickslabs/remorph/issues/1040)). The `Generating` class has undergone significant changes, removing the `ctx` parameter and introducing two new phases, `Parsing` and `BuildingAst`, in the sealed trait `Phase`. The `Parsing` phase extends `Phase` with a previous phase of `Init` and contains the source code and filename. The `BuildingAst` phase extends `Phase` with a previous phase of `Parsing` and contains the parsed tree and the previous phase. The `Optimizing` phase now contains the unoptimized plan and the previous phase. The `Generating` phase now contains the optimized plan, the current node, the total statements, the transpiled statements, the `GeneratorContext`, and the previous phase. Additionally, the `TransformationConstructors` trait has been updated to allow for the creation of Transformation instances specific to a certain phase of a workflow. The `runQuery` method in the `BaseQueryRunner` abstract class has been updated to use a new `transpile` method provided by the `Transpiler` trait, and the `Estimator` class in the `Estimation` module has undergone changes to remove the `ctx` parameter in generators. Overall, these changes simplify the implementation, improve code maintainability, and enable better separation of concerns in the codebase. +* With Recursive ([#1000](https://github.com/databrickslabs/remorph/issues/1000)). In this release, we have introduced several enhancements for `With Recursive` statements in SQL parsing and processing for the Snowflake database. A new IR (Intermediate Representation) for With Recursive CTE (Common Table Expression) has been implemented in the SnowflakeAstBuilder.scala file. A new case class, WithRecursiveCTE, has been added to the SnowflakeRelationBuilder class in the databricks/labs/remorph project, which extends RelationCommon and includes two members: ctes and query. The buildColumns method in the SnowflakeRelationBuilder class has been updated to handle cases where columnList is null and extract column names differently. Additionally, a new test has been added in SnowflakeAstBuilderSpec.scala that verifies the correct handling of a recursive CTE query. These enhancements improve the support for recursive queries in the Snowflake database, enabling more powerful and flexible querying capabilities for developers and data analysts working with complex data structures. +* [chore] fixed query coverage report ([#1160](https://github.com/databrickslabs/remorph/issues/1160)). In this release, we have addressed the issue [#1160](https://github.com/databrickslabs/remorph/issues/1160) related to the query coverage report. We have implemented changes to the `QueryRunner` abstract class in the `com.databricks.labs.remorph.coverage` package. The `ReportEntryReport` constructor now accepts a new parameter `parsed`, which is set to 1 if there is no transpilation error and 0 otherwise. Previously, `parsed` was always set to 1, regardless of the presence of a transpilation error. We also updated the `extractQueriesFromFile` and `extractQueriesFromFolder` methods in the `FileQueryHistory` class to return a single `ExecutedQuery` instance, allowing for better query coverage reporting. Additionally, we modified the behavior of the `history()` method of the `fileQueryHistory` object in the `FileQueryHistorySpec` test case. The method now returns a query history object with a single query having a `source` including the text "SELECT * FROM table1;" and "SELECT * FROM table2;", effectively merging the previous two queries into one. These changes ensure that the report accurately reflects whether the query was successfully transpiled, parsed, and stored in the query history. It is crucial to test thoroughly any parts of the code that rely on the `history()` method to return separate queries, as the behavior of the method has changed. + + +## 0.7.0 + +* Added private key authentication for sf ([#917](https://github.com/databrickslabs/remorph/issues/917)). This commit adds support for private key authentication to the Snowflake data source connector, providing users with more flexibility and security. The `cryptography` library is used to process the user-provided private key, with priority given to the `pem_private_key` secret, followed by the `sfPassword` secret. If neither secret is found, an exception is raised. However, password-based authentication is still used when JDBC options are provided, as Spark JDBC does not currently support private key authentication. A new exception class, `InvalidSnowflakePemPrivateKey`, has been added for handling invalid or malformed private keys. Additionally, new tests have been included for reading data with private key authentication, handling malformed private keys, and checking for missing authentication keys. The notice has been updated to include the `cryptography` library's copyright and license information. +* Added support for `PARSE_JSON` and `VARIANT` datatype ([#906](https://github.com/databrickslabs/remorph/issues/906)). This commit introduces support for the `PARSE_JSON` function and `VARIANT` datatype in the Snowflake parser, addressing issue [#894](https://github.com/databrickslabs/remorph/issues/894). The implementation removes the experimental dialect, enabling support for the `VARIANT` datatype and using `PARSE_JSON` for it. The `variant_explode` function is also utilized. During transpilation to Snowflake, whenever the `:` operator is encountered in the SELECT statement, everything will be treated as a `VARIANT` on the Databricks side to handle differences between Snowflake and Databricks in accessing variant types. These changes are implemented using ANTLR. +* Added upgrade script and modified metrics sql ([#990](https://github.com/databrickslabs/remorph/issues/990)). In this release, the open-source library has been updated with several improvements to the metrics system, database upgrades, and setup process. The setup_spark_remote.sh script now checks if the Spark server is running by pinging localhost:4040, rather than sleeping for a fixed time, allowing for faster execution and more accurate server status determination. The metrics table's insert statement has been updated to cast values to Bigint for better handling of larger counts. An upgrade script has been added to facilitate required modifications, and the setup_spark_remote.sh script has been modified to validate URLs. A new Python file for upgrading the metrics table's data types has been added, which includes a function to recreate the table with the correct data types for specific columns. The upgrade_common module now includes several functions for upgrading database tables, and a new unit test file, test_upgrade_common.py, has been added with test cases for these functions. Lastly, the upgrade script for v0.4.0 has been updated to simplify the process of checking if the main table in the reconcile metadata needs to be recreated and to add an `operation_name` column. These changes improve the library's functionality, accuracy, and robustness, particularly for larger datasets and upgrading processes, enhancing the overall user experience. +* Basic CTAS Implementation ([#968](https://github.com/databrickslabs/remorph/issues/968)). This pull request adds basic support for the CREATE TABLE AS SELECT (CTAS) statement in Snowflake, enabling users to create a new table by selecting data from an existing table or query. In the LogicalPlanGenerator class, a new method has been implemented to handle CTAS statements, which generates the appropriate SQL command for creating a table based on the result of a select query. The SnowflakeDDLBuilder class now includes a relationBuilder class member for building relations based on Snowflake DDL input, and the visitCreateTableAsSelect method has been overridden to parse CTAS statements and construct corresponding IR objects. The test suite has been expanded to include a new spec for CTAS statements and a test case for the CTAS statement "CREATE TABLE t1 AS SELECT c1, c2 FROM t2;". Additionally, a new test file "test_ctas_complex.sql" has been added, containing SQL statements for creating a new table by selecting columns from multiple tables and computing new columns through various data manipulations. The implementation also includes adding new SQL statements for CTAS in both Snowflake and Databricks dialects, allowing for testing the CTAS functionality for both. +* Create repeatable estimator for Snowflake query history ([#924](https://github.com/databrickslabs/remorph/issues/924)). This commit introduces a new coverage estimation tool for analyzing query history in a database, initially implemented for Snowflake. The tool parses and transpiles query history into Databricks SQL and reports on the percentage of query history it can process. It includes a new `SnowflakePlanParser` class that handles Snowflake query plans, a `SqlGenerator` class that generates Databricks SQL from optimized logical plans, and a `dialect` method that returns the dialect string. The long-term plan is to extend this functionality to other supported databases and dialects and include a report on SQL complexity. Additionally, test cases have been added to the `AnonymizerTest` class to ensure the correct functionality of the `Anonymizer` class, which anonymizes executed queries when provided with a `PlanParser` object. The `Anonymizer` class is intended to be used as part of the coverage estimation tool, which will provide analysis of query history for various databases. +* Created a mapping dict for algo for each dialect at layer level ([#934](https://github.com/databrickslabs/remorph/issues/934)). A series of changes have been implemented to improve the reconciliation process and the handling of hash algorithms in the open-source library. A mapping dictionary algorithm to dialect has been introduced at the layer level to enhance the reconciliation process. The `get_hash_transform` function now accepts a new `layer` argument and returns a list of hash algorithms from the `HashAlgoMapping` dictionary. A new `HashAlgoMapping` class has been added to map algorithms to a dialect for hashing, replacing the previous `DialectHashConfig` class. A new function `get_dialect` has been introduced to retrieve the dialect based on the layer. The `_hash_transform` function and the `build_query` method have been updated to use the `layer` parameter when determining the dialect. These changes provide more precise control over the algorithm used for hash transformation based on the source layer and the target dialect, resulting in improved reconciliation performance and accuracy. +* Fetch TableDefinitions from Snowflake ([#904](https://github.com/databrickslabs/remorph/issues/904)). A new `SnowflakeTableDefinitions` class has been added to simplify the discovery of Snowflake table metadata. This class establishes a connection with a Snowflake database through a Connection object, and provides methods such as `getDataType` and `getTableDefinitionQuery` to parse data types and generate queries for table definitions. Moreover, it includes a `getTableDefinitions` method to retrieve all table definitions in a Snowflake database as a sequence of `TableDefinition` objects, which encapsulates various properties of each table. The class also features methods for retrieving all catalogs and schemas in a Snowflake database. Alongside `SnowflakeTableDefinitions`, a new test class, `SnowflakeTableDefinitionTest`, has been introduced to verify the behavior of `getTableDefinitions` and ensure that the class functions as intended, adhering to the desired behavior. +* Guide user on missing configuration file ([#930](https://github.com/databrickslabs/remorph/issues/930)). In this commit, the `_verify_recon_table_config` method in the `runner.py` file of the `databricks/labs/remorph` package has been updated to handle missing reconcile table configurations during installation. When the reconcile table configuration is not found, an error message will now display the name of the requested configuration file. This enhancement helps users identify the specific configuration file they need to provide in their workspace, addressing issue [#919](https://github.com/databrickslabs/remorph/issues/919). This commit is co-authored by Ludovic Claude. +* Implement more missing visitor functions for Snowflake and TSQL ([#975](https://github.com/databrickslabs/remorph/issues/975)). In this release, we have added several missing visitor methods for the Snowflake and TSQL builder classes to improve the reliability and maintainability of our parser. Previously, when a visitor method was missing, the default visitor was called, causing the visit of all children of the ParseTree, which was not ideal. This could lead to incorrect results due to a slight modification in the ANTLR grammar inadvertently breaking the visitor. In this release, we have implemented several new visitor methods for both Snowflake and TSQL builder classes, including the `visitDdlCommand` method in the `SnowflakeDDLBuilder` class and the `visitDdlClause` method in the `TSqlDDLBuilder` class. These new methods ensure that the ParseTree is traversed correctly and that the correct IR node is returned. The `visitDdlCommand` method checks for different types of DDL commands, such as create, alter, drop, and undrop, and calls the appropriate method for each type. The `visitDdlClause` method contains a sequence of methods corresponding to different DDL clauses and calls the first non-null method in the sequence. These changes significantly improve the robustness of our parser and enhance the reliability of our code. +* Introduce typed errors ([#981](https://github.com/databrickslabs/remorph/issues/981)). This commit introduces typed errors in the form of a new class, `UnexpectedNode`, and several case classes including `ParsingError`, `UnsupportedDataType`, `WrongNumberOfArguments`, `UnsupportedArguments`, and `UnsupportedDateTimePart` in various packages, as part of the ongoing effort to replace exception throwing with returning `Result` types in future pull requests. These changes will improve error handling and provide more context and precision for errors, facilitating debugging and maintenance of the remorph library and data type generation functionality. The `TranspileException` class is now constructed with specific typed error instances, and the `ErrorCollector` and `ErrorDetail` classes have been updated to use `ParsingError`. Additionally, the `SnowflakeCallMapper` and `SnowflakeTimeUnits` classes have been updated to use the new typed error mechanism, providing more precise error handling for Snowflake-specific functions and expressions. +* Miscellaneous improvements to Snowflake parser ([#952](https://github.com/databrickslabs/remorph/issues/952)). This diff brings several miscellaneous improvements to the Snowflake parser in the open-source library, targeting increased parse and transpilation success rates. The modifications include updating the `colDecl` rule to allow optional data types, introducing an `objectField` rule, and enabling date and timestamp literals as strings. Additionally, the parser has been refined to handle identifiers more efficiently, such as hashes within the AnonymizerTest. The expected Ast for certain test cases has also been updated to improve parser accuracy. These changes aim to create a more robust and streamlined Snowflake parser, minimizing parsing errors and enhancing overall user experience for project adopters. Furthermore, the error handling and reporting capabilities of the Snowflake parser have been improved with new case classes, `IndividualError` and `ErrorsSummary`, and updated error messages. +* Moved intermediate package out of parsers ([#972](https://github.com/databrickslabs/remorph/issues/972)). In this release, the `intermediate` package has been refactored out of the `parsers` package, aligning with the design principle that parsers should depend on the intermediate representation instead of the other way around. This change affects various classes and methods across the project, all of which have been updated to import the `intermediate` package from its new location. No new functionality has been introduced, but the refactoring improves the package structure and dependency management. The `EstimationAnalyzer` class in the `coverage/estimation` package has been updated to import classes from the new location of the `intermediate` package, and its `evaluateTree` method has been updated to use the new import path for `LogicalPlan` and `Expression`. Other affected classes include `SnowflakeTableDefinitions`, `SnowflakeLexer`, `SnowflakeParser`, `SnowflakeTypeBuilder`, `GeneratorContext`, `DataTypeGenerator`, `IRHelpers`, and multiple test files. +* Patch Function without Brackets ([#907](https://github.com/databrickslabs/remorph/issues/907)). This commit introduces new lexer and parser rules to handle Snowflake SQL functions without parentheses, specifically impacting CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP, LOCALTIME, and LOCALTIMESTAMP. The new rules allow these functions to be used without parentheses, consistent with Snowflake SQL. This change fixes functional tests and includes documentation for the affected functions. However, there is a pending task to add or fix more test cases to ensure comprehensive testing of the new rules. Additionally, the syntax of the SELECT statement for the CURRENT_TIMESTAMP function has been updated, removing the need for the parameter 'col1'. This change simplifies the syntax for certain SQL functions in the codebase and improves the consistency and reliability of the functional tests. +* Root Table ([#936](https://github.com/databrickslabs/remorph/issues/936)). The PR #\ introduces a new class `TableGraph` that extends `DependencyGraph` and implements `LazyLogging` trait. This class builds a graph of tables and their dependencies based on query history and table definitions. It provides methods to add nodes and edges, build the graph, and retrieve root, upstream, and downstream tables. The `DependencyGraph` trait offers a more structured and flexible way to handle table dependencies. This change is part of the Root Table feature (issue [#936](https://github.com/databrickslabs/remorph/issues/936)) that identifies root tables in a graph of table dependencies, closing issue [#23](https://github.com/databrickslabs/remorph/issues/23). The PR includes a new `TableGraphTest` class that demonstrates the use of these methods and verifies their behavior for better data flow understanding and optimization. +* Snowflake Merge Implementation ([#964](https://github.com/databrickslabs/remorph/issues/964)). In this release, we have implemented the Merge statement for the Snowflake parser, which enables updating or deleting rows in a target table based on matches with a source table, and inserting new rows into the target table when there are no matches. This feature includes updates to the SnowflakeDMLBuilder and SnowflakeExpressionBuilder classes, allowing for proper handling of column names and MERGE queries. Additionally, we have added test cases to the SnowflakeASTBuilder, SnowflakeDMLBuilderSpec, and SnowflakeToDatabricksTranspiler to ensure the accurate translation and execution of MERGE statements for the Snowflake dialect. These changes bring important database migration and synchronization capabilities to our open-source library, improving its functionality and usability for software engineers. +* TSQL: Implement CREATE TABLE ([#911](https://github.com/databrickslabs/remorph/issues/911)). This commit implements the TSQL CREATE TABLE command and its various options and forms, including CTAS, graph node syntax, and analytics variants, as well as syntactical differences for SQL Server. The DDL and DML visitors have been moved from the AST and Relation visitors to separate classes for better responsibility segregation. The LogicalPlanGenerator class has been updated to generate unique constraints, primary keys, foreign keys, check constraints, default value constraints, and identity constraints for the CREATE TABLE command. Additionally, new classes for generating SQL options and handling unresolved options during transpilation have been added to enhance the parser's capability to manage various options and forms. These changes improve the transpilation of TSQL code and the organization of the codebase, making it easier to maintain and extend. +* Transpile Snow ARRAY_SORT function ([#973](https://github.com/databrickslabs/remorph/issues/973)). In this release, we have implemented support for the Snowflake ARRAY_SORT function in our open-source library. This feature has been added as part of issue [#973](https://github.com/databrickslabs/remorph/issues/973), and it involves the addition of two new private methods, `arraySort` and `makeArraySort`, to the `SnowflakeCallMapper` class. The `arraySort` method takes a sequence of expressions as input and sorts the array using the `makeArraySort` method. The `makeArraySort` method handles both null and non-null values, sorts the array in ascending or descending order based on the provided parameter, and determines the position of null or small values based on the nulls first parameter. The sorted array is then returned as an `ir.ArraySort` expression. This functionality allows for the sorting of arrays in Snowflake SQL to be translated to equivalent code in the target language. This enhancement simplifies the process of working with arrays in Snowflake SQL and provides users with a more streamlined experience. +* Transpile Snow MONTHS_BETWEEN function correctly ([#963](https://github.com/databrickslabs/remorph/issues/963)). In this release, the remorph library's SnowflakeCallMapper class in the com/databricks/labs/remorph/parsers/snowflake/rules package has been updated to handle the MONTHS_BETWEEN function. A new case has been added that creates a MonthsBetween object with the first two arguments of the function call and a boolean value of true. This change enhances compatibility and ensures that the output accurately reflects the intended functionality. Additionally, new test cases have been introduced to the SnowflakeCallMapperSpec for the transpilation of the MONTHS_BETWEEN function. These test cases demonstrate accurate mapping of the function to the MonthsBetween class and proper casting of inputs as dates or timestamps, improving the reliability and precision of date and time calculations. +* Updated Installation to handle install errors ([#962](https://github.com/databrickslabs/remorph/issues/962)). In this release, we've made significant improvements to the `remorph` project, addressing and resolving installation errors that were occurring during the installation process in development mode. We've introduced a new `ProductInfo` class in the `wheels` module, which provides information about the products being installed. This change replaces the use of `WheelsV2` in two test functions. Additionally, we've updated the `workspace_installation` method in `application.py` to handle installation errors more effectively, addressing the dependency on workspace `.remorph` due to wheels. We've also added new methods to `installation.py` to manage local and remote version files, and updated the `_upgrade_reconcile_workflow` function to ensure the correct wheel path is used during installation. These changes improve the overall quality of the codebase, making it easier for developers to adopt and maintain the project, and ensure a more seamless installation experience for users. +* Updated catalog operations logging ([#910](https://github.com/databrickslabs/remorph/issues/910)). In this release, the setup process for the catalog, schema, and volume in the configurator module has been simplified and improved. The previous implementation repeatedly prompted the user for input until the correct input was provided or a maximum number of attempts was reached. The updated code now checks if the catalog, schema, or volume already exists and either uses it or prompts the user to create it once. If the user does not have the necessary privileges to use the catalog, schema, or volume, an error message is logged and the installation is aborted. New methods have been added to check for necessary privileges, such as `has_necessary_catalog_access`, `has_necessary_schema_access`, and `has_necessary_volume_access`, which return a boolean indicating whether the user has the necessary privileges and log an error message with the missing privileges if not. The logging for catalog operations in the install.py file has also been updated to check for privileges at the end of the process and list any missing privileges for each catalog object. Additionally, changes have been made to the unit tests for the ResourceConfigurator class to ensure that the system handles cases where the user does not have the necessary permissions to access catalogs, schemas, or volumes, preventing unauthorized access and maintaining the security and integrity of the system. +* Updated remorph reconcile workflow to use wheels instead of pypi ([#884](https://github.com/databrickslabs/remorph/issues/884)). In this release, the installation process for the Remorph library has been updated to allow for the use of locally uploaded wheel files instead of downloading the package from PyPI. This change includes updates to the `install` and `_deploy_jobs` methods in the `recon.py` file to accept a new `wheel_paths` argument, which is used to pass the path of the Remorph wheel file to the `deploy_recon_job` method. The `_upgrade_reconcile_workflow` function in the `v0.4.0_add_main_table_operation_name_column.py` file has also been updated to upload the wheel package to the workspace and pass its path to the `deploy_reconcile_job` method. Additionally, the `deploy_recon_job` method in the `JobDeployment` class now accepts a new `wheel_file` argument, which represents the name of the wheel file for the remorph library. These changes address issues faced by customers with no public internet access and enable the use of new features before they are released on PyPI. The `test_recon.py` file in the `tests/unit/deployment` directory has also been updated to reflect these changes. +* Upgrade script Implementation ([#777](https://github.com/databrickslabs/remorph/issues/777)). In this release, we've implemented an upgrade script as part of pull request [#777](https://github.com/databrickslabs/remorph/issues/777), which resolves issue [#769](https://github.com/databrickslabs/remorph/issues/769). This change introduces a new `Upgrades` class in `application.py` that accepts `product_info` and `installation` as parameters and includes a cached property `wheels` for improved performance. Additionally, we've added new methods to the `WorkspaceInstaller` class for handling upgrade-related tasks, including the creation of a `ProductInfo` object, interacting with the Databricks SDK, and handling potential errors. We've also added a test case to ensure that upgrades are applied correctly on more recent versions. These changes are part of our ongoing effort to enhance the management and application of upgrades to installed products. +* bug fix for to_array function ([#961](https://github.com/databrickslabs/remorph/issues/961)). A bug fix has been implemented to improve the `TO_ARRAY` function in our open-source library. Previously, this function expected only one parameter, but it has been updated to accept two parameters, with the second being optional. This change brings the function in line with other functions in the class, improving flexibility and ensuring backward compatibility. The `TO_ARRAY` function is used to convert a given expression to an array if it is not null and return null otherwise. The commit also includes updates to the `Generator` class, where a new entry for the `ToArray` expression has been added to the `expression_map` dictionary. Additionally, a new `ToArray` class has been introduced as a subclass of `Func`, allowing the function to handle a variable number of arguments more gracefully. Relevant updates have been made to the functional tests for the `to_array` function for both Snowflake and Databricks SQL, demonstrating its handling of null inputs and comparing it with the corresponding ARRAY function in each SQL dialect. Overall, these changes enhance the functionality and adaptability of the `TO_ARRAY` function. +* feat: Implement all of TSQL predicates except for SOME ALL ANY ([#922](https://github.com/databrickslabs/remorph/issues/922)). In this commit, we have implemented the IR generation for several TSQL predicates including IN, IS, BETWEEN, LIKE, EXISTS, and FREETEXT, thereby improving the parser's ability to handle a wider range of TSQL syntax. The `TSqlParser` class has been updated with new methods and changes to existing ones, including the addition of new labeled expressions to the `predicate` rule. Additionally, we have corrected an error in the LIKE predicate's implementation, allowing the ESCAPE character to accept a full expression that evaluates to a single character at runtime, rather than assuming it to be a single character at parse time. These changes provide more flexibility and adherence to the TSQL standard, enhancing the overall functionality of the project for our adopters. + + +## 0.6.0 + +* Added query history retrieval from Snowflake ([#874](https://github.com/databrickslabs/remorph/issues/874)). This release introduces query history retrieval from Snowflake, enabling expanded compatibility and data source options for the system. The update includes adding the Snowflake JDBC driver and its dependencies to the `pom.xml` file, and the implementation of a new `SnowflakeQueryHistory` class to retrieve query history from Snowflake. The `Anonymizer` object is also added to anonymize query histories by fingerprinting queries based on their structure. Additionally, several case classes are added to represent various types of data related to query execution and table definitions in a Snowflake database. A new `EnvGetter` class is also included to retrieve environment variables for use in testing. Test files for the `Anonymizer` and `SnowflakeQueryHistory` classes are added to ensure proper functionality. +* Added support for `ALTER TABLE`: `ADD COLUMNS`, `DROP COLUMNS`, `RENAME COLUMNS`, and `DROP CONSTRAINTS` ([#861](https://github.com/databrickslabs/remorph/issues/861)). In this release, support for various `ALTER TABLE` SQL commands has been added to our open-source library, including `ADD COLUMNS`, `DROP COLUMNS`, `RENAME COLUMNS`, and `DROP CONSTRAINTS`. These features have been implemented in the `LogicalPlanGenerator` class, which now includes a new private method `alterTable` that takes a context and an `AlterTableCommand` object and returns an `ALTER TABLE` SQL statement. Additionally, a new sealed trait `TableAlteration` has been introduced, with four case classes extending it to handle specific table alteration operations. The `SnowflakeTypeBuilder` class has also been updated to parse and build Snowflake-specific SQL types for these commands. These changes provide improved functionality for managing and manipulating tables in Snowflake, making it easier for users to work with and modify their data. The new functionality has been tested using the `SnowflakeToDatabricksTranspilerTest` class, which specifies Snowflake `ALTER TABLE` commands and the expected transpiled results. +* Added support for `STRUCT` types and conversions ([#852](https://github.com/databrickslabs/remorph/issues/852)). This change adds support for `STRUCT` types and conversions in the system by implementing new `StructType`, `StructField`, and `StructExpr` classes for parsing, data type inference, and code generation. It also maps the `OBJECT_CONSTRUCT` from Snowflake and introduces updates to various case classes such as `JsonExpr`, `Struct`, and `Star`. These improvements enhance the system's capability to handle complex data structures, ensuring better compatibility with external data sources and expanding the range of transformations available for users. Additionally, the changes include the addition of test cases to verify the functionality of generating SQL data types for `STRUCT` expressions and handling JSON literals more accurately. +* Minor upgrades to Snowflake parameter processing ([#871](https://github.com/databrickslabs/remorph/issues/871)). This commit includes minor upgrades to Snowflake parameter processing, enhancing the consistency and readability of the code. The changes normalize parameter generation to use `${}` syntax for clarity and to align with Databricks notebook examples. An extra coverage test for variable references within strings has been added. The specific changes include updating a SELECT statement in a Snowflake SQL query to use ${} for parameter processing. The commit also introduces a new SQL file for functional tests related to Snowflake's parameter processing, which includes commented out and alternate syntax versions of a query. This commit is part of continuous efforts to improve the functionality, reliability, and usability of the Snowflake parameter processing feature. +* Patch/reconcile support temp views ([#901](https://github.com/databrickslabs/remorph/issues/901)). The latest update to the remorph-reconcile library adds support for temporary views, a new feature that was not previously available. With this change, the system can now handle `global_temp` for temporary views by modifying the `_get_schema_query` method to return a query for the `global_temp` schema if the schema name is set as such. Additionally, the `read_data` method was updated to correctly handle the namespace and catalog for temporary views. The new variable `namespace_catalog` has been introduced, which is set to `hive_metastore` if the catalog is not set, and to the original catalog with the added schema otherwise. The `table_with_namespace` variable is then updated to use the `namespace_catalog` and table name, allowing for correct querying of temporary views. These modifications enable remorph-reconcile to work seamlessly with temporary views, enhancing its flexibility and functionality. The updated unit tests reflect these changes, with assertions to ensure that the correct SQL statements are being generated and executed for temporary views. +* Reconcile Table Recon JSON filename updates ([#866](https://github.com/databrickslabs/remorph/issues/866)). The Remorph project has implemented a change to the naming convention and placement of the configuration file for the table reconciliation process. The configuration file, previously named according to individual preference, must now follow the pattern `recon_config___.json` and be placed in the `.remorph` directory within the Databricks Workspace. Examples of Table Recon filenames for Snowflake, Oracle, and Databricks source systems have been provided for reference. Additionally, the `data_source` field in the config file has been updated to accurately reflect the data source. The case of the filename should now match the case of `SOURCE_CATALOG_OR_SCHEMA` as defined in the config. Compliance with this new naming convention and placement is required for the successful execution of the table reconciliation process. +* [snowflake] parse parameters ([#855](https://github.com/databrickslabs/remorph/issues/855)). The open-source library has undergone changes related to Scalafmt configuration, Snowflake SQL parsing, and the introduction of a new `ExpressionGenerator` class method. The Scalafmt configuration change introduces a new `docstrings.wrap` option set to `false`, disabling docstring wrapping at the specified column limit. The `danglingParentheses.preset` option is also set to `false`, disabling the formatting rule for unnecessary parentheses. In Snowflake SQL parsing, new token types, lexer modes, and parser rules have been added to improve the parsing of string literals and other elements. A new `variable` method in the `ExpressionGenerator` class generates SQL expressions for `ir.Variable` objects. A new `Variable` case class has been added to represent a variable in an expression, and the `SchemaReference` case class now takes a single child expression. The `SnowflakeDDLBuilder` class has a new method, `extractString`, to safely extract strings from ANTLR4 context objects. The `SnowflakeErrorStrategy` object now includes new parameters for parsing Snowflake syntax, and the Snowflake LexerSpec test class has new methods for filling tokens from an input string and dumping the token list. Tests have been added for various string literal scenarios, and the SnowflakeAstBuilderSpec includes a new test case for handling the `translate amps` functionality. The Snowflake SQL queries in the test file have been updated to standardize parameter referencing syntax, improving consistency and readability. +* fixed current_date() generation ([#890](https://github.com/databrickslabs/remorph/issues/890)). This release includes a fix for an issue with the generation of the `current_date()` function in SQL queries, specifically for the Snowflake dialect. A test case in the `sqlglot-incorrect` category has been updated to use the correct syntax for the `CURRENT_DATE` function, which includes parentheses (`SELECT CURRENT_DATE() FROM tabl;`). Additionally, the `current_date()` function is now called consistently throughout the tests, either as `CURRENT_DATE` or `CURRENT_DATE()`, depending on the syntax required by Snowflake. No new methods were added, and the existing functionality was changed only to correct the `current_date()` generation. This improvement ensures accurate and consistent generation of the `current_date()` function across different SQL dialects, enhancing the reliability and accuracy of the tests. + + +## 0.5.0 + +* Added Translation Support for `!` as `commands` and `&` for `Parameters` ([#771](https://github.com/databrickslabs/remorph/issues/771)). This commit adds translation support for using "!" as commands and "&" as parameters in Snowflake code within the remorph tool, enhancing compatibility with Snowflake syntax. The "!set exit_on_error=true" command, which previously caused an error, is now treated as a comment and prepended with `--` in the output. The "&" symbol, previously unrecognized, is converted to its Databricks equivalent "$", which represents parameters, allowing for proper handling of Snowflake SQL code containing "!" commands and "&" parameters. These changes improve the compatibility and robustness of remorph with Snowflake code and enable more efficient processing of Snowflake SQL statements. Additionally, the commit introduces a new test suite for Snowflake commands, enhancing code coverage and ensuring proper functionality of the transpiler. +* Added `LET` and `DECLARE` statements parsing in Snowflake PL/SQL procedures ([#548](https://github.com/databrickslabs/remorph/issues/548)). This commit introduces support for parsing `DECLARE` and `LET` statements in Snowflake PL/SQL procedures, enabling variable declaration and assignment. It adds new grammar rules, refactors code using ScalaSubquery, and implements IR visitors for `DECLARE` and `LET` statements with Variable Assignment and ResultSet Assignment. The `RETURN` statement and parameterized expressions are also now supported. Note that `CURSOR` is not yet covered. These changes allow for improved processing and handling of Snowflake PL/SQL code, enhancing the overall functionality of the library. +* Added logger statements in get_schema function ([#756](https://github.com/databrickslabs/remorph/issues/756)). In this release, enhanced logging has been implemented in the Metadata (Schema) fetch functions, specifically in the `get_schema` function and other metadata fetch functions within Oracle, SnowflakeDataSource modules. The changes include logger statements that log the schema query, start time, and end time, providing better visibility into the performance and behavior of these functions during debugging or monitoring. The logging functionality is implemented using the built-in `logging` module and timestamps are obtained using the `datetime` module. In the SnowflakeDataSource class, RuntimeError or PySparkException will be raised if the user's current role lacks the necessary privileges to access the specified Information Schema object. The INFORMATION_SCHEMA table in Snowflake is used to fetch the schema, with the query modified to handle unquoted and quoted identifiers and the ordinal position of columns. The `get_schema_query` function has also been updated for better formatting for the SQL query used to fetch schema information. The schema fetching method remains unchanged, but these enhancements provide more detailed logging for debugging and monitoring purposes. +* Aggregates Reconcile CLI Implementation ([#770](https://github.com/databrickslabs/remorph/issues/770)). The `Aggregates Reconcile CLI Implementation` commit introduces a new command-line interface (CLI) for reconcile jobs, specifically for aggregated data. This change adds a new parameter, "operation_name", to the run method in the runner.py file, which determines the type of reconcile operation to perform. A new function, _trigger_reconcile_aggregates, has been implemented to reconcile aggregate data based on provided configurations and log the reconciliation process outcome. Additionally, new methods for defining job parameters and settings, such as `max_concurrent_runs` and "parameters", have been included. This CLI implementation enhances the customizability and control of the reconciliation process for users, allowing them to focus on specific use cases and data aggregations. The changes also include new test cases in test_runner.py to ensure the proper behavior of the ReconcileRunner class when the `aggregates-reconcile` operation_name is set. +* Aggregates Reconcile Updates ([#784](https://github.com/databrickslabs/remorph/issues/784)). This commit introduces significant updates to the `Table Deployment` feature, enabling it to support `Aggregate Tables` deployment and modifying the persistence logic for tables. Notable changes include the addition of a new `aggregates` attribute to the `Table` class in the configuration, which allows users to specify aggregate functions and optionally group by specific columns. The reconcile process now captures mismatch data, missing rows in the source, and missing rows in the target in the recon metrics tables. Furthermore, the aggregates reconcile process supports various aggregate functions like min, max, count, sum, avg, median, mode, percentile, stddev, and variance. The documentation has been updated to reflect these improvements. The commit also removes the `percentile` function from the reconciliation configuration and modifies the `aggregate_metrics` SQL query, enhancing the flexibility of the `Table Deployment` feature for `Aggregate Tables`. Users should note that the `percentile` function is no longer a valid option and should update their code accordingly. +* Aggregates Reconcile documentation ([#779](https://github.com/databrickslabs/remorph/issues/779)). In this commit, the Aggregates Reconcile utility has been enhanced with new documentation and visualizations for improved understanding and usability. The utility now includes a flow diagram, visualization, and README file illustrating how it compares specific aggregate metrics between source and target data residing on Databricks. A new configuration sample is added, showcasing the reconciliation of two tables using various aggregate functions, join columns, transformations, filters, and JDBC ReaderOptions configurations. The commit also introduces two Mermaid flowchart diagrams, depicting the reconciliation process with and without a `group by` operation. Additionally, new flow diagram visualizations in PNG and GIF formats have been added, aiding in understanding the process flow of the Aggregates Reconcile feature. The reconcile configuration samples in the documentation have also been updated with a spelling correction for clarity. +* Bump sqlglot from 25.6.1 to 25.8.1 ([#749](https://github.com/databrickslabs/remorph/issues/749)). In this version update, the `sqlglot` dependency has been bumped from 25.6.1 to 25.8.1, bringing several bug fixes and new features related to various SQL dialects such as BigQuery, DuckDB, and T-SQL. Notable changes include support for BYTEINT in BigQuery, improved parsing and transpilation of StrToDate in ClickHouse, and support for SUMMARIZE in DuckDB. Additionally, there are bug fixes for DuckDB and T-SQL, including wrapping left IN clause json extract arrow operand and handling JSON_QUERY with a single argument. The update also includes refactors and changes to the ANNOTATORS and PARSER modules to improve dialect-aware annotation and consistency. This pull request is compatible with `sqlglot` version 25.6.1 and below and includes a detailed list of commits and their corresponding changes. +* Generate window functions ([#772](https://github.com/databrickslabs/remorph/issues/772)). In this release, we have added support for generating SQL `WINDOW` and `SortOrder` expressions in the `ExpressionGenerator` class. This enhancement includes the ability to generate a `WINDOW` expression with a window function, partitioning and ordering clauses, and an optional window frame, using the `window` and `frameBoundary` methods. The `sortOrder` method now generates the SQL `SortOrder` expression, which includes the expression to sort by, sort direction, and null ordering. Additional methods `orNull` and `doubleQuote` return a string representing a NULL value and a string enclosed in double quotes, respectively. These changes provide increased flexibility for handling more complex expressions in SQL. Additionally, new test cases have been added to the `ExpressionGeneratorTest` to ensure the correct generation of SQL window functions, specifically the `ROW_NUMBER()` function with various partitioning, ordering, and framing specifications. These updates improve the robustness and functionality of the `ExpressionGenerator` class for generating SQL window functions. +* Implement TSQL specific function call mapper ([#765](https://github.com/databrickslabs/remorph/issues/765)). This commit introduces several new features to enhance compatibility between TSQL and Databricks SQL. A new method, `interval`, has been added to generate a Databricks SQL compatible string for intervals in a TSQL expression. The `expression` method has been updated to handle certain functions directly, improving translation efficiency. Specifically, the DATEADD function is now translated to Databricks SQL's DATE_ADD, ADD_MONTHS, and xxx + INTERVAL n {days|months|etc} constructs. The changes also include a new sealed trait `KnownIntervalType`, a new case class `KnownInterval`, and a new class `TSqlCallMapper` for mapping TSQL functions to Databricks SQL equivalents. Furthermore, the commit introduces new tests for TSQL specific function call mappers, ensuring proper translation of TSQL functions to Databricks SQL compatible constructs. These improvements collectively facilitate better integration and compatibility between TSQL and Databricks SQL. +* Improve TSQL and Snowflake parser and lexer ([#757](https://github.com/databrickslabs/remorph/issues/757)). In this release, the open-source library's Snowflake and TSQL lexers and parsers have been improved for better functionality and robustness. For the Snowflake lexer, unnecessary escape sequence processing has been removed, and various options have been corrected to be simple strings. The lexer now accepts a question mark as a placeholder for prepared statements in Snowflake statements. The TSQL lexer has undergone minor improvements, such as aligning the catch-all rule name with Snowflake. The Snowflake parser now accepts the question mark as a `PARAM` placeholder and simplifies the `typeFileformat` rule to accept a single `STRING` token. Additionally, several new keywords have been added to the TSQL lexer, improving consistency and clarity. These changes aim to simplify lexer and parser rules, enhance option handling and placeholders, and ensure consistency between Snowflake and TSQL. +* Patch Information Schema Predicate Pushdown for Snowflake ([#764](https://github.com/databrickslabs/remorph/issues/764)). In this release, we have implemented Information Schema Predicate Pushdown for Snowflake, resolving issue [#7](https://github.com/databrickslabs/remorph/issues/7) +* TSQL: Implement correct grammar for CREATE TABLE in all forms ([#796](https://github.com/databrickslabs/remorph/issues/796)). In this release, the TSqlLexer's CREATE TABLE statement grammar has been updated and expanded to support new keywords and improve accuracy. The newly added keywords 'EDGE', 'FILETABLE', 'NODE', and `NODES` enable correct parsing of CREATE TABLE statements using graph nodes and FILETABLE functionality. Existing keywords such as 'DROP_EXISTING', 'DYNAMIC', 'FILENAME', and `FILTER` have been refined for better precision. Furthermore, the introduction of the `tableIndices` rule standardizes the order of columns in the table. These enhancements improve the T-SQL parser's robustness and consistency, benefiting users in creating and managing tables in their databases. +* TSQL: Implement grammar for CREATE DATABASE and CREATE DATABASE SCOPED OPTION ([#788](https://github.com/databrickslabs/remorph/issues/788)). In this release, we have implemented the TSQL grammar for `CREATE DATABASE` and `CREATE DATABASE SCOPED OPTION` statements, addressing inconsistencies with TSQL documentation. The implementation was initially intended to cover the entire process from grammar to code generation. However, to simplify other DDL statements, the work was split into separate grammar-only pull requests. The diff introduces new methods such as `createDatabaseScopedCredential`, `createDatabaseOption`, and `databaseFilestreamOption`, while modifying the existing `createDatabase` method. The `createDatabaseScopedCredential` method handles the creation of a database scoped credential, which was previously part of `createDatabaseOption`. The `createDatabaseOption` method now focuses on handling individual options, while `databaseFilestreamOption` deals with filesystem specifications. Note that certain options, like `DEFAULT_LANGUAGE`, `DEFAULT_FULLTEXT_LANGUAGE`, and more, have been marked as TODO and will be addressed in future updates. +* TSQL: Improve transpilation coverage ([#766](https://github.com/databrickslabs/remorph/issues/766)). In this update, various enhancements have been made to improve the coverage of TSQL transpilation and address bugs in code generation, particularly for the `ExpressionGenerator` class in the `com/databricks/labs/remorph/generators/sql` package, and the `TSqlExpressionBuilder`, `TSqlFunctionBuilder`, `TSqlCallMapper`, and `QueryRunner` classes. Changes include adding support for new cases, modifying code generation behavior, improving test coverage, and updating existing tests for better TSQL code generation. Specific additions include new methods for handling bitwise operations, converting CHECKSUM_AGG calls to a sequence of MD5 function calls, and handling Fn instances. The `QueryRunner` class has been updated to include both the actual and expected outputs in error messages for better debugging purposes. Additionally, the test file for the `DATEADD` function has been updated to ensure proper syntax and consistency. All these modifications aim to improve the reliability, accuracy, and compatibility of TSQL transpilation, ensuring better functionality and coverage for the Remorph library's transformation capabilities. +* [chore] speedup build process by not running unit tests twice ([#842](https://github.com/databrickslabs/remorph/issues/842)). In this commit, the build process for the open-source library has been optimized by removing the execution of unit tests during the build phase in the Maven build process. A new plugin for the Apache Maven Surefire Plugin has been added, with the group ID set to "org.apache.maven.plugins", artifact ID set to "maven-surefire-plugin", and version set to "3.1.2". The configuration for this plugin includes a `skipTests` attribute set to "true", ensuring that tests are not run twice, thereby improving the build process speed. The existing ScalaTest Maven plugin configuration remains unchanged, allowing Scala tests to still be executed during the test phase. Additionally, the Maven Compiler Plugin has been upgraded to version 3.11.0, and the release parameter has been set to 8, ensuring that the Java compiler used during the build process is compatible with Java 8. The version numbers for several libraries, including os-lib, mainargs, ujson, scalatest, and exec-maven-plugin, are now being defined using properties, allowing Maven to manage and cache these libraries more efficiently. These changes improve the build process's performance and reliability without affecting the existing functionality. +* [internal] better errors for call mapper ([#816](https://github.com/databrickslabs/remorph/issues/816)). In this release, the `ExpressionGenerator` class in the `com.databricks.labs.remorph.generators.sql` package has been updated to handle exceptions during the conversion of input functions to Databricks expressions. A try-catch block has been added to catch `IndexOutOfBoundsException` and provide a more descriptive error message, including the name of the problematic function and the error message associated with the exception. A `TranspileException` with the message `not implemented` is now thrown when encountering a function for which a translation to Databricks expressions is not available. The `IsTranspiledFromSnowflakeQueryRunner` class in the `com.databricks.labs.remorph.coverage` package has also been updated to include the name of the exception class in the error message for better error identification when a non-fatal error occurs during parsing. Additionally, the import statement for `Formatter` has been moved to ensure alphabetical order. These changes improve error handling and readability, thereby enhancing the overall user experience for developers interacting with the codebase. +* [snowflake] map more functions to Databricks SQL ([#826](https://github.com/databrickslabs/remorph/issues/826)). This commit introduces new private methods `andPredicate` and `orPredicate` to the ExpressionGenerator class in the `com.databricks.labs.remorph.generators.sql` package, enhancing the generation of SQL expressions for AND and OR logical operators, and improving readability and correctness of complex logical expressions. The LogicalPlanGenerator class in the `sql` package now supports more flexibility in inserting data into a target relation, enabling users to choose between overwriting the existing data or appending to it. The `FROM_JSON` function in the CallMapper class has been updated to accommodate an optional third argument, providing more flexibility in handling JSON-related transformations. A new class, `CastParseJsonToFromJson`, has been introduced to improve the performance of data processing pipelines that involve parsing JSON data in Snowflake using the `PARSE_JSON` function. Additional Snowflake SQL functions have been mapped to Databricks SQL IR, enhancing compatibility and functionality. The ExpressionGeneratorTest class now generates predicates without parentheses, simplifying and improving readability. Mappings for several Snowflake functions to Databricks SQL have been added, enhancing compatibility with Databricks SQL. The `sqlFiles` sequence in the `NestedFiles` class is now sorted before being mapped to `AcceptanceTest` objects, ensuring consistent order for testing or debugging purposes. A semicolon has been added to the end of a SQL query in a test file for Snowflake DML insert functionality, ensuring proper query termination. +* [sql] generate `INSERT INTO ...` ([#823](https://github.com/databrickslabs/remorph/issues/823)). In this release, we have made significant updates to our open-source library. The ExpressionGenerator.scala file has been updated to convert boolean values to lowercase instead of uppercase when generating INSERT INTO statements, ensuring SQL code consistency. A new method `insert` has been added to the `LogicalPlanGenerator` class to generate INSERT INTO SQL statements based on the `InsertIntoTable` input. We have introduced a new case class `InsertIntoTable` that extends `Modification` to simplify the API for DML operations other than SELECT. The SQL ExpressionGenerator now generates boolean literals in lowercase, and new test cases have been added to ensure the correct generation of INSERT and JOIN statements. Lastly, we have added support for generating INSERT INTO statements in SQL for specified database tables, improving cross-platform compatibility. These changes aim to enhance the library's functionality and ease of use for software engineers. +* [sql] generate basic JSON access ([#835](https://github.com/databrickslabs/remorph/issues/835)). In this release, we have added several new features and improvements to our open-source library. The `ExpressionGenerator` class now includes a new method, `jsonAccess`, which generates SQL code to access a JSON object's properties, handling different types of elements in the path. The `TO_JSON` function in the `StructsToJson` class has been updated to accept an optional expression as an argument, enhancing its flexibility. The `SnowflakeCallMapper` class now includes a new method, `lift`, and a new feature to generate basic JSON access, with corresponding updates to test cases and methods. The SQL logical plan generator has been refined to generate star projections with escaped identifiers, handling complex table and database names. We have also added new methods and test cases to the `SnowflakeCallMapper` class to convert Snowflake structs into JSON strings and cast Snowflake values to specific data types. These changes improve the library's ability to handle complex JSON data structures, enhance functionality, and ensure the quality of generated SQL code. +* [sql] generate basic `CREATE TABLE` definition ([#829](https://github.com/databrickslabs/remorph/issues/829)). In this release, the open-source library's SQL generation capabilities have been enhanced with the addition of a new `createTable` method to the `LogicalPlanGenerator` class. This method generates a `CREATE TABLE` definition for a given `ir.CreateTableCommand`, producing a SQL statement with a comma-separated list of column definitions. Each column definition includes the column name, data type, and any applicable constraints, generated using the `DataTypeGenerator.generateDataType` method and the newly-introduced `constraint` method. Additionally, the `project` method has been updated to incorporate a `FROM` clause in the generated SQL statement when the input of the project node is not `ir.NoTable()`. These improvements extend the functionality of the `LogicalPlanGenerator` class, allowing it to generate `CREATE TABLE` statements for input catalog ASTs, thereby better supporting data transformation use cases. A new test for the `CreateTableCommand` has been added to the `LogicalPlanGeneratorTest` class to validate the correct transpilation of the `CreateTableCommand` to a `CREATE TABLE` SQL statement. +* [sql] generate basic `TABLESAMPLE` ([#830](https://github.com/databrickslabs/remorph/issues/830)). In this commit, the open-source library's `LogicalPlanGenerator` class has been updated to include a new method, `tableSample`, which generates SQL representations of table sampling operations. Previously, the class only handled `INSERT`, `DELETE`, and `CREATE TABLE` commands. With this enhancement, the generator can now produce SQL statements using the `TABLESAMPLE` clause, allowing for the selection of a sample of data from a table based on various sampling methods and a seed value for repeatable sampling. The newly supported sampling methods include row-based probabilistic, row-based fixed amount, and block-based sampling. Additionally, a new test case has been added for the `LogicalPlanGenerator` related to the `TableSample` class, validating the correct transpilation of named tables and fixed row sampling into the `TABLESAMPLE` clause with specified parameters. This improvement ensures that the generated SQL code accurately represents the desired table sampling settings. + +Dependency updates: + + * Bump sqlglot from 25.6.1 to 25.8.1 ([#749](https://github.com/databrickslabs/remorph/pull/749)). + +## 0.4.1 + +* Aggregate Queries Reconciliation ([#740](https://github.com/databrickslabs/remorph/issues/740)). This release introduces several changes to enhance the functionality of the project, including the implementation of Aggregate Queries Reconciliation, addressing issue [#503](https://github.com/databrickslabs/remorph/issues/503). A new property, `aggregates`, has been added to the base class of the query builder module to support aggregate queries reconciliation. A `generate_final_reconcile_aggregate_output` function has been added to generate the final reconcile output for aggregate queries. A new SQL file creates a table called `aggregate_details` to store details about aggregate reconciles, and a new column, `operation_name`, has been added to the `main` table in the `installation` reconciliation query. Additionally, new classes and methods have been introduced for handling aggregate queries and their reconciliation, and new SQL tables and columns have been created for storing and managing rules for aggregating data in the context of query reconciliation. Unit tests have been added to ensure the proper functioning of aggregate queries reconciliation and reconcile aggregate data in the context of missing records. +* Generate GROUP BY / PIVOT ([#747](https://github.com/databrickslabs/remorph/issues/747)). The LogicalPlanGenerator class in the remorph library has been updated to support generating GROUP BY and PIVOT clauses for SQL queries. A new private method, "aggregate", has been added to handle two types of aggregates: GroupBy and Pivot. For GroupBy, it generates a GROUP BY clause with specified grouping expressions. For Pivot, it generates a PIVOT clause where the specified column is used as the pivot column and the specified values are used as the pivot values, compatible with Spark SQL. If the aggregate type is unsupported, a TranspileException is thrown. Additionally, new test cases have been introduced for the LogicalPlanGenerator class in the com.databricks.labs.remorph.generators.sql package to support testing the transpilation of Aggregate expressions with GROUP BY and PIVOT clauses, ensuring proper handling and transpilation of these expressions. +* Implement error strategy for Snowflake parsing and use error strategy for all parser instances ([#760](https://github.com/databrickslabs/remorph/issues/760)). In this release, we have developed an error strategy specifically for Snowflake parsing that translates raw token names and parser rules into more user-friendly SQL error messages. This strategy is applied consistently across all parser instances, ensuring a unified error handling experience. Additionally, we have refined the DBL_DOLLAR rule in the SnowflakeLexer grammar to handle escaped dollar signs correctly. These updates improve the accuracy and readability of error messages for SQL authors, regardless of the parsing tool or transpiler used. Furthermore, we have updated the TSQL parsing error strategy to match the new Snowflake error strategy implementation, providing a consistent error handling experience across dialects. +* Incremental improvement to error messages - article selection ([#711](https://github.com/databrickslabs/remorph/issues/711)). In this release, we have implemented an incremental improvement to the error messages generated during T-SQL code parsing. This change introduces a new private method, `articleFor`, which determines whether to use `a` or `an` in the generated messages based on the first letter of the following word. The `generateMessage` method has been updated to use this new method when constructing the initial error message and subsequent messages when there are multiple expected tokens. This improvement ensures consistent use of articles `a` or `an` in the error messages, enhancing their readability for software engineers working with T-SQL code. +* TSQL: Adds tests and support for SELECT OPTION(...) generation ([#755](https://github.com/databrickslabs/remorph/issues/755)). In this release, we have added support for generating code for the TSQL `SELECT ... OPTION(...)` clause in the codebase. This new feature includes the ability to transpile any query hints supplied with a SELECT statement as comments in the output code, allowing for easier assessment of query performance after transpilation. The OPTION clause is now generated as comments, including MAXRECURSION, string options, boolean options, and auto options. Additionally, we have added new tests and updated the TSqlAstBuilderSpec test class with new and updated test cases to cover the new functionality. The implementation is focused on generating code for the OPTION clause, and does not affect the actual execution of the query. The changes are limited to the ExpressionGenerator class and its associated methods, and the TSqlRelationBuilder class, without affecting other parts of the codebase. +* TSQL: IR implementation of MERGE ([#719](https://github.com/databrickslabs/remorph/issues/719)). The open-source library has been updated to include a complete implementation of the TSQL MERGE statement's IR (Intermediate Representation), bringing it in line with Spark SQL. The `LogicalPlanGenerator` class now includes a `generateMerge` method, which generates the SQL code for the MERGE statement, taking a `MergeIntoTable` object containing the target and source tables, merge condition, and merge actions as input. The `MergeIntoTable` class has been added as a case class to represent the logical plan of the MERGE INTO command and extends the `Modification` trait. The `LogicalPlanGenerator` class also includes a new `generateWithOptions` method, which generates SQL code for the WITH OPTIONS clause, taking a `WithOptions` object containing the input and options as children. Additionally, the `TSqlRelationBuilder` class has been updated to handle the MERGE statement's parsing, introducing new methods and updating existing ones, such as `visitMerge`. The `TSqlToDatabricksTranspiler` class has been updated to include support for the TSQL MERGE statement, and the `ExpressionGenerator` class has new tests for options, columns, and arithmetic expressions. A new optimization rule, `TrapInsertDefaultsAction`, has been added to handle the behavior of the DEFAULT keyword during INSERT statements. The commit also includes test cases for the `MergeIntoTable` logical operator and the T-SQL merge statement in the `TSqlAstBuilderSpec`. + + +## 0.4.0 + +* Added TSql transpiler ([#734](https://github.com/databrickslabs/remorph/issues/734)). In this release, we have developed a new open-source library feature that enhances the transpilation of T-SQL code to Databricks-compatible code. The new TSqlToDatabricksTranspiler class has been added, which extends the Transpiler abstract class and defines the transpile method. This method converts T-SQL code to Databricks-compatible code by creating a lexer, token stream, parser, and parsed tree from the input string using TSqlLexer, CommonTokenStream, and tSqlFile. The parsed tree is then passed to the TSqlAstBuilder's visit method to generate a logical plan, which is optimized using an optimizer object with rules such as PullLimitUpwards and TopPercentToLimitSubquery. The optimized plan is then passed to the LogicalPlanGenerator's generate method to generate the final transpiled code. Additionally, a new class, IsTranspiledFromTSqlQueryRunner, has been added to the QueryRunner object to transpile T-SQL queries into the format expected by the Databricks query runner using the new TSqlToDatabricksTranspiler. The AcceptanceTestRunner class in the com.databricks.labs.remorph.coverage package has been updated to replace TSqlAstBuilder with IsTranspiledFromTSqlQueryRunner in the TSqlAcceptanceSuite class, indicating a change in the code responsible for handling TSql queries. This new feature aims to provide a smooth and efficient way to convert T-SQL code to Databricks-compatible code for further processing and execution. +* Added the missing info for the reconciliation documentation ([#520](https://github.com/databrickslabs/remorph/issues/520)). In this release, we have made significant improvements to the reconciliation feature of our open-source library. We have added a configuration folder and provided a template for creating configuration files for specific table sources. The config files will contain necessary configurations for table-specific reconciliation. Additionally, we have included a note in the transformation section detailing the usage of user-defined functions (UDFs) in transformation expressions, with an example UDF called `sort_array_input()` provided. The reconcile configuration sample documentation has also been added, with a JSON configuration for reconciling tables using various operations like drop, join, transformation, threshold, filter, and JDBC ReaderOptions. The commit also includes examples of source and target tables, data overviews, and reconciliation configurations for various scenarios, such as basic config and column mapping, user transformations, explicit select, explicit drop, filters, and thresholds comparison, among others. These changes aim to make it easier for users to set up and execute the reconciliation process for specific table sources and provide clear and concise information about using UDFs in transformation expressions for reconciliation. This commit is co-authored by Vijay Pavan Nissankararao and SundarShankar89. +* Bump sigstore/gh-action-sigstore-python from 2.1.1 to 3.0.0 ([#555](https://github.com/databrickslabs/remorph/issues/555)). In this pull request, the sigstore/gh-action-sigstore-python dependency is being updated from version 2.1.1 to 3.0.0. This new version includes several changes and improvements, such as the addition of recursive globbing with ** to the inputs and the removal of certain settings like fulcio-url, rekor-url, ctfe, and rekor-root-pubkey. The signature, certificate, and bundle output settings have also been removed. Furthermore, the way inputs are parsed has been changed, and they are now made optional under certain conditions. The default suffix has been updated to .sigstore.json. The 3.0.0 version also resolves various deprecations present in sigstore-python's 2.x series and supports CI runners that use PEP 668 to constrain global package prefixes. +* Bump sqlglot from 25.1.0 to 25.5.1 ([#534](https://github.com/databrickslabs/remorph/issues/534)). In the latest release, the `sqlglot` package has been updated from version 25.1.0 to 25.5.1, which includes bug fixes, breaking changes, new features, and refactors for parsing, analyzing, and rewriting SQL queries. The new version introduces optimizations for coalesced USING columns, preserves EXTRACT(date_part FROM datetime) calls, decouples NVL() from COALESCE(), and supports FROM CHANGES in Snowflake. It also provides configurable transpilation of Snowflake VARIANT, and supports view schema binding options for Spark and Databricks. The update addresses several issues, such as the use of timestamp with time zone over timestamptz, switch off table alias columns generation, and parse rhs of x::varchar(max) into a type. Additionally, the update cleans up CurrentTimestamp generation logic for Teradata. The `sqlglot` dependency has also been updated in the 'experimental.py' file of the 'databricks/labs/remorph/snow' module, along with the addition of a new private method `_parse_json` to the `Generator` class for parsing JSON data. Software engineers should review the changes and update their code accordingly, as conflicts with existing code will be resolved automatically by Dependabot, as long as the pull request is not altered manually. +* Bump sqlglot from 25.5.1 to 25.6.1 ([#585](https://github.com/databrickslabs/remorph/issues/585)). In this release, the `sqlglot` dependency is updated from version 25.5.1 to 25.6.1 in the 'pyproject.toml' file. This update includes bug fixes, breaking changes, new features, and improvements. Breaking changes consist of updates to the QUALIFY clause in queries and the canonicalization of struct and array inline constructor. New features include support for ORDER BY ALL, FROM ROWS FROM (...), RPAD & LPAD functions, and exp.TimestampAdd. Bug fixes address issues related to the QUALIFY clause in queries, expansion of SELECT * REPLACE, RENAME, transpiling UDFs from Databricks, and more. The pull request also includes a detailed changelog, commit history, instructions for triggering Dependabot actions, and commands for reference, with the exception of the compatibility score for the new version, which is not taken into account in this pull request. +* Feature/reconcile table mismatch threshold ([#550](https://github.com/databrickslabs/remorph/issues/550)). This commit enhances the reconciliation process in the open-source library with several new features, addressing issue [#504](https://github.com/databrickslabs/remorph/issues/504). A new `get_record_count` method is added to the `Reconcile` class, providing record count data for source and target tables, facilitating comprehensive analysis of table mismatches. A `CountQueryBuilder` class is introduced to build record count queries for different layers and SQL dialects, ensuring consistency in data processing. The `Thresholds` class is refactored into `ColumnThresholds` and `TableThresholds`, allowing for more granular control over comparisons and customizable threshold settings. New methods `_is_mismatch_within_threshold_limits` and `_insert_into_metrics_table` are added to the `recon_capture.py` file, improving fine-grained control over the reconciliation process and preventing false positives. Additionally, new classes, methods, and data structures have been implemented in the `execute` module to handle reconciliation queries and data more efficiently. These improvements contribute to a more accurate and robust reconciliation system. +* Feature: introduce core transpiler ([#715](https://github.com/databrickslabs/remorph/issues/715)). A new core transpiler, `SnowflakeToDatabricksTranspiler`, has been introduced to convert Snowflake queries into Databricks SQL, streamlining integration and compatibility between the two systems. This transpiler is integrated into the coverage test suites for thorough testing, and is used to convert various types of logical plans, handling cases such as `Batch`, `WithCTE`, `Project`, `NamedTable`, `Filter`, and `Join`. The `SnowflakeToDatabricksTranspiler` class tokenizes input Snowflake query strings, initializes a `SnowflakeParser` instance, parses the input Snowflake query, generates a logical plan, and applies the `LogicalPlanGenerator` to the logical plan to generate the equivalent Databricks SQL query. Additionally, the `SnowflakeAstBuilder` class has been updated to alter the way `Batch` logical plans are built and improve overall functionality of the transpiler. +* Fixed LEFT and RIGHT JOIN syntax in Snowflake ANTLR grammar ([#526](https://github.com/databrickslabs/remorph/issues/526)). A fix has been implemented to address issues with the Snowflake ANTLR grammar related to the proper parsing of LEFT and RIGHT JOIN statements. Previously, the keywords LEFT and RIGHT were incorrectly allowed as identifiers, but they are hard keywords that must be escaped to be used as column names. This change updates the grammar to escape these keywords in JOIN statements, improving the overall parsing of queries that include LEFT and RIGHT JOINs. Additionally, semantic predicates have been suggested to handle cases where LEFT or RIGHT are used as column names without escaping, although this is not yet implemented. To ensure the correctness of the updated grammar, new tests have been added to the SnowflakeAstBuilderSpec for LEFT and RIGHT JOINs, which check that the Abstract Syntax Tree (AST) is built correctly for these queries. +* Fixed Snowflake Acceptance Testcases Failures ([#531](https://github.com/databrickslabs/remorph/issues/531)). In this release, updates have been made to the acceptance testcases for various SQL functions in the open-source library. The DENSE RANK function's testcase has been updated with a window specification and ORDER BY clause in both Snowflake and Databricks SQL syntaxes, ensuring accurate test results. The LAG function's testcase now includes a PARTITION BY and ORDER BY clause, as well as the NULLS LAST keyword in Databricks SQL, for improved accuracy and consistency. The SQL queries in the Snowflake testcase for the `last_value` function have been updated with a window specification, ORDER BY clause, and NULLS LAST directive for Databricks SQL. Test case failures in the Snowflake acceptance testsuite have been addressed with updates to the LEAD function, MONTH_NAME to MONTHNAME renaming, and DATE_FORMAT to TO_DATE conversion, improving reliability and consistency. The ntile function's testcase has been updated with PARTITION BY and ORDER BY clauses, and the NULLS LAST keyword has been added to the Databricks SQL query. The SQL query for null-safe equality comparison has been updated with a conditional expression compatible with Snowflake. The ranking function's testcase has been improved with the appropriate partition and order by clauses, and the NULLS LAST keyword has been added to the Databricks SQL query, enhancing accuracy and consistency. Lastly, updates have been made to the ROW_NUMBER function's testcase, ensuring accurate and consistent row numbering for both Snowflake and Databricks. +* Fixed TSQL transpiler ([#735](https://github.com/databrickslabs/remorph/issues/735)). In this release, we have implemented a fix for the TSQL transpiler, addressing the issue [#7](https://github.com/databrickslabs/remorph/issues/7). This enhancement allows the library to accurately convert TSQL code into an equivalent format that is compatible with other databases. The fix resolves reported bugs related to incorrect syntax interpretation, thereby improving the overall reliability and functionality of the transpiler. Software engineers and developers relying on TSQL compatibility for cross-database operations will benefit from this improvement. We encourage users to test and provide feedback on this updated feature. +* Fixed `SELECT TOP X PERCENT` IR translation for TSQL ([#733](https://github.com/databrickslabs/remorph/issues/733)). In this release, we have made several enhancements to the open-source library to improve compatibility with T-SQL and Catalyst. We have added a new dependency, `pprint_${scala.binary.version}` version 0.8.1 from the `com.lihaoyi` group, to provide advanced pretty-printing functionality for Scala. We have also fixed the translation of the TSQL `SELECT TOP X PERCENT` feature in the parser for intermediate expressions, addressing the difference in syntax between TSQL and SQL for limiting the number of rows returned by a query. Additionally, we have modified the implementation of the `WITH` clause and added a new expression for the `SELECT TOP` clause in T-SQL, improving the compatibility of the codebase with T-SQL and aligning it with Catalyst. We also introduced a new abstract class `Rule` and a case class `Rules` in the `com.databricks.labs.remorph.parsers.intermediate` package to fix the `SELECT TOP X PERCENT` IR translation for TSQL by adding new rules. Furthermore, we have added a new Scala file, `subqueries.scala`, containing abstract class `SubqueryExpression` and two case classes that extend it, and made changes to the `trees.scala` file to improve the tree string representation for better readability and consistency in the codebase. These changes aim to improve the overall functionality of the library and make it easier for new users to understand and adopt the project. +* Fixed invalid null constraint and FQN ([#517](https://github.com/databrickslabs/remorph/issues/517)). In this change, we addressed issues [#516](https://github.com/databrickslabs/remorph/issues/516) and [#517](https://github.com/databrickslabs/remorph/issues/517), which involved resolving an invalid null constraint and correcting a fully qualified name (FQN) in our open-source library. The `read_data` function in the `databricks.py` file was updated to improve null constraint handling and ensure the FQN is valid. Previously, the catalog was always appended to the table name, potentially resulting in an invalid FQN and null constraint issues. Now, the code checks if the catalog exists before appending it to the table name, and if not provided, the schema and table name are concatenated directly. Additionally, we removed the NOT NULL constraint from the catalog field in the source_table and target_table structs in the main SQL file, allowing null values for this field. These changes maintain backward compatibility and enhance the overall functionality and robustness of the project, ensuring accurate query results and avoiding potential errors. +* Generate SQL for arithmetic operators ([#726](https://github.com/databrickslabs/remorph/issues/726)). In this release, we have introduced a new private method `arithmetic` to the `ExpressionGenerator` class that generates SQL for arithmetic operations, including unary minus, unary plus, multiplication, division, modulo, addition, and subtraction. This improves the readability and maintainability of the code by separating concerns and making the functionality more explicit. Additionally, we have introduced a new trait named `Arithmetic` to group arithmetic expressions together, which enables easier manipulation and identification of arithmetic expressions in the code. A new test suite has also been added for arithmetic operations in the `ExpressionGenerator` class, which improves test coverage and ensures the correct SQL is generated for these operations. These changes provide a welcome addition for developers looking to extend the functionality of the `ExpressionGenerator` class for arithmetic operations. +* Generate SQL for bitwise operators. The ExpressionGenerator class in the remorph project has been updated to support generating SQL for bitwise operators (OR, AND, XOR, NOT) through the addition of a new private method `bitwise` that converts bitwise operations to equivalent SQL expressions. The `expression` method has also been updated to utilize the new `bitwise` method for any input ir.Expression that is a bitwise operation. To facilitate this change, a new trait called `Bitwise` and updated case classes for bitwise operations, including `BitwiseNot`, `BitwiseAnd`, `BitwiseOr`, and `BitwiseXor`, have been implemented. The updated case classes extend the new `Bitwise` trait and include the `dataType` override method to return the data type of the left expression. A new test case in ExpressionGeneratorTest for bitwise operators has been added to validate the functionality, and the `expression` method in ExpressionGenerator now utilizes the GeneratorContext() instead of new GeneratorContext(). These changes enable the ExpressionGenerator to generate SQL code for bitwise operations, expanding its capabilities. +* Generate `.. LIKE ..` ([#723](https://github.com/databrickslabs/remorph/issues/723)). In this commit, the `ExpressionGenerator` class has been enhanced with a new method, `like`, which generates SQL `LIKE` expressions. The method takes a `GeneratorContext` object and an `ir.Like` object as arguments, and returns a string representation of the `LIKE` expression. It uses the `expression` method to generate the left and right sides of the `LIKE` operator, and also handles the optional escape character. Additionally, the `timestampLiteral` and `dateLiteral` methods have been updated to take an `ir.Literal` object and better handle `NULL` values. A new test case has also been added for the `ExpressionGenerator` class, which checks the `like` function and includes examples for basic usage and usage with an escape character. This commit improves the functionality of the `ExpressionGenerator` class, allowing it to handle `LIKE` expressions and better handle `NULL` values for timestamps and dates. +* Generate `DISTINCT`, `*` ([#739](https://github.com/databrickslabs/remorph/issues/739)). In this release, we've enhanced the ExpressionGenerator class to support generating `DISTINCT` and `*` (star) expressions in SQL queries. Previously, the class did not handle these cases, resulting in incomplete or incorrect SQL queries. With the introduction of the `distinct` method, the class can now generate the `DISTINCT` keyword followed by the expression to be applied to, and the `star` method produces the `*` symbol, optionally followed by the name of the object (table or subquery) to which it applies. These improvements make the ExpressionGenerator class more robust and compatible with various SQL dialects, resulting in more accurate query outcomes. We've also added new test cases for the `ExpressionGenerator` class to ensure that `DISTINCT` and `*` expressions are generated correctly. Additionally, support for generating SQL `DISTINCT` and `*` (wildcard) has been added to the transpilation of Logical Plans to SQL, specifically in the `ir.Project` class. This ensures that the correct SQL `SELECT * FROM table` syntax is generated when a wildcard is used in the expression list. These enhancements significantly improve the functionality and compatibility of our open-source library. +* Generate `LIMIT` ([#732](https://github.com/databrickslabs/remorph/issues/732)). A new method has been added to generate a SQL LIMIT clause for a logical plan in the data transformation tool. A new `case` branch has been implemented in the `generate` method of the `LogicalPlanGenerator` class to handle the `ir.Limit` case, which generates the SQL LIMIT clause. If a percentage limit is specified, the tool will throw an exception as it is not currently supported. The `generate` method of the `ExpressionGenerator` class has been replaced with a new `generate` method in the `ir.Project`, `ir.Filter`, and new `ir.Limit` cases to ensure consistent expression generation. A new case class `Limit` has been added to the `com.databricks.labs.remorph.parsers.intermediate.relations` package, which extends the `UnaryNode` class and has four parameters: `input`, `limit`, `is_percentage`, and `with_ties`. This new class enables limiting the number of rows returned by a query, with the ability to specify a percentage of rows or include ties in the result set. Additionally, a new test case has been added to the `LogicalPlanGeneratorTest` class to verify the transpilation of a `Limit` node to its SQL equivalent, ensuring that the `Limit` node is correctly handled during transpilation. +* Generate `OFFSET` SQL clauses ([#736](https://github.com/databrickslabs/remorph/issues/736)). The Remorph project's latest update introduces a new OFFSET clause generation feature for SQL queries in the LogicalPlanGenerator class. This change adds support for skipping a specified number of rows before returning results in a SQL query, enhancing the library's query generation capabilities. The implementation includes a new case in the match statement of the generate function to handle ir.Offset nodes, creating a string representation of the OFFSET clause using the provide offset expression. Additionally, the commit includes a new test case in the LogicalPlanGeneratorTest class to validate the OFFSET clause generation, ensuring that the LogicalPlanGenerator can translate Offset AST nodes into corresponding SQL statements. Overall, this update enables the generation of more comprehensive SQL queries with OFFSET support, providing software engineers with greater flexibility for pagination and other data processing tasks. +* Generate `ORDER BY` SQL clauses ([#737](https://github.com/databrickslabs/remorph/issues/737)). This commit introduces new classes and enumerations for sort direction and null ordering, as well as an updated SortOrder case class, enabling the generation of ORDER BY SQL clauses. The LogicalPlanGenerator and SnowflakeExpressionBuilder classes have been modified to utilize these changes, allowing for more flexible and customizable sorting and null ordering when generating SQL queries. Additionally, the TSqlRelationBuilderSpec test suite has been updated to reflect these changes, and new test cases have been added to ensure the correct transpilation of ORDER BY clauses. Overall, these improvements enhance the Remorph project's capability to parse and generate SQL expressions with various sorting scenarios, providing a more robust and maintainable codebase. +* Generate `UNION` / `EXCEPT` / `INTERSECT` ([#731](https://github.com/databrickslabs/remorph/issues/731)). In this release, we have introduced support for the `UNION`, `EXCEPT`, and `INTERSECT` set operations in our data processing system's generator. A new `unknown` method has been added to the `Generator` trait to return a `TranspileException` when encountering unsupported operations, allowing for better error handling and more informative error messages. The `LogicalPlanGenerator` class in the remorph project has been extended to support generating `UNION`, `EXCEPT`, and `INTERSECT` SQL operations with the addition of a new parameter, `explicitDistinct`, to enable explicit specification of `DISTINCT` for these set operations. A new test suite has been added to the `LogicalPlanGenerator` to test the generation of these operations using the `SetOperation` class, which now has four possible set operations: `UnionSetOp`, `IntersectSetOp`, `ExceptSetOp`, and `UnspecifiedSetOp`. With these changes, our system can handle a wider range of input and provide meaningful error messages for unsupported operations, making it more versatile in handling complex SQL queries. +* Generate `VALUES` SQL clauses ([#738](https://github.com/databrickslabs/remorph/issues/738)). The latest commit introduces a new feature to generate `VALUES` SQL clauses in the context of the logical plan generator. A new case branch has been implemented in the `generate` method to manage `ir.Values` expressions, converting input data (lists of lists of expressions) into a string representation compatible with `VALUES` clauses. The existing functionality remains unchanged. Additionally, a new test case has been added for the `LogicalPlanGenerator` class, which checks the correct transpilation to `VALUES` SQL clauses. This test case ensures that the `ir.Values` method, which takes a sequence of sequences of literals, generates the corresponding `VALUES` SQL clause, specifically checking the input `Seq(Seq(ir.Literal(1), ir.Literal(2)), Seq(ir.Literal(3), ir.Literal(4)))` against the SQL clause `"VALUES (1,2), (3,4)"`. This change enables testing the functionality of generating `VALUES` SQL clauses using the `LogicalPlanGenerator` class. +* Generate predicate expressions. This commit introduces the generation of predicate expressions as part of the SQL ExpressionGenerator in the `com.databricks.labs.remorph.generators.sql` package, enabling the creation of more complex SQL expressions. The changes include the addition of a new private method, `predicate(ctx: GeneratorContext, expr: Expression)`, to handle predicate expressions, and the introduction of two new predicate expression types, LessThan and LessThanOrEqual, for comparing the relative ordering of two expressions. Existing predicate expression types have been updated with consistent naming. Additionally, the commit incorporates improvements to the handling of comparison operators in the SnowflakeExpressionBuilder and TSqlExpressionBuilder classes, addressing bugs and ensuring precise predicate expression generation. The `ParserTestCommon` trait has also been updated to reorder certain operators in a logical plan, maintaining consistent comparison results in tests. New test cases have been added to several test suites to ensure the correct interpretation and generation of predicate expressions involving different data types and search conditions. Overall, these enhancements provide more fine-grained comparison of expressions, enable more nuanced condition checking, and improve the robustness and accuracy of the SQL expression generation process. +* Merge remote-tracking branch 'origin/main'. In this update, the `ExpressionGenerator` class in the `com.databricks.labs.remorph.generators.sql` package has been enhanced with two new private methods: `dateLiteral` and `timestampLiteral`. These methods are designed to generate the SQL literal representation of `DateType` and `TimestampType` expressions, respectively. The introduction of these methods addresses the previous limitations of formatting date and timestamp values directly, which lacked extensibility and required duplicated code for handling null values. By extracting the formatting logic into separate methods, this commit significantly improves code maintainability and reusability, enhancing the overall readability and understandability of the `ExpressionGenerator` class for developers. The `dateLiteral` method handles `DateType` values by formatting them using the `dateFormat` `SimpleDateFormat` instance, returning `NULL` if the value is missing. Likewise, the `timestampLiteral` method formats `TimestampType` values using the `timeFormat` `SimpleDateFormat` instance, returning `NULL` if the value is missing. These methods will enable developers to grasp the code's functionality more easily and make future enhancements to the class. +* Modified dataclass for table threshold and added documentation ([#714](https://github.com/databrickslabs/remorph/issues/714)). A series of modifications have been implemented to enhance the threshold configuration and validation for table reconciliation in the open-source library. The `TableThresholds` dataclass has been updated to accept a string for the `model` attribute, replacing the previously used `TableThresholdModel` Enum. Additionally, a new `validate_threshold_model` method has been added to `TableThresholds` to ensure proper validation of the `model` attribute. A new exception class, `InvalidModelForTableThreshold`, has been introduced to handle invalid settings. Column-specific thresholds can now be set using the `ColumnThresholds` configuration option. The `recon_capture.py` and `recon_config.py` files have been updated accordingly, and the documentation has been revised to clarify these changes. These improvements offer greater flexibility and control for users configuring thresholds while also refining validation and error handling. +* Support CTAS in TSQL Grammar and add more comparison operators ([#545](https://github.com/databrickslabs/remorph/issues/545)). In this release, we have added support for the CTAS (CREATE TABLE AS) statement in the TSQL (T-SQL) grammar, as well as introduced new comparison operators: !=, !<, and !>. The CTAS statement allows for the creation of a new table by selecting from an existing table or query, potentially improving code readability and performance. The new comparison operators provide alternative ways of expressing inequalities, increasing flexibility for developers. The keyword `REPLICATE` has also been added for creating a full copy of a database or availability group. These changes enhance the overall functionality of the TSQL grammar and improve the user's ability to express various operations in TSQL. The CTAS statement is implemented as a new rule, and the new comparison operators are added as methods in the TSqlExpressionBuilder class. These changes provide increased capability and flexibility for TSQL parsing and query handling. The new methods are not adding any new external dependencies, and the project remains self-contained. The additions have been tested with the TSqlExpressionBuilderSpec test suite, ensuring the functionality and compatibility of the TSQL parser. +* Support translation of TSQL INGORE NULLS clause in windowing functions ([#511](https://github.com/databrickslabs/remorph/issues/511)). The latest change introduces support for translating the TSQL IGNORE NULLS and RESPECT NULLS clauses in windowing functions to their equivalents in Databricks SQL. In TSQL, these clauses appear after the function name and before the OVER clause, affecting how the functions handle null values. Databricks SQL represents this functionality with an optional trailing boolean parameter for specific windowing functions. With this update, when the IGNORE NULLS clause is specified in TSQL, a boolean option is appended to the corresponding Databicks SQL windowing functions, with RESPECT NULLS as the default. This enhancement is facilitated by a new private method, `buildNullIgnore`, which adds the boolean parameter to the original expression when IGNORE NULLS is specified in the OVER clause. The alteration is exemplified in new test examples for the TSqlFunctionSpec, which include testing the LEAD function with and without the IGNORE NULLS clause, and updates to the translation of functions with non-standard syntax. +* TSQL: Implement TSQL UPDATE/DELETE statements ([#540](https://github.com/databrickslabs/remorph/issues/540)). In this release, we have added support for TSQL UPDATE and DELETE statements in all syntactical forms, including UDF column transformations, in the TSqlParser.g4 file. The implementation of both statements is done in a single file and they share many common clauses. We have also introduced two new case classes, UpdateTable and MergeTables, in the extensions.scala file to implement the TSQL UPDATE and DELETE statements, respectively. Additionally, we have added new methods to handle various clauses and elements associated with these statements in the TSqlErrorStrategy class. A series of tests have been included to ensure the correct translation of various UPDATE and DELETE queries to their respective Abstract Syntax Trees (ASTs). These changes bring TSQL UPDATE and DELETE statement functionality to the project and allow for their use in a variety of contexts, providing developers with more flexibility and control when working with TSQL UPDATE and DELETE statements in the parser. +* TSQL: Implement translation of INSERT statement ([#515](https://github.com/databrickslabs/remorph/issues/515)). In this release, we have implemented the TSQL INSERT statement in its entirety, including all target options, optional clauses, and Common Table Expressions (CTEs) in our open-source library. The change includes updates to the TSqlParser.g4 file to support the INSERT statement's various clauses, such as TOP, INTO, WITH TABLE HINTS, outputClause, and optionClause. We have also added new case classes to the TSQL AST to handle various aspects of the INSERT statement, including LocalVarTable, Output, InsertIntoTable, DerivedRows, DefaultValues, and Default. The TSqlExpressionBuilder and TSqlRelationBuilder classes have been updated to support the new INSERT statement, including handling output column lists, aliases, and JSON clauses. We have added specification tests to TSqlAstBuilderSpec.scala to demonstrate the various ways that the INSERT statement can be written and to ensure the correct translation of TSQL INSERT statements into the remorph project. +* TSQL: Remove the SIGN fragment from numerical tokens ([#547](https://github.com/databrickslabs/remorph/issues/547)). In this release, we have made changes to the TSQL expression builder in the remorph project that affect how negative and positive numerical literals are parsed. Negative literals, such as -2, will now be parsed as UMinus(Literal(2)) instead of Literal(-2), and positive literals, such as +1, will be parsed as UPlus(Literal(1)) instead of Literal(1). This change was made to address issue [#546](https://github.com/databrickslabs/remorph/issues/546), but it is not an ideal solution as it may cause inconvenience in downstream processes. The affected numerical tokens include INT, HEX, FLOAT, REAL, and MONEY, which have been simplified by removing the SIGN fragment. We have updated the buildPrimitive method to handle INT, REAL, and FLOAT token types and ensure that numerical tokens are parsed correctly. We intend to keep issue [#546](https://github.com/databrickslabs/remorph/issues/546) open for further exploration of a better solution. The tests have been updated to reflect these changes. +* TSQL: Simplifies named table tableSource, implements columnAlias list ([#512](https://github.com/databrickslabs/remorph/issues/512)). This change introduces significant updates to the TSqlParser's grammar for tableSource, simplifying and consolidating rules related to table aliases and column alias lists. A new Relation called TableWithHints has been added to collect and process table hints, some of which have direct counterparts in the Catalyst optimizer or can be used as comments for migration purposes. The TSQLExpressionBuilder and TSqlRelationBuilder classes have been modified to handle table hints and column aliases, and the TSqlAstBuilderSpec test suite has been updated to include new tests for table hints in T-SQL SELECT statements. These changes aim to improve parsing, handling, and optimization of table sources, table hints, and column aliases in TSQL queries. +* TSQL: Support generic FOR options ([#525](https://github.com/databrickslabs/remorph/issues/525)). In this release, we have added support for parsing T-SQL (Transact-SQL) options that contain the keyword `FOR` using the standard syntax `[FOR]`. This change is necessary as `FOR` cannot be used directly as an identifier without escaping, as it would otherwise be seen as a table alias in a `SELECT` statement. The ANTLR rule for parsing generic options has been expanded to handle these special cases correctly, allowing for the proper parsing of options such as `OPTIMIZE FOR UNKNOWN` in a `SELECT` statement. Additionally, a new case has been added to the OptionBuilder class to handle the `FOR` keyword and elide it, converting `OPTIMIZE FOR UNKNOWN` to `OPTIMIZE` with an id of 'UNKNOWN'. This ensures the proper handling of options containing `FOR` and avoids any conflicts with the `FOR` clause in T-SQL statements. This change was implemented by Valentin Kasas and involves adding a new alternative to the `genericOption` rule in the TSqlParser.g4 file, but no new methods have been added. +* Updated Dialect Variable Name ([#535](https://github.com/databrickslabs/remorph/issues/535)). In this release, the `source` variable name in the `QueryBuilder` class, which refers to the `Dialect` instance, has been updated to `engine` to accurately reflect its meaning as referring to either `Source` or `Target`. This change includes updating the usage of `source` to `engine` in the `build_query`, `_get_with_clause`, and `build_threshold_query` methods, as well as removing unnecessary parentheses in a list. These changes improve the code's clarity, accuracy, and readability, while maintaining the overall functionality of the affected methods. +* Use Oracle library only if the recon source is Oracle ([#532](https://github.com/databrickslabs/remorph/issues/532)). In this release, we have added a new `ReconcileConfig` configuration object and a `SourceType` enumeration in the `databricks.labs.remorph.config` and `databricks.labs.remorph.reconcile.constants` modules, respectively. These objects are introduced to determine whether to include the Oracle JDBC driver library in a reconciliation job's task libraries. The `deploy_job` method and the `_job_recon_task` method have been updated to use the `_recon_config` attribute to decide whether to include the Oracle JDBC driver library. Additionally, the `_deploy_reconcile_job` method in the `install.py` file has been modified to include a new parameter called `reconcile`, which is passed as an argument from the `_config` object. This change enhances the flexibility and customization of the reconcile job deployment. Furthermore, new fixtures `oracle_recon_config` and `snowflake_reconcile_config` have been introduced for `ReconcileConfig` objects with Oracle and Snowflake specific configurations, respectively. These fixtures are used in the test functions for deploying jobs, ensuring that the tests are more focused and better reflect the actual behavior of the code. +* [chore] Make singletons for the relevant `DataType` instances ([#705](https://github.com/databrickslabs/remorph/issues/705)). This commit introduces case objects for various data types, such as NullType, StringType, and others, effectively making them singletons. This change simplifies the creation of data type instances and ensures that each type has a single instance throughout the application. The new case objects are utilized in building SQL expressions, affecting functions such as Cast and TRY_CAST in the TSqlExpressionBuilder class. Additionally, the test file TSqlExpressionBuilderSpec.scala has been updated to include the new case objects and handle errors by returning null. The data types tested include integer types, decimal types, date and time types, string types, binary types, and JSON. The primary goal of this change is to improve the management and identification of data types in the codebase, as well as to enhance code readability and maintainability. + +Dependency updates: + + * Bump sqlglot from 25.1.0 to 25.5.1 ([#534](https://github.com/databrickslabs/remorph/pull/534)). + * Bump sigstore/gh-action-sigstore-python from 2.1.1 to 3.0.0 ([#555](https://github.com/databrickslabs/remorph/pull/555)). + * Bump sqlglot from 25.5.1 to 25.6.1 ([#585](https://github.com/databrickslabs/remorph/pull/585)). + +## 0.3.0 + +* Added Oracle ojdbc8 dependent library during reconcile Installation ([#474](https://github.com/databrickslabs/remorph/issues/474)). In this release, the `deployment.py` file in the `databricks/labs/remorph/helpers` directory has been updated to add the `ojdbc8` library as a `MavenLibrary` in the `_job_recon_task` function, enabling the reconciliation process to access the Oracle Data source and pull data for reconciliation between Oracle and Databricks. The `JDBCReaderMixin` class in the `jdbc_reader.py` file has also been updated to include the Oracle ojdbc8 dependent library for reconciliation during the `reconcile` process. This involves installing the `com.oracle.database.jdbc:ojdbc8:23.4.0.24.05` jar as a dependent library and updating the driver class to `oracle.jdbc.driver.OracleDriver` from `oracle`. A new dictionary `driver_class` has been added, which maps the driver name to the corresponding class name, allowing for dynamic driver class selection during the `_get_jdbc_reader` method call. The `test_read_data_with_options` unit test has been updated to test the Oracle connector for reading data with specific options, including the use of the correct driver class and specifying the database table for data retrieval, improving the accuracy and reliability of the reconciliation process. +* Added TSQL coverage tests in the generated report artifact ([#452](https://github.com/databrickslabs/remorph/issues/452)). In this release, we have added new TSQL coverage tests and Snowflake coverage tests to the generated report artifact in the CI/CD pipeline. These tests are executed using Maven with the updated command "mvn --update-snapshots -B test -pl coverage --file pom.xml --fail-at-end" and "mvn --update-snapshots -B exec:java -pl coverage --file pom.xml --fail-at-end -Dexec.args="-i tests/resources/functional/snowflake -o coverage-result.json" respectively, and the "continue-on-error: true" option is added to allow the pipeline to proceed even if the tests fail. Additionally, we have introduced a new constructor to the `CommentBasedQueryExtractor` class, which accepts a `dialect` parameter and allows for easier configuration of the start and end comments for different SQL dialects. We have also updated the CommentBasedQueryExtractor for Snowflake and added two TSQL coverage tests to the generated report artifact to ensure that the `QueryExtractor` is working correctly for TSQL queries. These changes will help ensure thorough testing and identification of TSQL and Snowflake queries during the CI/CD process. +* Added full support for analytical windowing functions ([#401](https://github.com/databrickslabs/remorph/issues/401)). In this release, full support for analytical windowing functions has been implemented, addressing issue [#401](https://github.com/databrickslabs/remorph/issues/401). The functions were previously specified in the parser grammar but have been moved to the standard function lookup table for more consistent handling. This enhancement allows for the use of analytical aggregate functions, such as FIRST_VALUE and PERCENTILE_CONT, with a `WITHIN GROUP` syntax and an `OVER` clause, enabling more complex queries and data analysis. The `FixedArity` and `VariableArity` classes have been updated with new methods for the supported functions, and appropriate examples have been provided to demonstrate their usage in SQL. +* Added parsing for STRPOS in presto ([#462](https://github.com/databrickslabs/remorph/issues/462)). A new feature has been added to the remorph/snow package's presto module to parse the STRPOS function in SQL code. This has been achieved by importing the locate_to_strposition function from sqlglot.dialects.dialect and incorporating it into the FUNCTIONS dictionary in the Parser class. This change enables the parsing of the STRPOS function, which returns the position of the first occurrence of a substring in a string. The implementation has been tested with a SQL file containing two queries for Presto SQL using STRPOS and Databricks SQL using LOCATE, both aimed at finding the position of the letter `l` in the string 'Hello world', starting the search from the second position. This feature is particularly relevant for software engineers working on data processing and analytics projects involving both Presto and Databricks SQL, as it ensures compatibility and consistent behavior between the two for string manipulation functions. The commit is part of issue [#462](https://github.com/databrickslabs/remorph/issues/462), and the diff provided includes a new SQL file with test cases for the STRPOS function in Presto and Locate function in Databricks SQL. The test cases confirm if the `hello` string is present in the greeting_message column of the greetings_table. This feature allows users to utilize the STRPOS function in Presto to determine if a specific substring is present in a string. +* Added validation for join columns for all query builders and limiting rows for reports ([#413](https://github.com/databrickslabs/remorph/issues/413)). In this release, we've added validation for join columns in all query builders, ensuring consistent and accurate data joins. A limit on the number of rows displayed for reports has been implemented with a default of 50. The `compare.py` and `execute.py` files have been updated to include validation, and the `QueryBuilder` and `HashQueryBuilder` classes have new methods for validating join columns. The `SamplingQueryBuilder`, `ThresholdQueryBuilder`, and `recon_capture.py` files have similar updates for validation and limiting rows for reports. The `recon_config.py` file now has a new return type for the `get_join_columns` method, and a new method `test_no_join_columns_raise_exception()` has been added in the `test_threshold_query.py` file. These changes aim to enhance data consistency, accuracy, and efficiency for software engineers. +* Adds more coverage tests for functions to TSQL coverage ([#420](https://github.com/databrickslabs/remorph/issues/420)). This commit adds new coverage tests for various TSQL functions, focusing on the COUNT, MAX, MIN, STDEV, STDEVP, SUM, and VARP functions, which are identical in Databricks SQL. The tests include cases with and without the DISTINCT keyword to ensure consistent behavior between TSQL and Databricks. For the GROUPING and GROUPING_ID functions, which have some differences, tests and examples of TSQL and Databicks SQL code are provided. The CHECKSUM_AGG function, not directly supported in Databricks SQL, is tested using MD5 and CONCAT_WS for equivalence. The CUME_DIST function, identical in both systems, is also tested. Additionally, a new test file for the STDEV function and updated tests for the VAR function are introduced, enhancing the reliability and robustness of TSQL conversions in the project. +* Catalog, Schema Permission checks ([#492](https://github.com/databrickslabs/remorph/issues/492)). This release introduces enhancements to the Catalog and Schema functionality, with the addition of permission checks that raise explicit `Permission Denied` exceptions. The logger messages have been updated for clarity and a new variable, README_RECON_REPO, has been created to reference the readme file for the recon_config repository. The ReconcileUtils class has been modified to handle scenarios where the recon_config file is not found or corrupted during loading, providing clear error messages and guidance for users. The unit tests for the install feature have been updated with permission checks for Catalog and Schema operations, ensuring robust handling of permission denied errors. These changes improve the system's error handling and provide clearer guidance for users encountering permission issues. +* Changing the secret name acc to install script ([#432](https://github.com/databrickslabs/remorph/issues/432)). In this release, the `recon` function in the `execute.py` file of the `databricks.labs.remorph.reconcile` package has been updated to dynamically generate the secret name instead of hardcoding it as "secret_scope". This change utilizes the new `get_key_form_dialect` function to create a secret name specific to the source dialect being used in the reconciliation process. The `get_dialect` function, along with `DatabaseConfig`, `TableRecon`, and the newly added `get_key_form_dialect`, have been imported from `databricks.labs.remorph.config`. This enhancement improves the security and flexibility of the reconciliation process by generating dynamic and dialect-specific secret names. +* Feature/recon documentation ([#395](https://github.com/databrickslabs/remorph/issues/395)). This commit introduces a new reconciliation process, enhancing data consistency between sources, co-authored by Ganesh Dogiparthi, ganeshdogiparthi-db, and SundarShankar89. The README.md file provides detailed documentation for the reconciliation process. A new binary file, docs/transpile-install.gif, offers installation instructions or visual aids, while a mermaid flowchart in `report_types_visualisation.md` illustrates report generation for data, rows, schema, and overall reconciliation. No existing functionality was modified, ensuring the addition of valuable features for software engineers adopting this project. +* Fixing issues in sample query builder to handle Null's and zero ([#457](https://github.com/databrickslabs/remorph/issues/457)). This commit introduces improvements to the sample query builder's handling of Nulls and zeroes, addressing bug [#450](https://github.com/databrickslabs/remorph/issues/450). The changes include updated SQL queries in the test threshold query file with COALESCE and TRIM functions to replace Null values with a specified string, ensuring consistent comparison of datasets. The query store in test_execute.py has also been enhanced to handle NULL and zero values using COALESCE, improving overall robustness and consistency. Additionally, new methods such as build_join_clause, trim, and coalesce have been added to enhance null handling in the query builder. The commit also introduces the MockDataSource class, a likely test implementation of a data source, and updates the log_and_throw_exception function for clearer error messaging. +* Implement Lakeview Dashboard Publisher ([#405](https://github.com/databrickslabs/remorph/issues/405)). In this release, we've introduced the `DashboardPublisher` class in the `dashboard_publisher.py` module to streamline the process of creating and publishing dashboards in Databricks Workspace. This class simplifies dashboard creation by accepting an instance of `WorkspaceClient` and `Installation` and providing methods for creating and publishing dashboards with optional parameter substitution. Additionally, we've added a new JSON file, 'Remorph-Reconciliation-Substituted.lvdash.json', which contains a dashboard definition for a data reconciliation feature. This dashboard includes various widgets for filtering and displaying reconciliation results. We've also added a test file for the Lakeview Dashboard Publisher feature, which includes tests to ensure that the `DashboardPublisher` can create dashboards using specified file paths and parameters. These new features and enhancements are aimed at improving the user experience and streamlining the process of creating and publishing dashboards in Databricks Workspace. +* Integrate recon metadata reconcile cli ([#444](https://github.com/databrickslabs/remorph/issues/444)). A new CLI command, `databricks labs remorph reconcile`, has been added to initiate the Data Reconciliation process, loading `reconcile.yml` and `recon_config.json` configuration files from the Databricks Workspace. If these files are missing, the user is prompted to reinstall the `reconcile` module and exit the command. The command then triggers the `Remorph_Reconciliation_Job` based on the Job ID stored in the `reconcile.yml` file. This simplifies the reconcile execution process, requiring users to first configure the `reconcile` module and generate the `recon_config_.json` file using `databricks labs remorph install` and `databricks labs remorph generate-recon-config` commands. The new CLI command has been manually tested and includes unit tests. Integration tests and verification on the staging environment are pending. This feature was co-authored by Bishwajit, Ganesh Dogiparthi, and SundarShankar89. +* Introduce coverage tests ([#382](https://github.com/databrickslabs/remorph/issues/382)). This commit introduces coverage tests and updates the GitHub Actions workflow to use Java 11 with Corretto distribution, improving testing and coverage analysis for the project. Coverage tests are added as part of the remorph project with the introduction of a new module for coverage and updating the artifact version to 0.2.0-SNAPSHOT. The pom.xml file is modified to change the parent project version to 0.2.0-SNAPSHOT, ensuring accurate assessment and maintenance of code coverage during development. In addition, a new Main object within the com.databricks.labs.remorph.coverage package is implemented for running coverage tests using command-line arguments, along with the addition of a new file QueryRunner.scala and case classes for ReportEntryHeader, ReportEntryReport, and ReportEntry for capturing and reporting on the status and results of parsing and transpilation processes. The `Cache Maven packages` step is removed and replaced with two new steps: `Run Unit Tests with Maven` and "Run Coverage Tests with Maven." The former executes unit tests and generates a test coverage report, while the latter downloads remorph-core jars as artifacts, executes coverage tests with Maven, and uploads coverage tests results as json artifacts. The `coverage-tests` job runs after the `test-core` job and uses the same environment, checking out the code with full history, setting up Java 11 with Corretto distribution, downloading remorph-core-jars artifacts, and running coverage tests with Maven, even if there are errors. The JUnit report is also published, and the coverage tests results are uploaded as json artifacts, providing better test coverage and more reliable code for software engineers adopting the project. +* Presto approx percentile func fix ([#411](https://github.com/databrickslabs/remorph/issues/411)). The remorph library has been updated to support the Presto database system, with a new module added to the config.py file to enable robust and maintainable interaction. An `APPROX_PERCENTILE` function has been implemented in the `presto.py` file of the `sqlglot.dialects.presto` package, allowing for approximate percentile calculations in Presto and Databricks SQL. A test file has been included for both SQL dialects, with queries calculating the approximate median of the height column in the people table. The new functionality enhances the compatibility and versatility of the remorph library in working with Presto databases and improves overall project functionality. Additionally, a new test file for Presto in the snowflakedriver project has been introduced to test expected exceptions, further ensuring robustness and reliability. +* Raise exception if reconciliation fails for any table ([#412](https://github.com/databrickslabs/remorph/issues/412)). In this release, we have implemented significant changes to improve exception handling and raise meaningful exceptions when reconciliation fails for any table in our open-source library. A new exception class, `ReconciliationException`, has been added as a child of the `Exception` class, which takes two optional parameters in its constructor, `message` and `reconcile_output`. The `ReconcileOutput` property has been created for accessing the reconcile output object. The `InvalidInputException` class now inherits from `ValueError`, making the code more explicit with the type of errors being handled. A new method, `_verify_successful_reconciliation`, has been introduced to check the reconciliation output status and raise a `ReconciliationException` if any table fails reconciliation. The `test_execute.py` file has been updated to raise a `ReconciliationException` if reconciliation for a specific report type fails, and new tests have been added to the test suite to ensure the correct behavior of the `reconcile` function with and without raising exceptions. +* Removed USE catalog/schema statement as lsql has added the feature ([#465](https://github.com/databrickslabs/remorph/issues/465)). In this release, the usage of `USE` statements for selecting a catalog and schema has been removed in the `get_sql_backend` function, thanks to the new feature provided by the lsql library. This enhancement improves code readability, maintainability, and enables better integration with the SQL backend. The commit also includes changes to the installation process for reconciliation metadata tables, providing more clarity and simplicity in the code. Additionally, several test functions have been added or modified to ensure the proper functioning of the `get_sql_backend` function in various scenarios, including cases where a warehouse ID is not provided or when executing SQL statements in a notebook environment. An error simulation test has also been added for handling `DatabricksError` exceptions when executing SQL statements using the `DatabricksConnectBackend` class. +* Sampling with clause query to have `from dual` in from clause for oracle source ([#464](https://github.com/databrickslabs/remorph/issues/464)). In this release, we've added the `get_key_from_dialect` function, replacing the previous `get_key_form_dialect` function, to retrieve the key associated with a given dialect object, serving as a unique identifier for the dialect. This improvement enhances the flexibility and readability of the codebase, making it easier to locate and manipulate dialect objects. Additionally, we've modified the 'sampling_query.py' file to include `from dual` in the `from` clause for Oracle sources in a sampling query with a clause, enabling sampling from Oracle databases. The `_insert_into_main_table` method in the `recon_capture.py` file of the `databricks.labs.remorph.reconcile` module has been updated to ensure accurate key retrieval for the specified dialect, thereby improving the reconciliation process. These changes resolve issues [#458](https://github.com/databrickslabs/remorph/issues/458) and [#464](https://github.com/databrickslabs/remorph/issues/464), enhancing the functionality of the sampling query builder and providing better support for various databases. +* Support function translation to Databricks SQL in TSql and Snowflake ([#414](https://github.com/databrickslabs/remorph/issues/414)). This commit introduces a dialect-aware FunctionBuilder system and a ConversionStrategy system to enable seamless translation of SQL functions between TSQL, Snowflake, and Databricks SQL IR. The new FunctionBuilder system can handle both simple name translations and more complex conversions when there is no direct equivalent. For instance, TSQL's ISNULL function translates to IFNULL in Databricks SQL, while Snowflake's ISNULL remains unchanged. The commit also includes updates to the TSqlExpressionBuilder and new methods for building and visiting various contexts, enhancing compatibility and expanding the range of supported SQL dialects. Additionally, new tests have been added in the FunctionBuilderSpec to ensure the correct arity and function type for various SQL functions. +* TSQL: Create coverage tests for TSQL -> Databricks functions ([#415](https://github.com/databrickslabs/remorph/issues/415)). This commit introduces coverage tests for T-SQL functions and their equivalent Databricks SQL implementations, focusing on the DATEADD function's `yy` keyword. The DATEADD function is translated to the ADD_MONTHS function in Databricks SQL, with the number of months multiplied by 12. This ensures functional equivalence between T-SQL and Databricks SQL for date addition involving years. The tests are written as SQL scripts and are located in the `tests/resources/functional/tsql/functions` directory, covering various scenarios and possible engine differences between T-SQL and Databricks SQL. The conversion process is documented, and future automation of this documentation is considered. +* TSQL: Implement WITH CTE ([#443](https://github.com/databrickslabs/remorph/issues/443)). With this commit, we have extended the TSQL functionality by adding support for Common Table Expressions (CTEs). CTEs are temporary result sets that can be defined within a single execution of a SELECT, INSERT, UPDATE, DELETE, or CREATE VIEW statement, allowing for more complex and efficient queries. The implementation includes the ability to create a CTE with an optional name and a column list, followed by a SELECT statement that defines the CTE. CTEs can be self-referential and can be used to simplify complex queries, improving code readability and performance. This feature is particularly useful for cases where multiple queries rely on the same intermediate result set, as it enables reusing the results without having to repeat the query. +* TSQL: Implement functions with specialized syntax ([#430](https://github.com/databrickslabs/remorph/issues/430)). This commit introduces new data type conversion functions and JSON manipulation capabilities to T-SQL, addressing issue [#430](https://github.com/databrickslabs/remorph/issues/430). The newly implemented features include `NEXT VALUE FOR sequence`, `CAST(col TO sometype)`, `TRY_CAST(col TO sometype)`, `JSON_ARRAY`, and `JSON_OBJECT`. These functions support specialized syntax for handling data type conversions and JSON operations, including NULL value handling using `NULL ON NULL` and `ABSENT ON NULL` syntax. The `TSqlFunctionBuilder` class has been updated to accommodate these changes, and new test cases have been added to the `TSqlFunctionSpec` test class in Scala. This enhancement enables SQL-based querying and data manipulation with increased functionality for T-SQL parser and function evaluations. +* TSQL: Support DISTINCT in SELECT list and aggregate functions ([#400](https://github.com/databrickslabs/remorph/issues/400)). This commit adds support for the `DISTINCT` keyword in T-SQL for use in the `SELECT` list and aggregate functions such as `COUNT`. When used in the `SELECT` list, `DISTINCT` ensures unique values of the specified expression are returned, and in aggregate functions like `COUNT`, it considers only distinct values of the specified argument. This change aligns with the SQL standard and enhances the functionality of the T-SQL parser, providing developers with greater flexibility and control when using `DISTINCT` in complex queries and aggregate functions. The default behavior in SQL, `ALL`, remains unchanged, and the parser has been updated to accommodate these improvements. +* TSQL: Update the SELECT statement to support XML workspaces ([#451](https://github.com/databrickslabs/remorph/issues/451)). This release introduces updates to the TSQL Select statement grammar to correctly support XMLWORKSPACES in accordance with the latest specification. Although Databricks SQL does not currently support XMLWORKSPACES, this change is a syntax-only update to enable compatibility with other platforms that do support it. Newly added components include 'xmlNamespaces', 'xmlDeclaration', 'xmlSchemaCollection', 'xmlTypeDefinition', 'createXmlSchemaCollection', 'xmlIndexOptions', 'xmlIndexOption', 'openXml', 'xmlCommonDirectives', and 'xmlColumnDefinition'. These additions enable the creation, configuration, and usage of XML schemas and indexes, as well as the specification of XML namespaces and directives. A new test file for functional tests has been included to demonstrate the use of XMLWORKSPACES in TSQL and its equivalent syntax in Databricks SQL. While this update does not affect the existing codebase's functionality, it does enable support for XMLWORKSPACES syntax in TSQL, facilitating easier integration with other platforms that support it. Please note that Databricks SQL does not currently support XML workspaces. +* Test merge queue ([#424](https://github.com/databrickslabs/remorph/issues/424)). In this release, the Scalafmt configuration has been updated to version 3.8.0, with changes to the formatting of Scala code. The `danglingParentheses` preset option has been set to "false", removing dangling parentheses from the code. Additionally, the `configStyleArguments` option has been set to `false` under "optIn". These modifications to the configuration file are likely to affect the formatting and style of the Scala code in the project, ensuring consistent and organized code. This change aims to enhance the readability and maintainability of the codebase. +* Updated bug and feature yml to support reconcile ([#390](https://github.com/databrickslabs/remorph/issues/390)). The open-source library has been updated to improve issue and feature categorization. In the `.github/ISSUE_TEMPLATE/bug.yml` file, new options for TranspileParserError, TranspileValidationError, and TranspileLateralColumnAliasError have been added to the `label: Category of Bug / Issue` field. Additionally, a new option for ReconcileError has been included. The `feature.yml` file in the `.github/ISSUE_TEMPLATE` directory has also been updated, introducing a required dropdown menu labeled "Category of feature request." This dropdown offers options for Transpile, Reconcile, and Other categories, ensuring accurate classification and organization of incoming feature requests. The modifications aim to enhance clarity for maintainers in reviewing and prioritizing issue resolutions and feature implementations related to reconciliation functionality. +* Updated the documentation with json config examples ([#486](https://github.com/databrickslabs/remorph/issues/486)). In this release, the Remorph Reconciliation tool on Databricks has been updated to include JSON config examples for various config elements such as jdbc_reader_options, column_mapping, transformations, thresholds, and filters. These config elements enable users to define source and target data, join columns, JDBC reader options, select and drop columns, column mappings, transformations, thresholds, and filters. The update also provides examples in both Python and JSON formats, as well as instructions for installing the necessary Oracle JDBC library on a Databricks cluster. This update enhances the tool's functionality, making it easier for software engineers to reconcile source data with target data on Databricks. +* Updated uninstall flow ([#476](https://github.com/databrickslabs/remorph/issues/476)). In this release, the `uninstall` functionality of the `databricks labs remorph` tool has been updated to align with the latest changes made to the `install` refactoring. The `uninstall` flow now utilizes a new `MockInstallation` class, which handles the uninstallation process and takes a dictionary of configuration files and their corresponding contents as input. The `uninstall` function has been modified to return `False` in two cases, either when there is no remorph directory or when the user decides not to uninstall. A `MockInstallation` object is created for the reconcile.yml file, and appropriate exceptions are raised in the aforementioned cases. The `uninstall` function now uses a `WorkspaceUnInstallation` or `WorkspaceUnInstaller` object, depending on the input arguments, to handle the uninstallation process. Additionally, the `MockPrompts` class is used to prompt the user for confirmation before uninstalling remorph. +* Updates to developer documentation and add grammar formatting to maven ([#490](https://github.com/databrickslabs/remorph/issues/490)). The developer documentation has been updated to include grammar formatting instructions and support for dialects other than Snowflake. The Maven build cycle has been modified to format grammars before ANTLR processes them, enhancing readability and easing conflict resolution during maintenance. The TSqlLexer.g4 file has been updated with formatting instructions and added dialect recognition. These changes ensure that grammars are consistently formatted and easily resolvable during merges. Engineers adopting this project should reformat the grammar file before each commit, following the provided formatting instructions and reference link. Grammar modifications in the TSqlParser.g4 file, such as alterations in partitionFunction and freetextFunction rules, improve structure and readability. +* Upgrade sqlglot from 23.13.7 to 25.1.0 ([#473](https://github.com/databrickslabs/remorph/issues/473)). In the latest release, the sqlglot package has been upgraded from version 23.13.7 to 25.1.0, offering potential new features, bug fixes, and performance improvements for SQL processing. The package dependency for numpy has been updated to version 1.26.4, which may introduce new functionality, improve existing features, or fix numpy integration issues. Furthermore, the addition of the types-pytz package as a dependency provides type hints for pytz, enhancing codebase type checking and static analysis capabilities. Specific modifications to the test_sql_transpiler.py file include updating the expected result in the test_parse_query function and removing unnecessary whitespaces in the transpiled_sql assertion in the test_procedure_conversion function. Although the find_root_tables function remains unchanged, the upgrade to sqlglot promises overall functionality enhancements, which software engineers can leverage in their projects. +* Use default_factory in recon_config.py ([#431](https://github.com/databrickslabs/remorph/issues/431)). In this release, the default value handling for the `status` field in the `DataReconcileOutput` and `ReconcileTableOutput` classes has been improved to comply with Python 3.11. Previously, a mutable default value was used, causing a `ValueError` issue. This has been addressed by implementing the `default_factory` argument in the `field` function to ensure a new instance of `StatusOutput` is created for each class. Additionally, `MismatchOutput` and `ThresholdOutput` classes now also utilize `default_factory` for consistent and robust default value handling, enhancing the overall code quality and preventing potential issues arising from mutable default values. +* edit distance ([#501](https://github.com/databrickslabs/remorph/issues/501)). In this release, we have implemented an `edit distance` feature for calculating the difference between two strings using the LEVENSHTEIN function. This has been achieved by adding a new method, `anonymous_sql`, to the `Generator` class in the `databricks.py` file. The method takes expressions of the `Anonymous` type as arguments and calls the `LEVENSHTEIN` function if the `this` attribute of the expression is equal to "EDITDISTANCE". Additionally, a new test file has been introduced for the anonymous user in the functional snowflake test suite to ensure the accurate calculation of string similarity using the EDITDISTANCE function. This change includes examples of using the EDITDISTANCE function with different parameters and compares it with the LEVENSHTEIN function available in Databricks. It addresses issue [#500](https://github.com/databrickslabs/remorph/issues/500), which was related to testing the edit distance functionality. + + +## 0.2.0 + +* Capture Reconcile metadata in delta tables for dashbaords ([#369](https://github.com/databrickslabs/remorph/issues/369)). In this release, changes have been made to improve version control management, reduce repository size, and enhance build times. A new directory, "spark-warehouse/", has been added to the Git ignore file to prevent unnecessary files from being tracked and included in the project. The `WriteToTableException` class has been added to the `exception.py` file to raise an error when a runtime exception occurs while writing data to a table. A new `ReconCapture` class has been implemented in the `reconcile` package to capture and persist reconciliation metadata in delta tables. The `recon` function has been updated to initialize this new class, passing in the required parameters. Additionally, a new file, `recon_capture.py`, has been added to the reconcile package, which implements the `ReconCapture` class responsible for capturing metadata related to data reconciliation. The `recon_config.py` file has been modified to introduce a new class, `ReconcileProcessDuration`, and restructure the classes `ReconcileOutput`, `MismatchOutput`, and `ThresholdOutput`. The commit also captures reconcile metadata in delta tables for dashboards in the context of unit tests in the `test_execute.py` file and includes a new file, `test_recon_capture.py`, to test the reconcile capture functionality of the `ReconCapture` class. +* Expand translation of Snowflake `expr` ([#351](https://github.com/databrickslabs/remorph/issues/351)). In this release, the translation of the `expr` category in the Snowflake language has been significantly expanded, addressing uncovered grammar areas, incorrect interpretations, and duplicates. The `subquery` is now excluded as a valid `expr`, and new case classes such as `NextValue`, `ArrayAccess`, `JsonAccess`, `Collate`, and `Iff` have been added to the `Expression` class. These changes improve the comprehensiveness and accuracy of the Snowflake parser, allowing for a more flexible and accurate translation of various operations. Additionally, the `SnowflakeExpressionBuilder` class has been updated to handle previously unsupported cases, enhancing the parser's ability to parse Snowflake SQL expressions. +* Fixed orcale missing datatypes ([#333](https://github.com/databrickslabs/remorph/issues/333)). In the latest release, the Oracle class of the Tokenizer in the open-source library has undergone a fix to address missing datatypes. Previously, the KEYWORDS mapping did not require Tokens for keys, which led to unsupported Oracle datatypes. This issue has been resolved by modifying the test_schema_compare.py file to ensure that all Oracle datatypes, including LONG, NCLOB, ROWID, UROWID, ANYTYPE, ANYDATA, ANYDATASET, XMLTYPE, SDO_GEOMETRY, SDO_TOPO_GEOMETRY, and SDO_GEORASTER, are now mapped to the TEXT TokenType. This improvement enhances the compatibility of the code with Oracle datatypes and increases the reliability of the schema comparison functionality, as demonstrated by the test function test_schema_compare, which now returns is_valid as True and a count of 0 for is_valid = `false` in the resulting dataframe. +* Fixed the recon_config functions to handle null values ([#399](https://github.com/databrickslabs/remorph/issues/399)). In this release, the recon_config functions have been enhanced to manage null values and provide more flexible column mapping for reconciliation purposes. A `__post_init__` method has been added to certain classes to convert specified attributes to lowercase and handle null values. A new helper method, `_get_is_string`, has been introduced to determine if a column is of string type. Additionally, new functions such as `get_tgt_to_src_col_mapping_list`, `get_layer_tgt_to_src_col_mapping`, `get_src_to_tgt_col_mapping_list`, and `get_layer_src_to_tgt_col_mapping` have been added to retrieve column mappings, enhancing the overall functionality and robustness of the reconciliation process. These improvements will benefit software engineers by ensuring more accurate and reliable configuration handling, as well as providing more flexibility in mapping source and target columns during reconciliation. +* Improve Exception handling ([#392](https://github.com/databrickslabs/remorph/issues/392)). The commit titled `Improve Exception Handling` enhances error handling in the project, addressing issues [#388](https://github.com/databrickslabs/remorph/issues/388) and [#392](https://github.com/databrickslabs/remorph/issues/392). Changes include refactoring the `create_adapter` method in the `DataSourceAdapter` class, updating method arguments in test functions, and adding new methods in the `test_execute.py` file for better test doubles. The `DataSourceAdapter` class is replaced with the `create_adapter` function, which takes the same arguments and returns an instance of the appropriate `DataSource` subclass based on the provided `engine` parameter. The diff also modifies the behavior of certain test methods to raise more specific and accurate exceptions. Overall, these changes improve exception handling, streamline the codebase, and provide clearer error messages for software engineers. +* Introduced morph_sql and morph_column_expr functions for inline transpilation and validation ([#328](https://github.com/databrickslabs/remorph/issues/328)). Two new classes, TranspilationResult and ValidationResult, have been added to the config module of the remorph package to store the results of transpilation and validation. The morph_sql and morph_column_exp functions have been introduced to support inline transpilation and validation of SQL code and column expressions. A new class, Validator, has been added to the validation module to handle validation, and the validate_format_result method within this class has been updated to return a ValidationResult object. The _query method has also been added to the class, which executes a given SQL query and returns a tuple containing a boolean indicating success, any exception message, and the result of the query. Unit tests for these new functions have been updated to ensure proper functionality. +* Output for the reconcile function ([#389](https://github.com/databrickslabs/remorph/issues/389)). A new function `get_key_form_dialect` has been added to the `config.py` module, which takes a `Dialect` object and returns the corresponding key used in the `SQLGLOT_DIALECTS` dictionary. Additionally, the `MorphConfig` dataclass has been updated to include a new attribute `__file__`, which sets the filename to "config.yml". The `get_dialect` function remains unchanged. Two new exceptions, `WriteToTableException` and `InvalidInputException`, have been introduced, and the existing `DataSourceRuntimeException` has been modified in the same module to improve error handling. The `execute.py` file's reconcile function has undergone several changes, including adding imports for `InvalidInputException`, `ReconCapture`, and `generate_final_reconcile_output` from `recon_exception` and `recon_capture` modules, and modifying the `ReconcileOutput` type. The `hash_query.py` file's reconcile function has been updated to include a new `_get_with_clause` method, which returns a `Select` object for a given DataFrame, and the `build_query` method has been updated to include a new query construction step using the `with_clause` object. The `threshold_query.py` file's reconcile function's output has been updated to include query and logger statements, a new method for allowing user transformations on threshold aliases, and the dialect specified in the sql method. A new `generate_final_reconcile_output` function has been added to the `recon_capture.py` file, which generates a reconcile output given a recon_id and a SparkSession. New classes and dataclasses, including `SchemaReconcileOutput`, `ReconcileProcessDuration`, `StatusOutput`, `ReconcileTableOutput`, and `ReconcileOutput`, have been introduced in the `reconcile/recon_config.py` file. The `tests/unit/reconcile/test_execute.py` file has been updated to include new test cases for the `recon` function, including tests for different report types and scenarios, such as data, schema, and all report types, exceptions, and incorrect report types. A new test case, `test_initialise_data_source`, has been added to test the `initialise_data_source` function, and the `test_recon_for_wrong_report_type` test case has been updated to expect an `InvalidInputException` when an incorrect report type is passed to the `recon` function. The `test_reconcile_data_with_threshold_and_row_report_type` test case has been added to test the `reconcile_data` method of the `Reconciliation` class with a row report type and threshold options. Overall, these changes improve the functionality and robustness of the reconcile process by providing more fine-grained control over the generation of the final reconcile output and better handling of exceptions and errors. +* Threshold Source and Target query builder ([#348](https://github.com/databrickslabs/remorph/issues/348)). In this release, we've introduced a new method, `build_threshold_query`, that constructs a customizable threshold query based on a table's partition, join, and threshold columns configuration. The method identifies necessary columns, applies specified transformations, and includes a WHERE clause based on the filter defined in the table configuration. The resulting query is then converted to a SQL string using the dialect of the source database. Additionally, we've updated the test file for the threshold query builder in the reconcile package, including refactoring of function names and updated assertions for query comparison. We've added two new test methods: `test_build_threshold_query_with_single_threshold` and `test_build_threshold_query_with_multiple_thresholds`. These changes enhance the library's functionality, providing a more robust and customizable threshold query builder, and improve test coverage for various configurations and scenarios. +* Unpack nested alias ([#336](https://github.com/databrickslabs/remorph/issues/336)). This release introduces a significant update to the 'lca_utils.py' file, addressing the limitation of not handling nested aliases in window expressions and where clauses, which resolves issue [#334](https://github.com/databrickslabs/remorph/issues/334). The `unalias_lca_in_select` method has been implemented to recursively parse nested selects and unalias lateral column aliases, thereby identifying and handling unsupported lateral column aliases. This method is utilized in the `check_for_unsupported_lca` method to handle unsupported lateral column aliases in the input SQL string. Furthermore, the 'test_lca_utils.py' file has undergone changes, impacting several test functions and introducing two new ones, `test_fix_nested_lca` and 'test_fix_nested_lca_with_no_scope', to ensure the code's reliability and accuracy by preventing unnecessary assumptions and hallucinations. These updates demonstrate our commitment to improving the library's functionality and test coverage. + + +## 0.1.7 + +* Added `Configure Secrets` support to `databricks labs remorph configure-secrets` cli command ([#254](https://github.com/databrickslabs/remorph/issues/254)). The `Configure Secrets` feature has been implemented in the `databricks labs remorph` CLI command, specifically for the new `configure-secrets` command. This addition allows users to establish Scope and Secrets within their Databricks Workspace, enhancing security and control over resource access. The implementation includes a new `recon_config_utils.py` file in the `databricks/labs/remorph/helpers` directory, which contains classes and methods for managing Databricks Workspace secrets. Furthermore, the `ReconConfigPrompts` helper class has been updated to handle prompts for selecting sources, entering secret scope names, and handling overwrites. The CLI command has also been updated with a new `configure_secrets` function and corresponding tests to ensure correct functionality. +* Added handling for invalid alias usage by manipulating the AST ([#219](https://github.com/databrickslabs/remorph/issues/219)). The recent commit addresses the issue of invalid alias usage in SQL queries by manipulating the Abstract Syntax Tree (AST). It introduces a new method, `unalias_lca_in_select`, which unaliases Lateral Column Aliases (LCA) in the SELECT clause of a query. The AliasInfo class is added to manage aliases more effectively, with attributes for the name, expression, and a flag indicating if the alias name is the same as a column. Additionally, the execute.py file is modified to check for unsupported LCA using the `lca_utils.check_for_unsupported_lca` method, improving the system's robustness when handling invalid aliases. Test cases are also added in the new file, test_lca_utils.py, to validate the behavior of the `check_for_unsupported_lca` function, ensuring that SQL queries are correctly formatted for Snowflake dialect and avoiding errors due to invalid alias usage. +* Added support for `databricks labs remorph generate-lineage` CLI command ([#238](https://github.com/databrickslabs/remorph/issues/238)). A new CLI command, `databricks labs remorph generate-lineage`, has been added to generate lineage for input SQL files, taking the source dialect, input, and output directories as arguments. The command uses existing logic to generate a directed acyclic graph (DAG) and then creates a DOT file in the output directory using the DAG. The new command is supported by new functions `_generate_dot_file_contents`, `lineage_generator`, and methods in the `RootTableIdentifier` and `DAG` classes. The command has been manually tested and includes unit tests, with plans for adding integration tests in the future. The commit also includes a new method `temp_dirs_for_lineage` and updates to the `configure_secrets_databricks` method to handle a new source type "databricks". The command handles invalid input and raises appropriate exceptions. +* Custom oracle tokenizer ([#316](https://github.com/databrickslabs/remorph/issues/316)). In this release, the remorph library has been updated to enhance its handling of Oracle databases. A custom Oracle tokenizer has been developed to map the `LONG` datatype to text (string) in the tokenizer, allowing for more precise parsing and manipulation of `LONG` columns in Oracle databases. The Oracle dialect in the configuration file has also been updated to utilize the new custom Oracle tokenizer. Additionally, the Oracle class from the snow module has been imported and integrated into the Oracle dialect. These improvements will enable the remorph library to manage Oracle databases more efficiently, with a particular focus on improving the handling of the `LONG` datatype. The commit also includes updates to test files in the functional/oracle/test_long_datatype directory, which ensure the proper conversion of the `LONG` datatype to text. Furthermore, a new test file has been added to the tests/unit/snow directory, which checks for compatibility with Oracle's long data type. These changes enhance the library's compatibility with Oracle databases, ensuring accurate handling and manipulation of the `LONG` datatype in Oracle SQL and Databricks SQL. +* Removed strict source dialect checks ([#284](https://github.com/databrickslabs/remorph/issues/284)). In the latest release, the `transpile` and `generate_lineage` functions in `cli.py` have undergone changes to allow for greater flexibility in source dialect selection. Previously, only `snowflake` or `tsql` dialects were supported, but now any source dialect supported by SQLGLOT can be used, controlled by the `SQLGLOT_DIALECTS` dictionary. Providing an unsupported source dialect will result in a validation error. Additionally, the input and output folder paths for the `generate_lineage` function are now validated against the file system to ensure their existence and validity. In the `install.py` file of the `databricks/labs/remorph` package, the source dialect selection has been updated to use `SQLGLOT_DIALECTS.keys()`, replacing the previous hardcoded list. This change allows for more flexibility in selecting the source dialect. Furthermore, recent updates to various test functions in the `test_install.py` file suggest that the source selection process has been modified, possibly indicating the addition of new sources or a change in source identification. These modifications provide greater flexibility in testing and potentially in the actual application. +* Set Catalog, Schema from default Config ([#312](https://github.com/databrickslabs/remorph/issues/312)). A new feature has been added to our open-source library that allows users to specify the `catalog` and `schema` configuration options as part of the `transpile` command-line interface (CLI). If these options are not provided, the `transpile` function in the `cli.py` file will now set them to the values specified in `default_config`. This ensures that a default catalog and schema are used if they are not explicitly set by the user. The `labs.yml` file has been updated to reflect these changes, with the addition of the `catalog-name` and `schema-name` options to the `commands` object. The `default` property of the `validation` object has also been updated to `true`, indicating that the validation step will be skipped by default. These changes provide increased flexibility and ease-of-use for users of the `transpile` functionality. +* Support for Null safe equality join for databricks generator ([#280](https://github.com/databrickslabs/remorph/issues/280)). In this release, we have implemented support for a null-safe equality join in the Databricks generator, addressing issue [#280](https://github.com/databrickslabs/remorph/issues/280). This feature introduces the use of the " <=> " operator in the generated SQL code instead of the `is not distinct from` syntax to ensure accurate comparisons when NULL values are present in the columns being joined. The Generator class has been updated with a new method, NullSafeEQ, which takes in an expression and returns the binary version of the expression using the " <=> " operator. The preprocess method in the Generator class has also been modified to include this new functionality. It is important to note that this change may require users to update their existing code to align with the new syntax in the Databricks environment. With this enhancement, the Databricks generator is now capable of performing null-safe equality joins, resulting in consistent results regardless of the presence of NULL values in the join conditions. + + +## 0.1.6 + +* Added serverless validation using lsql library ([#176](https://github.com/databrickslabs/remorph/issues/176)). Workspaceclient object is used with `product` name and `product_version` along with corresponding `cluster_id` or `warehouse_id` as `sdk_config` in `MorphConfig` object. +* Enhanced install script to enforce usage of a warehouse or cluster when `skip-validation` is set to `False` ([#213](https://github.com/databrickslabs/remorph/issues/213)). In this release, the installation process has been enhanced to mandate the use of a warehouse or cluster when the `skip-validation` parameter is set to `False`. This change has been implemented across various components, including the install script, `transpile` function, and `get_sql_backend` function. Additionally, new pytest fixtures and methods have been added to improve test configuration and resource management during testing. Unit tests have been updated to enforce usage of a warehouse or cluster when the `skip-validation` flag is set to `False`, ensuring proper resource allocation and validation process improvement. This development focuses on promoting a proper setup and usage of the system, guiding new users towards a correct configuration and improving the overall reliability of the tool. +* Patch subquery with json column access ([#190](https://github.com/databrickslabs/remorph/issues/190)). The open-source library has been updated with new functionality to modify how subqueries with JSON column access are handled in the `snowflake.py` file. This change includes the addition of a check for an opening parenthesis after the `FROM` keyword to detect and break loops when a subquery is found, as opposed to a table name. This improvement enhances the handling of complex subqueries and JSON column access, making the code more robust and adaptable to different query structures. Additionally, a new test method, `test_nested_query_with_json`, has been introduced to the `tests/unit/snow/test_databricks.py` file to test the behavior of nested queries involving JSON column access when using a Snowflake dialect. This new method validates the expected output of a specific nested query when it is transpiled to Snowflake's SQL dialect, allowing for more comprehensive testing of JSON column access and type casting in Snowflake dialects. The existing `test_delete_from_keyword` method remains unchanged. +* Snowflake `UPDATE FROM` to Databricks `MERGE INTO` implementation ([#198](https://github.com/databrickslabs/remorph/issues/198)). +* Use Runtime SQL backend in Notebooks ([#211](https://github.com/databrickslabs/remorph/issues/211)). In this update, the `db_sql.py` file in the `databricks/labs/remorph/helpers` directory has been modified to support the use of the Runtime SQL backend in Notebooks. This change includes the addition of a new `RuntimeBackend` class in the `backends` module and an import statement for `os`. The `get_sql_backend` function now returns a `RuntimeBackend` instance when the `DATABRICKS_RUNTIME_VERSION` environment variable is present, allowing for more efficient and secure SQL statement execution in Databricks notebooks. Additionally, a new test case for the `get_sql_backend` function has been added to ensure the correct behavior of the function in various runtime environments. These enhancements improve SQL execution performance and security in Databricks notebooks and increase the project's versatility for different use cases. +* Added Issue Templates for bugs, feature and config ([#194](https://github.com/databrickslabs/remorph/issues/194)). Two new issue templates have been added to the project's GitHub repository to improve issue creation and management. The first template, located in `.github/ISSUE_TEMPLATE/bug.yml`, is for reporting bugs and prompts users to provide detailed information about the issue, including the current and expected behavior, steps to reproduce, relevant log output, and sample query. The second template, added under the path `.github/ISSUE_TEMPLATE/config.yml`, is for configuration-related issues and includes support contact links for general Databricks questions and Remorph documentation, as well as fields for specifying the operating system and software version. A new issue template for feature requests, named "Feature Request", has also been added, providing a structured format for users to submit requests for new functionality for the Remorph project. These templates will help streamline the issue creation process, improve the quality of information provided, and make it easier for the development team to quickly identify and address bugs and feature requests. +* Added Databricks Source Adapter ([#185](https://github.com/databrickslabs/remorph/issues/185)). In this release, the project has been enhanced with several new features for the Databricks Source Adapter. A new `engine` parameter has been added to the `DataSource` class, replacing the original `source` parameter. The `_get_secrets` and `_get_table_or_query` methods have been updated to use the `engine` parameter for key naming and handling queries with a `select` statement differently, respectively. A Databricks Source Adapter for Oracle databases has been introduced, which includes a new `OracleDataSource` class that provides functionality to connect to an Oracle database using JDBC. A Databricks Source Adapter for Snowflake has also been added, featuring the `SnowflakeDataSource` class that handles data reading and schema retrieval from Snowflake. The `DatabricksDataSource` class has been updated to handle data reading and schema retrieval from Databricks, including a new `get_schema_query` method that generates the query to fetch the schema based on the provided catalog and table name. Exception handling for reading data and fetching schema has been implemented for all new classes. These changes provide increased flexibility for working with various data sources, improved code maintainability, and better support for different use cases. +* Added Threshold Query Builder ([#188](https://github.com/databrickslabs/remorph/issues/188)). In this release, the open-source library has added a Threshold Query Builder feature, which includes several changes to the existing functionality in the data source connector. A new import statement adds the `re` module for regular expressions, and new parameters have been added to the `read_data` and `get_schema` abstract methods. The `_get_jdbc_reader_options` method has been updated to accept a `options` parameter of type "JdbcReaderOptions", and a new static method, "_get_table_or_query", has been added to construct the table or query string based on provided parameters. Additionally, a new class, "QueryConfig", has been introduced in the "databricks.labs.remorph.reconcile" package to configure queries for data reconciliation tasks. A new abstract base class QueryBuilder has been added to the query_builder.py file, along with HashQueryBuilder and ThresholdQueryBuilder classes to construct SQL queries for generating hash values and selecting columns based on threshold values, transformation rules, and filtering conditions. These changes aim to enhance the functionality of the data source connector, add modularity, customizability, and reusability to the query builder, and improve data reconciliation tasks. +* Added snowflake connector code ([#177](https://github.com/databrickslabs/remorph/issues/177)). In this release, the open-source library has been updated to add a Snowflake connector for data extraction and schema manipulation. The changes include the addition of the SnowflakeDataSource class, which is used to read data from Snowflake using PySpark, and has methods for getting the JDBC URL, reading data with and without JDBC reader options, getting the schema, and handling exceptions. A new constant, SNOWFLAKE, has been added to the SourceDriver enum in constants.py, which represents the Snowflake JDBC driver class. The code modifications include updating the constructor of the DataSource abstract base class to include a new parameter 'scope', and updating the `_get_secrets` method to accept a `key_name` parameter instead of 'key'. Additionally, a test file 'test_snowflake.py' has been added to test the functionality of the SnowflakeDataSource class. This release also updates the pyproject.toml file to version lock the dependencies like black, ruff, and isort, and modifies the coverage report configuration to exclude certain files and lines from coverage checks. These changes were completed by Ravikumar Thangaraj and SundarShankar89. +* `remorph reconcile` baseline for Query Builder and Source Adapter for oracle as source ([#150](https://github.com/databrickslabs/remorph/issues/150)). + +Dependency updates: + + * Bump sqlglot from 22.4.0 to 22.5.0 ([#175](https://github.com/databrickslabs/remorph/pull/175)). + * Updated databricks-sdk requirement from <0.22,>=0.18 to >=0.18,<0.23 ([#178](https://github.com/databrickslabs/remorph/pull/178)). + * Updated databricks-sdk requirement from <0.23,>=0.18 to >=0.18,<0.24 ([#189](https://github.com/databrickslabs/remorph/pull/189)). + * Bump actions/checkout from 3 to 4 ([#203](https://github.com/databrickslabs/remorph/pull/203)). + * Bump actions/setup-python from 4 to 5 ([#201](https://github.com/databrickslabs/remorph/pull/201)). + * Bump codecov/codecov-action from 1 to 4 ([#202](https://github.com/databrickslabs/remorph/pull/202)). + * Bump softprops/action-gh-release from 1 to 2 ([#204](https://github.com/databrickslabs/remorph/pull/204)). + +## 0.1.5 + +* Added Pylint Checker ([#149](https://github.com/databrickslabs/remorph/issues/149)). This diff adds a Pylint checker to the project, which is used to enforce a consistent code style, identify potential bugs, and check for errors in the Python code. The configuration for Pylint includes various settings, such as a line length limit, the maximum number of arguments for a function, and the maximum number of lines in a module. Additionally, several plugins have been specified to load, which add additional checks and features to Pylint. The configuration also includes settings that customize the behavior of Pylint's naming conventions checks and handle various types of code constructs, such as exceptions, logging statements, and import statements. By using Pylint, the project can help ensure that its code is of high quality, easy to understand, and free of bugs. This diff includes changes to various files, such as cli.py, morph_status.py, validate.py, and several SQL-related files, to ensure that they adhere to the desired Pylint configuration and best practices for code quality and organization. +* Fixed edge case where column name is same as alias name ([#164](https://github.com/databrickslabs/remorph/issues/164)). A recent commit has introduced fixes for edge cases related to conflicts between column names and alias names in SQL queries, addressing issues [#164](https://github.com/databrickslabs/remorph/issues/164) and [#130](https://github.com/databrickslabs/remorph/issues/130). The `check_for_unsupported_lca` function has been updated with two helper functions `_find_aliases_in_select` and `_find_invalid_lca_in_window` to detect aliases with the same name as a column in a SELECT expression and identify invalid Least Common Ancestors (LCAs) in window functions, respectively. The `find_windows_in_select` function has been refactored and renamed to `_find_windows_in_select` for improved code readability. The `transpile` and `parse` functions in the `sql_transpiler.py` file have been updated with try-except blocks to handle cases where a column name matches the alias name, preventing errors or exceptions such as `ParseError`, `TokenError`, and `UnsupportedError`. A new unit test, "test_query_with_same_alias_and_column_name", has been added to verify the fix, passing a SQL query with a subquery having a column alias `ca_zip` which is also used as a column name in the same query, confirming that the function correctly handles the scenario where a column name conflicts with an alias name. +* `TO_NUMBER` without `format` edge case ([#172](https://github.com/databrickslabs/remorph/issues/172)). The `TO_NUMBER without format edge case` commit introduces changes to address an unsupported usage of the `TO_NUMBER` function in Databicks SQL dialect when the `format` parameter is not provided. The new implementation introduces constants `PRECISION_CONST` and `SCALE_CONST` (set to 38 and 0 respectively) as default values for `precision` and `scale` parameters. These changes ensure Databricks SQL dialect requirements are met by modifying the `_to_number` method to incorporate these constants. An `UnsupportedError` will now be raised when `TO_NUMBER` is called without a `format` parameter, improving error handling and ensuring users are aware of the required `format` parameter. Test cases have been added for `TO_DECIMAL`, `TO_NUMERIC`, and `TO_NUMBER` functions with format strings, covering cases where the format is taken from table columns. The commit also ensures that an error is raised when `TO_DECIMAL` is called without a format parameter. + +Dependency updates: + + * Bump sqlglot from 21.2.1 to 22.0.1 ([#152](https://github.com/databrickslabs/remorph/pull/152)). + * Bump sqlglot from 22.0.1 to 22.1.1 ([#159](https://github.com/databrickslabs/remorph/pull/159)). + * Updated databricks-labs-blueprint[yaml] requirement from ~=0.2.3 to >=0.2.3,<0.4.0 ([#162](https://github.com/databrickslabs/remorph/pull/162)). + * Bump sqlglot from 22.1.1 to 22.2.0 ([#161](https://github.com/databrickslabs/remorph/pull/161)). + * Bump sqlglot from 22.2.0 to 22.2.1 ([#163](https://github.com/databrickslabs/remorph/pull/163)). + * Updated databricks-sdk requirement from <0.21,>=0.18 to >=0.18,<0.22 ([#168](https://github.com/databrickslabs/remorph/pull/168)). + * Bump sqlglot from 22.2.1 to 22.3.1 ([#170](https://github.com/databrickslabs/remorph/pull/170)). + * Updated databricks-labs-blueprint[yaml] requirement from <0.4.0,>=0.2.3 to >=0.2.3,<0.5.0 ([#171](https://github.com/databrickslabs/remorph/pull/171)). + * Bump sqlglot from 22.3.1 to 22.4.0 ([#173](https://github.com/databrickslabs/remorph/pull/173)). + +## 0.1.4 + +* Added conversion logic for Try_to_Decimal without format ([#142](https://github.com/databrickslabs/remorph/pull/142)). +* Identify Root Table for folder containing SQLs ([#124](https://github.com/databrickslabs/remorph/pull/124)). +* Install Script ([#106](https://github.com/databrickslabs/remorph/pull/106)). +* Integration Test Suite ([#145](https://github.com/databrickslabs/remorph/pull/145)). + +Dependency updates: + + * Updated databricks-sdk requirement from <0.20,>=0.18 to >=0.18,<0.21 ([#143](https://github.com/databrickslabs/remorph/pull/143)). + * Bump sqlglot from 21.0.0 to 21.1.2 ([#137](https://github.com/databrickslabs/remorph/pull/137)). + * Bump sqlglot from 21.1.2 to 21.2.0 ([#147](https://github.com/databrickslabs/remorph/pull/147)). + * Bump sqlglot from 21.2.0 to 21.2.1 ([#148](https://github.com/databrickslabs/remorph/pull/148)). + +## 0.1.3 + +* Added support for WITHIN GROUP for ARRAY_AGG and LISTAGG functions ([#133](https://github.com/databrickslabs/remorph/pull/133)). +* Fixed Merge "INTO" for delete from syntax ([#129](https://github.com/databrickslabs/remorph/pull/129)). +* Fixed `DATE TRUNC` parse errors ([#131](https://github.com/databrickslabs/remorph/pull/131)). +* Patched Logger function call during wheel file ([#135](https://github.com/databrickslabs/remorph/pull/135)). +* Patched extra call to root path ([#126](https://github.com/databrickslabs/remorph/pull/126)). + +Dependency updates: + + * Updated databricks-sdk requirement from ~=0.18.0 to >=0.18,<0.20 ([#134](https://github.com/databrickslabs/remorph/pull/134)). + +## 0.1.2 + +* Fixed duplicate LCA warnings ([#108](https://github.com/databrickslabs/remorph/pull/108)). +* Fixed invalid flagging of LCA usage ([#117](https://github.com/databrickslabs/remorph/pull/117)). + +Dependency updates: + + * Bump sqlglot from 20.10.0 to 20.11.0 ([#95](https://github.com/databrickslabs/remorph/pull/95)). + * Bump sqlglot from 20.11.0 to 21.0.0 ([#122](https://github.com/databrickslabs/remorph/pull/122)). + +## 0.1.1 + +* Added test_approx_percentile and test_trunc Testcases ([#98](https://github.com/databrickslabs/remorph/pull/98)). +* Updated contributing/developer guide ([#97](https://github.com/databrickslabs/remorph/pull/97)). + + +## 0.1.0 + +* Added baseline for Databricks CLI frontend ([#60](https://github.com/databrickslabs/remorph/pull/60)). +* Added custom Databricks dialect test cases and lateral struct parsing ([#77](https://github.com/databrickslabs/remorph/pull/77)). +* Extended Snowflake to Databricks functions coverage ([#72](https://github.com/databrickslabs/remorph/pull/72), [#69](https://github.com/databrickslabs/remorph/pull/69)). +* Added `databricks labs remorph transpile` documentation for installation and usage ([#73](https://github.com/databrickslabs/remorph/pull/73)). + +Dependency updates: + + * Bump sqlglot from 20.8.0 to 20.9.0 ([#83](https://github.com/databrickslabs/remorph/pull/83)). + * Updated databricks-sdk requirement from ~=0.17.0 to ~=0.18.0 ([#90](https://github.com/databrickslabs/remorph/pull/90)). + * Bump sqlglot from 20.9.0 to 20.10.0 ([#91](https://github.com/databrickslabs/remorph/pull/91)). + +## 0.0.1 + +Initial commit diff --git a/CODEOWNERS b/CODEOWNERS index c811905ffa..d724c297f2 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -* @databrickslabs/remorph-write \ No newline at end of file +* @databrickslabs/role-labs-remorph-write diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cd904ea6e5..f81c1e60ab 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,29 +2,62 @@ ## First Principles -We must use the [Databricks SDK for Python](https://databricks-sdk-py.readthedocs.io/) in this project. It is a toolkit for our project. -If something doesn't naturally belong to the `WorkspaceClient`, it must go through a "mixin" process before it can be used with the SDK. -Imagine the `WorkspaceClient` as the main control center and the "mixin" process as a way to adapt other things to work with it. -You can find an example of how mixins are used with `StatementExecutionExt`. There's a specific example of how to make something -work with the WorkspaceClient using `StatementExecutionExt`. This example can help you understand how mixins work in practice. - -Favoring standard libraries over external dependencies, especially in specific contexts like Databricks, is a best practice in software -development. +Favoring standard libraries over external dependencies, especially in specific contexts like Databricks, +is a best practice in software development. There are several reasons why this approach is encouraged: -- Standard libraries are typically well-vetted, thoroughly tested, and maintained by the official maintainers of the programming language or platform. This ensures a higher level of stability and reliability. -- External dependencies, especially lesser-known or unmaintained ones, can introduce bugs, security vulnerabilities, or compatibility issues that can be challenging to resolve. Adding external dependencies increases the complexity of your codebase. -- Each dependency may have its own set of dependencies, potentially leading to a complex web of dependencies that can be difficult to manage. This complexity can lead to maintenance challenges, increased risk, and longer build times. -- External dependencies can pose security risks. If a library or package has known security vulnerabilities and is widely used, it becomes an attractive target for attackers. Minimizing external dependencies reduces the potential attack surface and makes it easier to keep your code secure. -- Relying on standard libraries enhances code portability. It ensures your code can run on different platforms and environments without being tightly coupled to specific external dependencies. This is particularly important in settings like Databricks, where you may need to run your code on different clusters or setups. -- External dependencies may have their versioning schemes and compatibility issues. When using standard libraries, you have more control over versioning and can avoid conflicts between different dependencies in your project. -- Fewer external dependencies mean faster build and deployment times. Downloading, installing, and managing external packages can slow down these processes, especially in large-scale projects or distributed computing environments like Databricks. -- External dependencies can be abandoned or go unmaintained over time. This can lead to situations where your project relies on outdated or unsupported code. When you depend on standard libraries, you have confidence that the core functionality you rely on will continue to be maintained and improved. - -While minimizing external dependencies is essential, exceptions can be made case-by-case. There are situations where external dependencies are -justified, such as when a well-established and actively maintained library provides significant benefits, like time savings, performance improvements, +- Standard libraries are typically well-vetted, thoroughly tested, and maintained by the official maintainers of the programming language or platform. This ensures a higher level of stability and reliability. +- External dependencies, especially lesser-known or unmaintained ones, can introduce bugs, security vulnerabilities, or compatibility issues that can be challenging to resolve. Adding external dependencies increases the complexity of your codebase. +- Each dependency may have its own set of dependencies, potentially leading to a complex web of dependencies that can be difficult to manage. This complexity can lead to maintenance challenges, increased risk, and longer build times. +- External dependencies can pose security risks. If a library or package has known security vulnerabilities and is widely used, it becomes an attractive target for attackers. Minimizing external dependencies reduces the potential attack surface and makes it easier to keep your code secure. +- Relying on standard libraries enhances code portability. It ensures your code can run on different platforms and environments without being tightly coupled to specific external dependencies. This is particularly important in settings like Databricks, where you may need to run your code on different clusters or setups. +- External dependencies may have their versioning schemes and compatibility issues. When using standard libraries, you have more control over versioning and can avoid conflicts between different dependencies in your project. +- Fewer external dependencies mean faster build and deployment times. Downloading, installing, and managing external packages can slow down these processes, especially in large-scale projects or distributed computing environments like Databricks. +- External dependencies can be abandoned or go unmaintained over time. This can lead to situations where your project relies on outdated or unsupported code. When you depend on standard libraries, you have confidence that the core functionality you rely on will continue to be maintained and improved. + +While minimizing external dependencies is essential, exceptions can be made case-by-case. There are situations where external dependencies are +justified, such as when a well-established and actively maintained library provides significant benefits, like time savings, performance improvements, or specialized functionality unavailable in standard libraries. +## GPG signing +The Remorph project requires any commit to be signed-off using GPG signing. +Before you submit any commit, please make sure you are properly setup, as follows. + +If you don't already have one, create a GPG key: + - on MacOS, install the GPG Suite from https://gpgtools.org/ + - from the Applications folder, launch the GPG Keychain app + - create a new GPG key, using your Databricks email + - Right-click on the created key and select Export, to save the key + - Check the key using TextEdit, it should start with -----BEGIN PGP PUBLIC KEY BLOCK----- + +Register your PGP key in GitHub: + - In GitHub, select Settings from your picture at the top-right + - select SSH and PGP key + - click on New key, and paste the text content of the exported key + - select Emails + - if your databricks email is not registered, register it + - complete the verification before the next steps + +Tell local git to signoff your commits using your PGP key + - see full instructions here https://docs.github.com/en/authentication/managing-commit-signature-verification/telling-git-about-your-signing-key + - in short, you need to run the following commands from a terminal: + - git config --global --unset gpg.format + - gpg --list-secret-keys --keyid-format=long + - git config --global user.signingkey + - git config --global commit.gpgsign true + +Once all this is done, you can verify it's correct as follows: + - create a branch and use it + - create a file with some content + - git add + - git commit -m "test PGP" + - git verify-commit +The last command should display something like the following: +`gpg: Signature made Tue Nov 26 11:34:23 2024 CET +gpg: using RSA key FD4D754BB2B1D4F09F2BF658F4B0C73DFC65A17B +gpg: Good signature from "GitHub " [ultimate] +` + ## Change management When you introduce a change in the code, specifically a deeply technical one, please ensure that the change provides same or improved set of capabilities. @@ -40,95 +73,78 @@ _Keep API components simple._ In the components responsible for API interactions Refrain from overloading them with complex logic; instead, focus on making API calls and handling the data from those calls. _Inject Business Logic._ If you need to use business logic in your API-calling components, don't build it directly there. -Instead, inject (or pass in) the business logic components into your API components. This way, you can keep your API components +Instead, inject (or pass in) the business logic components into your API components. This way, you can keep your API components clean and flexible, while the business logic remains separate and reusable. -_Test your Business Logic._ It's essential to test your business logic to ensure it works correctly and thoroughly. When writing -unit tests, avoid making actual API calls - unit tests are executed for every pull request, and **_take seconds to complete_**. -For calling any external services, including Databricks Connect, Databricks Platform, or even Apache Spark, unit tests have -to use "mocks" or fake versions of the APIs to simulate their behavior. This makes testing your code more manageable and catching any -issues without relying on external systems. Focus on testing the edge cases of the logic, especially the scenarios where +_Test your Business Logic._ It's essential to test your business logic to ensure it works correctly and thoroughly. When writing +unit tests, avoid making actual API calls - unit tests are executed for every pull request, and **_take seconds to complete_**. +For calling any external services, including Databricks Connect, Databricks Platform, or even Apache Spark, unit tests have +to use "mocks" or fake versions of the APIs to simulate their behavior. This makes testing your code more manageable and catching any +issues without relying on external systems. Focus on testing the edge cases of the logic, especially the scenarios where things may fail. See [this example](https://github.com/databricks/databricks-sdk-py/pull/295) as a reference of an extensive unit test coverage suite and the clear difference between _unit tests_ and _integration tests_. -## Integration Testing Infrastructure - -Integration tests must accompany all new code additions. Integration tests help us validate that various parts of -our application work correctly when they interact with each other or external systems. This practice ensures that our -software _**functions as a cohesive whole**_. Integration tests run every night and take approximately 15 minutes -for the entire test suite to complete. - -We encourage using predefined test infrastructure provided through environment variables for integration tests. -These fixtures are set up in advance to simulate specific scenarios, making it easier to test different use cases. These -predefined fixtures enhance test consistency and reliability and point to the real infrastructure used by integration -testing. See [Unified Authentication Documentation](https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html) -for the latest reference of environment variables related to authentication. - -- `CLOUD_ENV`: This environment variable specifies the cloud environment where Databricks is hosted. The values typically - indicate the cloud provider being used, such as "aws" for Amazon Web Services and "azure" for Microsoft Azure. -- `DATABRICKS_ACCOUNT_ID`: This variable stores the unique identifier for your Databricks account. -- `DATABRICKS_HOST`: This variable contains the URL of your Databricks workspace. It is the web address you use to access - your Databricks environment and typically looks like "https://dbc-....cloud.databricks.com." -- `TEST_DEFAULT_CLUSTER_ID`: This variable holds the identifier for the default cluster used in testing. The value - resembles a unique cluster ID, like "0824-163015-tdtagl1h." -- `TEST_DEFAULT_WAREHOUSE_DATASOURCE_ID`: This environment variable stores the identifier for the default warehouse data - source used in testing. The value is a unique identifier for the data source, such as "3c0fef12-ff6c-...". -- `TEST_DEFAULT_WAREHOUSE_ID`: This variable contains the identifier for the default warehouse used in testing. The value - resembles a unique warehouse ID, like "49134b80d2...". -- `TEST_INSTANCE_POOL_ID`: This environment variable stores the identifier for the instance pool used in testing. - You must utilise existing instance pools as much as possible for cluster startup time and cost reduction. - The value is a unique instance pool ID, like "0824-113319-...". -- `TEST_LEGACY_TABLE_ACL_CLUSTER_ID`: This variable holds the identifier for the cluster used in testing legacy table - access control. The value is a unique cluster ID, like "0824-161440-...". -- `TEST_USER_ISOLATION_CLUSTER_ID`: This environment variable contains the identifier for the cluster used in testing - user isolation. The value is a unique cluster ID, like "0825-164947-...". - -Use the following command to run the integration tests: +## JVM Proxy -```shell -make integration -``` +In order to use this, you have to install `remorph` on any workspace via `databricks labs install .`, +so that `.databricks-login.json` file gets created with the following contents: -We'd like to encourage you to leverage the extensive set of [pytest fixtures](https://docs.pytest.org/en/latest/explanation/fixtures.html#about-fixtures). -These fixtures follow a consistent naming pattern, starting with "make_". These functions can be called multiple -times to _create and clean up objects as needed_ for your tests. Reusing these fixtures helps maintain clean and consistent -test setups across the codebase. In cases where your tests require unique fixture setups, keeping the wall -clock time of fixture initialization under one second is crucial. Fast fixture initialization ensures that tests run quickly, reducing -development cycle times and allowing for more immediate feedback during development. +``` +{ + "workspace_profile": "labs-azure-tool", + "cluster_id": "0708-200540-wcwi4i9e" +} +``` -```python -from databricks.sdk.service.workspace import AclPermission -from databricks.labs.ucx.mixins.fixtures import * # noqa: F403 +then run `make dev-cli` to collect classpath information. And then invoke commands, +like `databricks labs remorph debug-script --name file`. Add `--debug` flag to recompile project each run. -def test_secret_scope_acl(make_secret_scope, make_secret_scope_acl, make_group): - scope_name = make_secret_scope() - make_secret_scope_acl(scope=scope_name, principal=make_group().display_name, permission=AclPermission.WRITE) +Example output is: +```text +databricks labs remorph debug-script --name foo +21:57:42 INFO [databricks.sdk] Using Azure CLI authentication with AAD tokens +21:57:42 WARN [databricks.sdk] azure_workspace_resource_id field not provided. It is recommended to specify this field in the Databricks configuration to avoid authentication errors. +Debugging script... +Map(log_level -> disabled, name -> foo) ``` -Each integration test _must be debuggable within the free [IntelliJ IDEA (Community Edition)](https://www.jetbrains.com/idea/download) -with the [Python plugin (Community Edition)](https://plugins.jetbrains.com/plugin/7322-python-community-edition). If it works within -IntelliJ CE, then it would work in PyCharm. Debugging capabilities are essential for troubleshooting and diagnosing issues during -development. Please make sure that your test setup allows for easy debugging by following best practices. +## Local Setup -![debugging tests](docs/debugging-tests.gif) +This section provides a step-by-step guide to set up and start working on the project. These steps will help you set up your project environment and dependencies for efficient development. -Adhering to these guidelines ensures that our integration tests are robust, efficient, and easily maintainable. This, -in turn, contributes to the overall reliability and quality of our software. +To begin, install prerequisites: -Currently, VSCode IDE is not supported, as it does not offer interactive debugging single integration tests. -However, it's possible that this limitation may be addressed in the future. +`wget` is required by the maven installer +```shell +brew install wget +``` -## Local Setup +`maven` is the dependency manager for JVM based languages +```shell +brew install maven +``` -This section provides a step-by-step guide to set up and start working on the project. These steps will help you set up your project environment and dependencies for efficient development. +`jdk11` is the jdk used by remorph +download it from [OpenJDK11](https://www.openlogic.com/openjdk-downloads?field_java_parent_version_target_id=406&field_operating_system_target_id=431&field_architecture_target_id=391&field_java_package_target_id=396) and install it -To begin, run `make dev` to install [Hatch](https://github.com/pypa/hatch), create the default environment and install development dependencies, assuming you've already cloned the github repo. +`python` is the dependency manager for JVM based languages +```shell +brew install maven +``` +`hatch` is a Python project manager +```shell +pip install hatch +``` + +Then run project-specific install scripts + +`make dev` creates the default environment and installs development dependencies, assuming you've already cloned the github repo. ```shell make dev ``` -Verify installation with +Verify installation with ```shell make test ``` @@ -138,8 +154,14 @@ To ensure your integrated development environment (IDE) uses the newly created v hatch run python -c "import sys; print(sys.executable)" ``` -Configure your IDE to use this Python path so that you work within the virtual environment when developing the project: -![IDE Setup](docs/hatch-intellij.gif) +As of writing, we only support IntelliJ IDEA CE 2024.1. Development using more recent versions doesn't work (yet!). +Download and install [IntelliJ IDEA](https://www.jetbrains.com/idea/download/other.html) + +Configure your IDE to: + - use OpenJDK11 as the SDK for the project + - install the IntelliJ Scala plugin version 2024.1.25. Do not use more recent versions, they don't work!!! + - use this Python venv path so that you work within the virtual environment when developing the project: +![IDE Setup](docs/img/remorph_intellij.gif) Before every commit, apply the consistent formatting of the code, as we want our codebase look consistent: ```shell @@ -147,16 +169,34 @@ make fmt ``` Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated -pull request checks do pass, before your code is reviewed by others: +pull request checks do pass, before your code is reviewed by others: ```shell make lint test ``` +## IDE plugins + +If you will be working with the ANTLR grammars, then you should install the ANTLR plugin for your IDE. There +is a plugin for VS Code, but it does not have as many checks as the one for IntelliJ IDEA. + +While the ANTLR tool run at build time, will warn (and the build will stop on warnings) about things like +tokens that are used in the parser grammar but not defined in the lexer grammar, the IntelliJ IDEA plugin +provides a few extra tools such as identifying unused rules, and providing a visual representation of trees +etc. + +Please read the documentation for the plugin so that you can make the most of it and have it generate +the lexer and parser in a temp directory to tell you about things like undefined tokens etc. + +If you intended to make changes to the ANTLR defined syntax, please read teh README.md under ./core before +doing so. Changes to ANTLR grammars can have a big knock on effects on the rest of the codebase, and must +be carefully reviewed and tested - all the way from parse to code generation. Such changes are generally +not suited to beginners. + ## First contribution Here are the example steps to submit your first contribution: -1. Make a Fork from ucx repo (if you really want to contribute) +1. Make a Fork from remorph repo (if you really want to contribute) 2. `git clone` 3. `git checkout main` (or `gcm` if you're using [ohmyzsh](https://ohmyz.sh/)). 4. `git pull` (or `gl` if you're using [ohmyzsh](https://ohmyz.sh/)). @@ -169,70 +209,12 @@ Here are the example steps to submit your first contribution: 11. .. fix if any 12. `git commit -a`. Make sure to enter meaningful commit message title. 13. `git push origin FEATURENAME` -14. Go to GitHub UI and create PR. Alternatively, `gh pr create` (if you have [GitHub CLI](https://cli.github.com/) installed). +14. Go to GitHub UI and create PR. Alternatively, `gh pr create` (if you have [GitHub CLI](https://cli.github.com/) installed). Use a meaningful pull request title because it'll appear in the release notes. Use `Resolves #NUMBER` in pull request description to [automatically link it](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue) - to an existing issue. + to an existing issue. 15. announce PR for the review ## Troubleshooting If you encounter any package dependency errors after `git pull`, run `make clean` - -### Environment Issues - -Sometimes, when dependencies are updated via `dependabot` for example, the environment may report the following error: - -```sh -$ hatch run unit:test-cov-report -ERROR: Cannot install databricks-labs-ucx[test]==0.0.3 and databricks-sdk~=0.8.0 because these package versions have conflicting dependencies. -ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts -``` - -The easiest fix is to remove the environment and have the re-run recreate it: - -```sh -$ hatch env show - Standalone -┏━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Name ┃ Type ┃ Dependencies ┃ Scripts ┃ -┡━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ default │ virtual │ │ │ -├─────────────┼─────────┼────────────────────────────────┼─────────────────┤ -│ unit │ virtual │ databricks-labs-ucx[test] │ test │ -│ │ │ delta-spark<3.0.0,>=2.4.0 │ test-cov-report │ -│ │ │ pyspark<=3.5.0,>=3.4.0 │ │ -├─────────────┼─────────┼────────────────────────────────┼─────────────────┤ -│ integration │ virtual │ databricks-labs-ucx[dbconnect] │ test │ -│ │ │ databricks-labs-ucx[test] │ │ -│ │ │ delta-spark<3.0.0,>=2.4.0 │ │ -├─────────────┼─────────┼────────────────────────────────┼─────────────────┤ -│ lint │ virtual │ black>=23.1.0 │ fmt │ -│ │ │ isort>=2.5.0 │ verify │ -│ │ │ ruff>=0.0.243 │ │ -└─────────────┴─────────┴────────────────────────────────┴─────────────────┘ - -$ hatch env remove unit -$ hatch run unit:test-cov-report -========================================================================================== test session starts =========================================================================================== -platform darwin -- Python 3.11.4, pytest-7.4.1, pluggy-1.3.0 -- /Users/lars.george/Library/Application Support/hatch/env/virtual/databricks-labs-ucx/H6b8Oom-/unit/bin/python -cachedir: .pytest_cache -rootdir: /Users/lars.george/projects/work/databricks/ucx -configfile: pyproject.toml -plugins: cov-4.1.0, mock-3.11.1 -collected 103 items - -tests/unit/test_config.py::test_initialization PASSED -tests/unit/test_config.py::test_reader PASSED -... -tests/unit/test_tables.py::test_uc_sql[table1-CREATE VIEW IF NOT EXISTS new_catalog.db.view AS SELECT * FROM table;] PASSED -tests/unit/test_tables.py::test_uc_sql[table2-CREATE TABLE IF NOT EXISTS new_catalog.db.external_table LIKE catalog.db.external_table COPY LOCATION;ALTER TABLE catalog.db.external_table SET TBLPROPERTIES ('upgraded_to' = 'new_catalog.db.external_table');] PASSED - ----------- coverage: platform darwin, python 3.11.4-final-0 ---------- -Coverage HTML written to dir htmlcov - -========================================================================================== 103 passed in 12.61s ========================================================================================== -$ -``` - -Note: The initial `hatch env show` is just to list the environments managed by Hatch and is not needed. diff --git a/LICENSE b/LICENSE index e02a93e6da..c8d0d24aec 100644 --- a/LICENSE +++ b/LICENSE @@ -1,25 +1,69 @@ -DB license + Databricks License + Copyright (2024) Databricks, Inc. -Copyright (2023) Databricks, Inc. + Definitions. + + Agreement: The agreement between Databricks, Inc., and you governing + the use of the Databricks Services, as that term is defined in + the Master Cloud Services Agreement (MCSA) located at + www.databricks.com/legal/mcsa. + + Licensed Materials: The source code, object code, data, and/or other + works to which this license applies. -Definitions. + Scope of Use. You may not use the Licensed Materials except in + connection with your use of the Databricks Services pursuant to + the Agreement. Your use of the Licensed Materials must comply at all + times with any restrictions applicable to the Databricks Services, + generally, and must be used in accordance with any applicable + documentation. You may view, use, copy, modify, publish, and/or + distribute the Licensed Materials solely for the purposes of using + the Licensed Materials within or connecting to the Databricks Services. + If you do not agree to these terms, you may not view, use, copy, + modify, publish, and/or distribute the Licensed Materials. + + Redistribution. You may redistribute and sublicense the Licensed + Materials so long as all use is in compliance with these terms. + In addition: + + - You must give any other recipients a copy of this License; + - You must cause any modified files to carry prominent notices + stating that you changed the files; + - You must retain, in any derivative works that you distribute, + all copyright, patent, trademark, and attribution notices, + excluding those notices that do not pertain to any part of + the derivative works; and + - If a "NOTICE" text file is provided as part of its + distribution, then any derivative works that you distribute + must include a readable copy of the attribution notices + contained within such NOTICE file, excluding those notices + that do not pertain to any part of the derivative works. -Agreement: The agreement between Databricks, Inc., and you governing the use of the Databricks Services, which shall be, with respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless you have entered into a separate written agreement with Databricks governing the use of the applicable Databricks Services. + You may add your own copyright statement to your modifications and may + provide additional license terms and conditions for use, reproduction, + or distribution of your modifications, or for any such derivative works + as a whole, provided your use, reproduction, and distribution of + the Licensed Materials otherwise complies with the conditions stated + in this License. -Software: The source code and object code to which this license applies. + Termination. This license terminates automatically upon your breach of + these terms or upon the termination of your Agreement. Additionally, + Databricks may terminate this license at any time on notice. Upon + termination, you must permanently delete the Licensed Materials and + all copies thereof. -Scope of Use. You may not use this Software except in connection with your use of the Databricks Services pursuant to the Agreement. Your use of the Software must comply at all times with any restrictions applicable to the Databricks Services, generally, and must be used in accordance with any applicable documentation. You may view, use, copy, modify, publish, and/or distribute the Software solely for the purposes of using the code within or connecting to the Databricks Services. If you do not agree to these terms, you may not view, use, copy, modify, publish, and/or distribute the Software. + DISCLAIMER; LIMITATION OF LIABILITY. -Redistribution. You may redistribute and sublicense the Software so long as all use is in compliance with these terms. In addition: - -You must give any other recipients a copy of this License; -You must cause any modified files to carry prominent notices stating that you changed the files; -You must retain, in the source code form of any derivative works that you distribute, all copyright, patent, trademark, and attribution notices from the source code form, excluding those notices that do not pertain to any part of the derivative works; and -If the source code form includes a "NOTICE" text file as part of its distribution, then any derivative works that you distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the derivative works. -You may add your own copyright statement to your modifications and may provide additional license terms and conditions for use, reproduction, or distribution of your modifications, or for any such derivative works as a whole, provided your use, reproduction, and distribution of the Software otherwise complies with the conditions stated in this License. - -Termination. This license terminates automatically upon your breach of these terms or upon the termination of your Agreement. Additionally, Databricks may terminate this license at any time on notice. Upon termination, you must permanently delete the Software and all copies thereof. - -DISCLAIMER; LIMITATION OF LIABILITY. - -THE SOFTWARE IS PROVIDED “AS-IS” AND WITH ALL FAULTS. DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY DISCLAIMS ALL WARRANTIES RELATING TO THE SOURCE CODE, EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE SOURCE CODE SHALL BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file + THE LICENSED MATERIALS ARE PROVIDED “AS-IS” AND WITH ALL FAULTS. + DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY + DISCLAIMS ALL WARRANTIES RELATING TO THE LICENSED MATERIALS, EXPRESS + AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, + CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR + FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND + ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF + YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE LICENSED MATERIALS SHALL + BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS OR + THE USE OR OTHER DEALINGS IN THE LICENSED MATERIALS. diff --git a/Makefile b/Makefile index 77e3c69b38..ec078c88af 100644 --- a/Makefile +++ b/Makefile @@ -1,27 +1,64 @@ -all: clean lint fmt test +all: clean dev fmt lint test clean: - rm -fr htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml - hatch env remove unit + rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml dev: - pip install hatch hatch env create hatch run pip install -e '.[test]' hatch run which python lint: - hatch run lint:verify + hatch run verify -fmt: - hatch run lint:fmt +fmt: fmt-python fmt-scala -test: - hatch run unit:test +fmt-python: + hatch run fmt + +fmt-scala: + mvn validate -Pformat + +test: test-python test-scala + +setup_spark_remote: + .github/scripts/setup_spark_remote.sh + +test-python: setup_spark_remote + hatch run test + +test-scala: + mvn test -f pom.xml integration: - hatch run integration:test + hatch run integration coverage: - hatch run unit:test-cov-report && open htmlcov/index.html + hatch run coverage && open htmlcov/index.html + +build_core_jar: dev-cli + mvn --file pom.xml -pl core package + +clean_coverage_dir: + rm -fr ${OUTPUT_DIR} + +python_coverage_report: + hatch run python src/databricks/labs/remorph/coverage/remorph_snow_transpilation_coverage.py + hatch run pip install --upgrade sqlglot + hatch -e sqlglot-latest run python src/databricks/labs/remorph/coverage/sqlglot_snow_transpilation_coverage.py + hatch -e sqlglot-latest run python src/databricks/labs/remorph/coverage/sqlglot_tsql_transpilation_coverage.py + +antlr_coverage_report: build_core_jar + java -jar $(wildcard core/target/remorph-core-*-SNAPSHOT.jar) '{"command": "debug-coverage", "flags":{"src": "$(abspath ${INPUT_DIR_PARENT})", "dst":"$(abspath ${OUTPUT_DIR})", "extractor": "full"}}' + +dialect_coverage_report: clean_coverage_dir antlr_coverage_report python_coverage_report + hatch run python src/databricks/labs/remorph/coverage/local_report.py + +antlr-lint: + mvn compile -DskipTests exec:java -pl linter --file pom.xml -Dexec.args="-i core/src/main/antlr4 -o .venv/linter/grammar -c true" + +dev-cli: + mvn -f core/pom.xml dependency:build-classpath -Dmdep.outputFile=target/classpath.txt +estimate-coverage: build_core_jar + databricks labs remorph debug-estimate --dst $(abspath ${OUTPUT_DIR}) --dialect snowflake --console-output true diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000000..bc73bfe034 --- /dev/null +++ b/NOTICE @@ -0,0 +1,91 @@ +Copyright (2024) Databricks, Inc. + +This software includes software developed at Databricks (https://www.databricks.com/) and its use is subject to the included LICENSE file. + +This software contains code from the following open source projects, licensed under the MIT license: + +SQL Glot - https://github.com/tobymao/sqlglot +Copyright 2023 Toby Mao +License - https://github.com/tobymao/sqlglot/blob/main/LICENSE + +T-SQL (Transact-SQL, MSSQL) grammar - https://github.com/antlr/grammars-v4/tree/master/sql/tsql +Copyright (c) 2017, Mark Adams (madams51703@gmail.com) +Copyright (c) 2015-2017, Ivan Kochurkin (kvanttt@gmail.com), Positive Technologies. +Copyright (c) 2016, Scott Ure (scott@redstormsoftware.com). +Copyright (c) 2016, Rui Zhang (ruizhang.ccs@gmail.com). +Copyright (c) 2016, Marcus Henriksson (kuseman80@gmail.com). + +Snowflake Database grammar - https://github.com/antlr/grammars-v4/tree/master/sql/snowflake +Copyright (c) 2022, Michał Lorek. + +SLF4J - https://slf4j.org +Copyright (c) 2004-2023 QOS.ch +License - https://www.slf4j.org/license.html + +Microsoft JDBC Driver for SQL Server - https://github.com/microsoft/mssql-jdbc +Copyright (c) Microsoft Corporation +License - https://github.com/microsoft/mssql-jdbc/blob/main/LICENSE + +sql-formatter - https://github.com/vertical-blank/sql-formatter +Copyright (c) 2019 Yohei Yamana +License - https://github.com/vertical-blank/sql-formatter/blob/master/LICENSE + +This software contains code from the following open source projects, licensed under the Apache 2.0 license: + +DataComPy - https://github.com/capitalone/datacompy +Copyright 2018 Capital One Services, LLC +License - https://github.com/capitalone/datacompy/blob/develop/LICENSE + +Apache Spark - https://github.com/apache/spark +Copyright 2018 The Apache Software Foundation +License - https://github.com/apache/spark/blob/master/LICENSE + +Databricks SDK for Python - https://github.com/databricks/databricks-sdk-py +Copyright 2023 Databricks, Inc. All rights reserved. +License - https://github.com/databricks/databricks-sdk-py/blob/main/LICENSE + +Apache Log4j - https://github.com/apache/logging-log4j2 +Copyright 1999-2024 Apache Software Foundation +License - https://github.com/apache/logging-log4j2/blob/2.x/LICENSE.txt + +Scala Logging - https://github.com/lightbend-labs/scala-logging +Copyright 2014-2021 Lightbend, Inc. +License - https://github.com/lightbend-labs/scala-logging?tab=Apache-2.0-1-ov-file#readme + +Snowflake JDBC Driver - https://github.com/snowflakedb/snowflake-jdbc +Copyright 2012-2023 Snowflake Computing, Inc. +License - https://github.com/snowflakedb/snowflake-jdbc/blob/master/LICENSE.txt + +scala-csv - https://github.com/tototoshi/scala-csv +Copyright 2013-2015 Toshiyuki Takahashi +License - https://github.com/tototoshi/scala-csv/blob/master/LICENSE.txt + +cryptography - https://github.com/pyca/cryptography +Copyright 2013-2023 The Python Cryptographic Authority and individual contributors. +License - https://github.com/pyca/cryptography/blob/main/LICENSE + +circe - https://github.com/circe/circe +Copyright (c) 2015, Ephox Pty Ltd, Mark Hibberd, Sean Parsons, Travis Brown, and other contributors. All rights reserved. +https://github.com/circe/circe/blob/series/0.14.x/LICENSE + +circe-generic-extras - https://github.com/circe/circe-generic-extras +https://github.com/circe/circe-generic-extras/blob/main/LICENSE + +circe-jackson - https://github.com/circe/circe-jackson +https://github.com/circe/circe-jackson/blob/main/LICENSE + +This software contains code from the following open source projects, licensed under the BSD license: + +ANTLR v4 - https://github.com/antlr/antlr4 +Copyright (c) 2012-2022 The ANTLR Project. All rights reserved. +https://github.com/antlr/antlr4/blob/dev/LICENSE.txt + +This software contains code from the following publicly available projects, licensed under the Databricks license: + +Databricks Labs Blueprint - https://github.com/databrickslabs/blueprint +Copyright (2023) Databricks, Inc. +https://github.com/databrickslabs/blueprint/blob/main/LICENSE + +Databricks Connect - https://pypi.org/project/databricks-connect/ +Copyright (2019) Databricks Inc. + diff --git a/README.md b/README.md index d5f291c504..2e2c5ddde4 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,481 @@ -# remorph +Databricks Labs Remorph +--- +![Databricks Labs Remorph](docs/img/remorph-logo.svg) + +[![lines of code](https://tokei.rs/b1/github/databrickslabs/remorph)]([https://codecov.io/github/databrickslabs/remorph](https://github.com/databrickslabs/remorph)) ----- -**Table of Contents** +# Table of Contents + +* [Introduction](#introduction) + * [Remorph](#remorph) + * [Transpile](#transpile) + * [Reconcile](#reconcile) +* [Environment Setup](#environment-setup) +* [How to use Transpile](#how-to-use-transpile) + * [Installation](#installation) + * [Verify Installation](#verify-installation) + * [Execution Pre-Set Up](#execution-pre-set-up) + * [Execution](#execution) +* [How to use Reconcile](#how-to-use-reconcile) + * [Installation](#installation-1) + * [Verify Installation](#verify-installation-1) + * [Execution Pre-Set Up](#execution-pre-set-up-1) + * [Execution](#execution-1) +* [Benchmarks](#benchmarks) + * [Transpile](#Transpile-benchmarks) + * [Reconcile](#Reconcile-benchmarks) +* [Frequently Asked Questions](#frequently-asked-questions) + * [Transpile](#Transpile-faq) + * [Reconcile](#Reconcile-faq) +* [Common Error Codes](#common-error-codes) +* [Project Support](#project-support) + +---- +# Introduction + +## Remorph +Remorph stands as a comprehensive toolkit meticulously crafted to facilitate seamless migrations to Databricks. +This suite of tools is dedicated to simplifying and optimizing the entire migration process, offering two distinctive functionalities – Transpile and Reconcile. Whether you are navigating code translation or resolving potential conflicts, Remorph ensures a smooth journey for any migration project. With Remorph as your trusted ally, +the migration experience becomes not only efficient but also well-managed, setting the stage for a successful transition to the Databricks platform. + +## Transpile +Transpile is a self-contained SQL parser, transpiler, and validator designed to interpret a diverse range of SQL inputs and generate syntactically and semantically correct SQL in the Databricks SQL dialect. This tool serves as an automated solution, named Transpile, specifically crafted for migrating and translating SQL scripts from various sources to the Databricks SQL format. Currently, it exclusively supports Snowflake as a source platform, leveraging the open-source SQLglot. + +Transpile stands out as a comprehensive and versatile SQL transpiler, boasting a robust test suite to ensure reliability. Developed entirely in Python, it not only demonstrates high performance but also highlights syntax errors and provides warnings or raises alerts for dialect incompatibilities based on configurations. + +### Transpiler Design Flow: +```mermaid +flowchart TD + A(Transpile CLI) --> |Directory| B[Transpile All Files In Directory]; + A --> |File| C[Transpile Single File] ; + B --> D[List Files]; + C --> E("Sqlglot(transpile)"); + D --> E + E --> |Parse Error| F(Failed Queries) + E --> G{Skip Validations} + G --> |Yes| H(Save Output) + G --> |No| I{Validate} + I --> |Success| H + I --> |Fail| J(Flag, Capture) + J --> H +``` + +## Reconcile +Reconcile is an automated tool designed to streamline the reconciliation process between source data and target data residing on Databricks. Currently, the platform exclusively offers support for Snowflake, Oracle and other Databricks tables as the primary data source. This tool empowers users to efficiently identify discrepancies and variations in data when comparing the source with the Databricks target. + +---- + +# Environment Setup + +1. `Databricks CLI` - Ensure that you have the Databricks Command-Line Interface (CLI) installed on your machine. Refer to the installation instructions provided for Linux, MacOS, and Windows, available [here](https://docs.databricks.com/en/dev-tools/cli/install.html#install-or-update-the-databricks-cli). + +2. `Databricks Connect` - Set up the Databricks workspace configuration file by following the instructions provided [here](https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-configuration-profiles). Note that Databricks labs use 'DEFAULT' as the default profile for establishing connections to Databricks. + +3. `Python` - Verify that your machine has Python version 3.10 or later installed to meet the required dependencies for seamless operation. + - `Windows` - Install python from [here](https://www.python.org/downloads/). Your Windows computer will need a shell environment ([GitBash](https://www.git-scm.com/downloads) or [WSL](https://learn.microsoft.com/en-us/windows/wsl/about)) + - `MacOS/Unix` - Use [brew](https://formulae.brew.sh/formula/python@3.10) to install python in macOS/Unix machines +#### Installing Databricks CLI on macOS +![macos-databricks-cli-install](docs/img/macos-databricks-cli-install.gif) + +#### Install Databricks CLI via curl on Windows +![windows-databricks-cli-install](docs/img/windows-databricks-cli-install.gif) + +#### Check Python version on Windows, macOS, and Unix -- [Installation](#installation) -- [License](#license) +![check-python-version](docs/img/check-python-version.gif) -## Installation +[[back to top](#table-of-contents)] -```console +---- + + +# How to Use Transpile + +### Installation + +Upon completing the environment setup, install Remorph by executing the following command: +```bash databricks labs install remorph ``` +![transpile install](docs/img/transpile-install.gif) + +[[back to top](#table-of-contents)] + +---- + +### Verify Installation +Verify the successful installation by executing the provided command; confirmation of a successful installation is indicated when the displayed output aligns with the example screenshot provided: +```bash + databricks labs remorph transpile --help + ``` +![transpile-help](docs/img/transpile-help.png) + +### Execution Pre-Set Up +1. Transpile necessitates input in the form of either a directory containing SQL files or a single SQL file. +2. The SQL file should encompass scripts intended for migration to Databricks SQL. + +Below is the detailed explanation on the arguments required for Transpile. +- `input-source [Required]` - The path to the SQL file or directory containing SQL files to be transpiled. +- `source-dialect [Required]` - The source platform of the SQL scripts. Currently, only Snowflake is supported. +- `output-folder [Optional]` - The path to the output folder where the transpiled SQL files will be stored. If not specified, the transpiled SQL files will be stored in the same directory as the input SQL file. +- `skip-validation [Optional]` - The default value is True. If set to False, the transpiler will validate the transpiled SQL scripts against the Databricks catalog and schema provided by user. +- `catalog-name [Optional]` - The name of the catalog in Databricks. If not specified, the default catalog `transpiler_test` will be used. +- `schema-name [Optional]` - The name of the schema in Databricks. If not specified, the default schema `convertor_test` will be used. + +### Execution +Execute the below command to intialize the transpile process. +```bash + databricks labs remorph transpile --input-source --source-dialect --output-folder --skip-validation --catalog-name --schema-name +``` + +![transpile run](docs/img/transpile-run.gif) + +[[back to top](#table-of-contents)] + +---- +# How to Use Reconcile + +### Installation + +Install Reconciliation with databricks labs cli. + +```commandline +databricks labs install remorph +``` + +![reconcile install](docs/img/recon-install.gif) + +### Verify Installation +Verify the successful installation by executing the provided command; confirmation of a successful installation is indicated when the displayed output aligns with the example screenshot provided: +```bash + databricks labs remorph reconcile --help + ``` +![reconcile-help](docs/img/reconcile-help.png) + +### Execution Pre-Set Up +>1. Setup the configuration file: + +Once the installation is done, a folder named **.remorph** will be created in the user workspace's home folder. +To process the reconciliation for specific table sources, we must create a config file that gives the detailed required configurations for the table-specific ones. +The file name should be in the format as below and created inside the **.remorph** folder. +``` +recon_config___.json + +Note: For CATALOG_OR_SCHEMA , if CATALOG exists then CATALOG else SCHEMA +``` + +eg: + +| source_type | catalog_or_schema | report_type | file_name | +|-------------|-------------------|-------------|---------------------------------------| +| databricks | tpch | all | recon_config_databricks_tpch_all.json | +| source1 | tpch | row | recon_config_source1_tpch_row.json | +| source2 | tpch | schema | recon_config_source2_tpch_schema.json | + +#### Refer to [Reconcile Configuration Guide][def] for detailed instructions and [example configurations][config] + +[def]: docs/recon_configurations/README.md +[config]: docs/recon_configurations/reconcile_config_samples.md + +> 2. Setup the connection properties + +Remorph-Reconcile manages connection properties by utilizing secrets stored in the Databricks workspace. +Below is the default secret naming convention for managing connection properties. + +**Note: When both the source and target are Databricks, a secret scope is not required.** + +**Default Secret Scope:** remorph_{data_source} + +| source | scope | +|--------|-------| +| snowflake | remorph_snowflake | +| oracle | remorph_oracle | +| databricks | remorph_databricks | + +Below are the connection properties required for each source: +``` +Snowflake: + +sfUrl = https://.snowflakecomputing.com +account = +sfUser = +sfPassword = +sfDatabase = +sfSchema = +sfWarehouse = +sfRole = +pem_private_key = + +Note: For Snowflake authentication, either sfPassword or pem_private_key is required. +Priority is given to pem_private_key, and if it is not found, sfPassword will be used. +If neither is available, an exception will be raised. + +``` + +``` +Oracle: + +user = +password = +host = +port = +database = +``` + + + + +### Execution +Execute the below command to initialize the reconcile process. +```bash + databricks labs remorph reconcile +``` +![reconcile-run](docs/img/recon-run.gif) + +[[back to top](#table-of-contents)] + +---- + +# Benchmarks + +## Transpile-benchmarks +TBD + +## Reconcile-benchmarks + +### tpch `1000GB` data details + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Table nameNo of rowsData SizeSet
supplier10M754.7 MiBSet 1
customer150M11.5 GiBSet 1
part200M5.8 GiBSet 1
partsupp800M39.9 GiBSet 1
orders1.5B62.4 GiBSet 1
lineitem6B217.6 GiBSet 2
+ +### Databricks to Databricks Recon + +The following benchmarks were conducted on various Databricks clusters. Please note that the reconciliation times listed below do not include cluster startup time. + +>Cluster1 -- 14.3 LTS (includes Apache Spark 3.5.0, Scala 2.12) `Photon Enabled` + + + + + + + + + + + + + + + + + + + + + + +
VMQuantityTotal CoresTotal RAM
Driver:**i3.xlarge****1****4 cores****30.5 GB**
Workers:**i3.xlarge****10****40 cores****305 GB**
+ +>Cluster2 -- 14.3 LTS (includes Apache Spark 3.5.0, Scala 2.12) `Photon Enabled` + + + + + + + + + + + + + + + + + + + + + + +
VMQuantityTotal CoresTotal RAM
Driver:**i3.2xlarge****1****8 cores****61 GB**
Workers:**i3.2xlarge****10****80 cores****610 GB**
+ +>Benchmark + +| Type | Data | Cluster1 | Cluster2 | +|----------------|------------------------| --- | --- | +| With Threshold | tpch (set 1 and set 2) | 1.46 hours | 50.12 minutes | +| Without Threshold | tpch (set 1 and set 2) | 1.34 hours | 45.58 minutes | + + +### Snowflake to Databricks Recon + +The following benchmarks were conducted on various Databricks clusters. Please note that the reconciliation times listed below do not include cluster startup time. + +>Snowflake cluster details + +Type: Standard + +Size: Large Cluster (8 nodes, 64 cores) + +>Cluster1 -- 13.3 LTS (includes Apache Spark 3.4.1, Scala 2.12) + + + + + + + + + + + + + + + + + + + + + + +
VMQuantityTotal CoresTotal RAM
Driver:**i3.xlarge****1****4 cores****30.5 GB**
Workers:**i3.xlarge****16****64 cores****488 GB**
+ +>Benchmark + +| Method | Configuration | Set | Time | +|-----------------| ------------- |-------------| ---- | +| Spark (deafult) | - | tpch - Set 1 | 32.01 minutes | +| Spark (deafult) | - | tpch - Set 2 | 1.01 hours | +| JDBC | number_partitions - 10 | tpch - Set 1 | 43.39 minutes | +| JDBC | number_partitions - 10 | tpch - Set 2 | 1.17 hours | +| JDBC | number_partitions - 64 | tpch - Set 1 | 25.95 minutes | +| JDBC | number_partitions - 64 | tpch - Set 2 | 40.30 minutes | +| JDBC | number_partitions - 100 | tpch - Set 2| 1.02 hours | + + +[[back to top](#table-of-contents)] + +---- + +# Frequently Asked Questions + +## Transpile-faq +TBD + +## Reconcile-faq + +
+Can we reconcile for Databricks without UC as a target? + +***The reconciliation target is always Databricks with UC enabled. Reconciler supports non-uc Databricks only as a +source.*** +
+ +
+What would happen if my dataset had duplicate records? + +***Duplicates are not handled in the reconciler. If run with duplicates, it would result in inconsistent output. We can +implement +some workarounds to handle the duplicates, and the solution varies from dataset to dataset.*** +
+ +
+Are User Transformations applicable for Schema Validations? + +***No. User Transformations are not applied for Schema Validation.Only select_columns,drop_columns and column_mapping is +valid for schema validation.*** +
+ +
+Can we apply Aggregate or multi-column transformations as user transformations? + +***No. Aggregate transformations or multi-column transformations are not supported.*** +
+ +
+Does Reconciler support all complex data types? + +***Not all complex data types are supported currently.Reconciler do support UDFs for complex datatypes.Please refer here +for examples.*** +
+ + +
+Does Reconciler support `Column Threshold Validation` for report type as `row`? + +***No. Column Threshold Validation is supported only for reports with the report type `data` or `all`, generally tables with +primary keys.*** +
+ +[[back to top](#table-of-contents)] + +---- + +## Common Error Codes: + +TBD + +---- -## Project Support +# Project Support Please note that all projects in the /databrickslabs github account are provided for your exploration only, and are not formally supported by Databricks with Service Level Agreements (SLAs). They are provided AS-IS and we do not make any guarantees of any kind. Please do not submit a support ticket relating to any issues arising from the use of these projects. Any issues discovered through the use of this project should be filed as GitHub Issues on the Repo. They will be reviewed as time permits, but there are no formal SLAs for support. diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000..e289506a04 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,2 @@ +# We may add certain file to ignore in the codecov coverage report +ignore: diff --git a/core/README.md b/core/README.md new file mode 100644 index 0000000000..e6b4a64f41 --- /dev/null +++ b/core/README.md @@ -0,0 +1,444 @@ +# Implementing Snowflake and Other SQL Dialects AST -> Intermediate Representation + +Here's a guideline for incrementally improving the Snowflake -> IR translation in `SnowflakeAstBuilder`, et al., +and the equivalent builders for additional dialects. + +## Table of Contents +1. [Changing the ANTLR grammars](#changing-the-antlr-grammars) +2. [Conversion of AST to IR](#converting-to-ir-in-the-dialectthingbuilder-classes) + 1. [Step 1: add a test](#step-1-add-a-test) + 2. [Step 2: figure out the expected IR](#step-2-figure-out-the-expected-ir) + 3. [Step 3: run the test](#step-3-run-the-test) + 4. [Step 4: modify Builder](#step-4-modify-dialectthingbuilder) + 5. [Step 5: test, commit, improve](#step-5-test-commit-improve) + 6. [Caveat](#caveat) + +## Changing the ANTLR grammars + +Changing the ANTLR grammars is a specialized task, and only a few team members have permission to do so. +Do not attempt to change the grammars unless you have been given explicit permission to do so. + +It is unfortunately, easy to add parser rules without realizing the full implications of the change, +and this can lead to performance problems, incorrect IR and incorrect code gen. +Performance problems usually come from the fact that ANTLR will manage to make a working parser +out of almost any input specification. However, the resulting parser may then be very slow, +or may not be able to handle certain edge cases. + +After a grammar change is made, the .g4 files must be reformatted to stay in line with the guidelines. +We use (for now at least) the [antlr-format](https://github.com/mike-lischke/antlr-format) tool to do this. The tool is run as part of the +maven build process, if you are using the 'format' maven profile. You can also run it from the command line +via make: + +```bash +make fmt +``` + +or: + +```bash +make fmt-scala +``` + +Also, there is a Databricks specific ANTLR linter, which you MUST run before checking in. + +It can be run from the command line with: + + +```bash +make anltr-lint +``` + +And it will identify problems that the ANTLR tool does not consider problematic because it does +not have enough context. For instance, the ANTLR tool does not identify rules with no caller +as it must assume that they will be called from outside the grammar. + +## Checking ANTLR changes + +If you make changes to the ANTLR grammar, you should check the following things +(this is a good order to check them in): + + - Have you defined any new TOKENS used in the PARSER? + - If so, did you define them in the LEXER? + - Have you eliminated the use of any tokens (this is generally a good thing - such as replacing + a long list of option keywords with an id rule, which is then checked later in the process)? + - If so, have you removed them from the lexer? + - Have you added any new rules? If so: + - have you checked that you have not duplicated some syntactic structure that is already defined + and could be reused? + - have you checked that they are used in the parser (the IntelliJ plugin will highlight unused rules)? + - Have you orphaned any rules? If so: + - did you mean to do so? + - have you removed them (the IntelliJ plugin will highlight unused rules)? + - are you sure that removing them is the right thing to do? If you can create + a shared rule that cuts down on the number of rules, that is generally a good thing. + +You must create tests at each stage for any syntax changes. IR generation tests, coverage tests, and +transpilation tests are all required. Make sure there are tests covering all variants of the syntax you have +added or changed. It is generally be a good idea to write the tests before changing the grammar. + +### Examples + +#### Adding new tokens +Let's say you need to add a new type to this parser rule: + +```antlr +distributionType: HASH LPAREN id (COMMA id)* RPAREN | ROUND_ROBIN | REPLICATE + ; +``` + +So you change the rule to: + +```antlrv4 +distributionType: HASH LPAREN id (COMMA id)* RPAREN | ROUND_ROBIN | REPLICATE | MAGIC + ; +``` + +The IDE plugin should be showing you that MAGIC is not defined as a token in the lexer by +underlining it. + +![ANTLR error](docs/img/antlrmissingtioken.png) + +You should then add a new token to the lexer: + +```antlrv4 +lexer grammar TSqlGrammar; +// ... +LOW : 'LOW'; +MAGIC : 'MAGIC'; +MANUAL : 'MANUAL'; +MARK : 'MARK'; +MASK : 'MASK'; +// ... +``` + +And the IDE should now be happy. Please keep the token definitions in alphabetical order of +the token name. + +#### Orphaned Rules +Let's say that you have had to refactor some rules because something that was previously a standalone +rule is now incorporated into the expression syntax. So you change: + +```antlrv4 +r1 : SELECT id jsonExtract? FROM t1; +jsonExtract : COLON id; +``` +To: + +```antlrv4 +r1 : SELECT expression FROM t1; +jsonExtract : COLON id; + +expression : + // ... + | expression COLON id + // ... +``` +You should now check to see if `jsonExtract` is used anywhere else for a number of reasons: + - it may be that it is no longer needed and can be removed + - it may be that it is used in other places and needs to be refactored to `expression` + everywhere. + - it may be that it is used in other places and needs to be left as is because it is + a different use case. + +Note that in general, if you can use `expression` instead of a specific rule, you should do so, +and resolve the expression type in the IR generation phase. We are generally looking to +parse correct input in this tool, and in any case a good grammar accepts almost anything +that _might_ be correct, and then performs semantic checks in the next phase after parsing, +as this gives better error output. + +Be careful though, as sometimes SQL syntax is ambiguous and you may need to restrict the +syntactic element to a specific type such as `id` in order to avoid ambiguity. + +Note as well, that if you are using the ANTLR plugin for IntelliJ IDEA, it will highlight rules +that are orphaned (not used) in the parser like so: + +![ANTLR unused rule](docs/img/antlrorphanedrule.png) + +The underline in this case is easy to miss, but if you check the top right of the editor window +you will see that there is a warning there: + +![ANTLR unused rule](docs/img/intellijproblem.png) + +And click the icons wil switch to the problems view, where you can see the warnings: + +![ANTLR problem list](docs/img/antlrproblemlist.png) + +Note that as of writing, there are a number of orphaned rules in the grammar that have +been left by previous authors. They will be gradually cleaned up, but do not add to the list. + +Finally, some rules will appear to be orphaned, but are actually are the entry points for +the external calls into the parser. There is currently no way to mark these as such in the +plugin. You can generally assume they are external entry points if the rule ends in +`EOF` to indicate that it parses to the end of the input. + +For instance: + +```antlrv4 +tSqlFile: batch? EOF + ; +``` +#### labels +Use of labels where they are not needed is discouraged. They are overused in the current +grammars. Use them where it makes it much easier to check in the ParserContext received by +the visit method. This is because a label creates a new variable, getters and associated +stuff in the generated method context. + +There is no need for this: + +```antlrv4 +r1: XXX name=id YYY name2=id (name3=id)? + ; +``` +Just so you can reference `ctx.name` and so on in the visitor. You can use `ctx.id(0)`, +and so on: + +```antlrv4 +r1: XXX id YYY id id? + ; +``` + +You should use labels where it makes both the rule and the visitor easier to read. For instance, +a rule with two optional parse paths, with a common element. Let's assume in the following +artificial example, the YY clause means two different things depending on if it +comes before or after the ZZ clause: + +```antlrv4 +r1: XX (YY id )? ZZ id (YY id)? + ; +``` + +In the visitor, you will not know if the `id` is the first or second `id` in the `YY` clause +unless the size of ctx.id() is 3. So, you can use labels to make it clear: + +```antlrv4 +r1: XX (YY pre=id )? ZZ xx=id (YY post=id)? + ; +``` + +Which means that in the visitor, you can check for `ctx.pre`, `ctx.xx` and `ctx.post` and know +for certain which `id` you are looking at. + +#### Refactoring rules + +In some cases, you may see that rules repeat long sequences of tokens. This is generally a bad +thing, and you should strive to merge them in to a common rule. Two to four tokens is generally +OK, but much longer than that, and you should consider refactoring. As of writing for instance, +there are many TSQL grammar rules for `create` vs `altet` where the original author has +repeated all the tokens and options. They will eventually be refactored into common rules. + +Do not separate longer sequences of tokens into separate rules unless there is in fact +commonality between two or more rules. At runtime, a new rule will generate as a new method and +incur the setup and teardown time involved in calling a new rule. + +If you see an opportunity to refactor, and it is because you have added a new rule or syntax +then do so. Refactoring outside the scope of your PR requires a second PR, or raise an issue. + +In the current TSQL grammar, we can see that: + - there are two rules using a common sequence of tokens that we can refactor into a common rule + - that it is probably not worth doing that for `createLoginAzure` + - `labels=` have been used for no reason. + +```antlrv4 +alterLoginAzureSql + : ALTER LOGIN loginName = id ( + (ENABLE | DISABLE)? + | WITH ( + PASSWORD EQ password = STRING (OLD_PASSWORD EQ oldPassword = STRING)? + | NAME EQ loginName = id + ) + ) + ; + +createLoginAzureSql: CREATE LOGIN loginName = id WITH PASSWORD EQ STRING ( SID EQ sid = HEX)? + ; + +alterLoginAzureSqlDwAndPdw + : ALTER LOGIN loginName = id ( + (ENABLE | DISABLE)? + | WITH ( + PASSWORD EQ password = STRING ( + OLD_PASSWORD EQ oldPassword = STRING (MUST_CHANGE | UNLOCK)* + )? + | NAME EQ loginName = id + ) + ) + ; +``` + +So, a refactor would look like this (in stages - later you will see it all at once): + +```antlrv4 +alterLoginAzureSql + : ALTER LOGIN id passwordClause + ; + +createLoginAzureSql: CREATE LOGIN loginName = id WITH PASSWORD EQ STRING ( SID EQ sid = HEX)? + ; + +alterLoginAzureSqlDwAndPdw + : ALTER LOGIN id passwordClause + ; + +passWordClause: + (ENABLE | DISABLE)? + | WITH ( + PASSWORD EQ STRING ( + OLD_PASSWORD EQ STRING (MUST_CHANGE | UNLOCK)* + )? + | NAME EQ id + ) + ; +``` + +In passwdClause we know in the visitor that ctx.STRING(0) is the password, and +ctx.STRING(1) is the old password. We can also check if ctx.MUST_CHANGE() or ctx.UNLOCK() +are present and because they are optional, we can share the passwdClause between the +two rules even though `alterLoginAzureSql` does not support them. The visitor can check +for them and raise an error if it is present in the wrong context, or we can assume valid +input and ignore it. We can now use a specific buildPasswordClause method in the visitor to +return some common definition of a password clause. + +But... we now realize that `createLoginAzureSql` and `alterLoginAzureSqlDwAndPdw` are the same +and there is only need for one of them. So we can merge them into one and remove and extra +entry in the calling clause, which in the TSQL grammar is `ddlClause`, and hey presto we +have removed a rule, reduced parser complexity, and made the grammar easier to read. If there +is any IR difference between the two, we can handle that in the visitor. + +### Installing antlr-format + +The antlr-format tool will run as part of the maven build process and so there is no need to install it locally. +But you can do so using the instructions below. + +In order to run the tool, you have to install Node.js. You can download it from [here](https://nodejs.org/en/download/), +or more simply install it with `brew install node` if you are on a Mac. + +Once node is installed you can install the formatter with: + +```bash +npm install -g antlr-format +npm install -g antlr-format-cli +``` + +### Running antlr-format + +The formatter is trivial to run from the directory containing your changed grammar: + +```bash +~/databricks/remorph/core/src/main/antlr4/../parsers/tsql (feature/antlrformatdocs ✘)✹ ᐅ antlr-format *.g4 + +antlr-format, processing options... + +formatting 2 file(s)... + +done [82 ms] +``` + +Note that the formatting configuration is contained in the .g4 files themselves, so there is no need to +provide a configuration file. Please do not change the formatting rules. + +### Caveat + +Some of the grammar definitions (`src/main/antlr4/com/databricks/labs/remorph/parsers//Parser.g4`) +may still be works-in-progress and, as such, may contain rules that are either incorrect or simply +_get in the way_ of implementing the -> IR translation. + +When stumbling upon such a case, one should: +- materialize the problem in a (failing) test in `BuilderSpec`, flagged as `ignored` until the problem is solved +- shallowly investigate the problem in the grammar and raise a GitHub issue with a problem statement +- add a `TODO` comment on top of the failing test with a link to said issue +- point out the issue to someone tasked with changing/fixing grammars +- move on with implementing something else + +## Converting to IR in the Builder classes + +Here is the methodology used to effect changes to IR generation. + +### Step 1: add a test + +Let say we want to add support for a new type of query, for the sake of simplicity we'll take +`SELECT a FROM b` in SnowFlake, as an example (even though this type of query is already supported). + +The first thing to do is to add an example in `SnowflakeAstBuilderSpec`: + +```scala +"translate my query" in { // find a better test name + example( + query = "SELECT a FROM b", + expectedAst = ir.Noop + ) +} +``` + +### Step 2: figure out the expected IR + +Next we need to figure out the intermediate AST we want to produce for this specific type of query. +In this simple example, the expected output would be +`Project(NamedTable("b", Map.empty, is_streaming = false), Seq(Column("a")))` +so we update our test with: + +```scala +// ... + expectedAst = Project(NamedTable("b", Map.empty, is_streaming = false), Seq(Column("a"))) +// ... +``` + +Less trivial cases may require careful exploration of the available data types in +`com.databricks.labs.remorph.parsers.intermediate` to find to proper output structure. + +It may also happen that the desired structure is missing in `com.databricks.labs.remorph.parsers.intermediate`. +In such a case, **we should not add/modify anything to/in the existing AST**, as it will eventually be generated +from an external definition. Instead, we should add our new AST node as an extension in `src/main/scala/com/databricks/labs/remorph/parsers/intermediate/extensions.scala`. + +### Step 3: run the test + +Our new test is now ready to be run (we expect it to fail). But before running it, you may want to uncomment the +`println(tree.toStringTree(parser))` line in the `parseString` method of `SnowflakeAstBuilderSpec`. + +It will print out the parser's output for your query in a LISP-like format: + +``` +(snowflake_file (batch (sql_command (dml_command (query_statement (select_statement (select_clause SELECT (select_list_no_top (select_list (select_list_elem (column_elem (column_name (id_ a))))))) (select_optional_clauses (from_clause FROM (table_sources (table_source (table_source_item_joined (object_ref (object_name (id_ b))))))))))))) ) +``` + +This will be useful to know which methods in `SnowflakeAstBuilder` you need to override/modify. +Note however, that ANTLR4 generated parser tree visitors automatically call accept on nodes, +so you do not always need to override some intermediate method, just to call accept() yourself. + +### Step 4: modify Builder + +Method names in `Builder` follow the names that appear in the parser's output above. For example, +one would realize that the content of the `columnName` node is what will ultimately get translated as an IR `Column`. +To do so, they therefore need to override the `visitColumnName` method of `SnowflakeExpressionBuilder`. + +Rather than have a single big visitor for every possible node in the parser's output, the nodes are handled by +specialized visitors. For instance `TSqlExpressionBuilder` and `SnowflakeExpressionBuilder`. An instance of the `ExpressionBuilder` +is injected into the `AstBuilder`, which in turn calls accept using the expressionBuilder on a node that +will be an expression within a larger production, such as visitSelect. + +The builders have one `visit*` method for every node in the parser's output that the builder is responsible +for handling. Methods that are not overridden simply return the result of visiting children nodes. So in our example, +even though we haven't overridden the `visitColumnElem` method, our `visitColumnName` method will get called as expected +because ANTLR creates a default implementation of `visitColumnElem` tha that calls `accept(this)` on the `columnName` node. + +However, note that if a rule can produce multiple children using the default `visit`, you will need to override the +method that corresponds to that rule and produce the single IR node that represents that production. + +A rule of thumb for picking the right method to override is therefore to look for the narrowest node (in the parser's output) +that contains all the information we need. Once you override a `visit` function, you are then responsible for either +calling `accept()` on its child nodes or otherwise processing them using a specialized `build` method. + +Here, the `(id_ a)` node is "too narrow" as `id_` appears in many different places where it could be translated as +something else than a `Column` so we go for the parent `columnName` instead. + +Moving forward, we may realize that there are more ways to build a `Column` than we initially expected so we may +have to override the `visitColumnElem` as well, but we will still be able to reuse our `visitColumnName` method, and +have `visitColumnElem` call `accept(expressionBuilder)` on the `columnName` node. + +### Step 5: test, commit, improve + +At this point, we should have come up with an implementation of `Builder` that makes our new test pass. +It is a good time for committing our changes. + +It isn't the end of the story though. We should add more tests with slight variations of our initial query +(like `SELECT a AS aa FROM b` for example) and see how our new implementation behaves. This may in turn make us +change our implementation, repeating the above steps a few times more. diff --git a/core/docs/img/antlrmissingtioken.png b/core/docs/img/antlrmissingtioken.png new file mode 100644 index 0000000000..afeacdf92a Binary files /dev/null and b/core/docs/img/antlrmissingtioken.png differ diff --git a/core/docs/img/antlrorphanedrule.png b/core/docs/img/antlrorphanedrule.png new file mode 100644 index 0000000000..7139f65efe Binary files /dev/null and b/core/docs/img/antlrorphanedrule.png differ diff --git a/core/docs/img/antlrproblemlist.png b/core/docs/img/antlrproblemlist.png new file mode 100644 index 0000000000..986cab7c5f Binary files /dev/null and b/core/docs/img/antlrproblemlist.png differ diff --git a/core/docs/img/intellijproblem.png b/core/docs/img/intellijproblem.png new file mode 100644 index 0000000000..55adf3f5bd Binary files /dev/null and b/core/docs/img/intellijproblem.png differ diff --git a/core/pom.xml b/core/pom.xml new file mode 100644 index 0000000000..40fcf69a49 --- /dev/null +++ b/core/pom.xml @@ -0,0 +1,346 @@ + + + 4.0.0 + + com.databricks.labs + remorph + 0.2.0-SNAPSHOT + + remorph-core + jar + + 4.5.14 + 2.18.2 + 5.11.3 + 1.8 + 1.8 + + 4.11.0 + UTF-8 + 2.0.9 + 4.13.2 + 2.0.5 + 3.9.5 + 0.37.0 + + 15.1.0 + 0.14.2 + 12.8.0.jre8 + 3.20.0 + 0.10.1 + 2.0.0 + + + + + org.junit + junit-bom + ${junit-bom.version} + pom + import + + + + + + org.antlr + antlr4-runtime + ${antlr.version} + + + com.fasterxml.jackson.module + jackson-module-scala_${scala.binary.version} + ${jackson.version} + + + org.scala-lang + scala-library + ${scala.version} + + + com.databricks + databricks-sdk-java + ${databricks-sdk-java.version} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + 2.18.2 + + + com.databricks + databricks-connect + ${databricks-connect.version} + + + org.scala-lang + scala-reflect + + + org.json4s + json4s-scalap_${scala.binary.version} + + + + + org.apache.logging.log4j + log4j-slf4j2-impl + 2.23.1 + + + org.slf4j + slf4j-api + ${slf4j.version} + + + com.typesafe.scala-logging + scala-logging_${scala.binary.version} + ${scala-logging.version} + + + org.scalatest + scalatest_${scala.binary.version} + 3.3.0-SNAP4 + test + + + org.scalatestplus + mockito-4-11_${scala.binary.version} + 3.3.0.0-alpha.1 + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + com.lihaoyi + pprint_${scala.binary.version} + 0.9.0 + compile + + + io.circe + circe-core_${scala.binary.version} + ${circe.version} + + + io.circe + circe-generic_${scala.binary.version} + ${circe.version} + + + io.circe + circe-generic-extras_${scala.binary.version} + ${circe.version} + + + io.circe + circe-jackson215_${scala.binary.version} + ${circe.version} + + + com.lihaoyi + os-lib_${scala.binary.version} + ${os-lib.version} + + + com.github.vertical-blank + sql-formatter + ${sql-formatter.version} + + + net.snowflake + snowflake-jdbc + ${snowflake-jdbc.version} + + + com.github.tototoshi + scala-csv_${scala.binary.version} + ${scala-csv.version} + + + com.microsoft.sqlserver + mssql-jdbc + ${mssql-jdbc.version} + + + org.bouncycastle + bcprov-jdk18on + 1.78.1 + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + true + + + + org.scalatest + scalatest-maven-plugin + 2.2.0 + + ${project.build.directory}/surefire-reports + . + tests-report.xml + + + + test + + test + + + + + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + false + src/main/antlr4 + true + ${project.basedir}/src/main/antlr4/com/databricks/labs/remorph/parsers/lib + ${project.build.directory}/generated-sources/antlr4 + + **/*.g4 + + + **/lib/*.g4 + **/basesnowflake.g4 + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.4.2 + + + + test-jar + + + + + + + + + format + + + + com.github.eirslett + frontend-maven-plugin + 1.15.1 + + v22.3.0 + ${project.build.directory} + + + + install node and npm + validate + + install-node-and-npm + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.4.1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + java + + + + + com.databricks.labs.remorph.Main + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.7.1 + + + jar-with-dependencies + + + + com.databricks.labs.remorph.Main + + + + + + package + + single + + + + + + + + + diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/README.md b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/README.md new file mode 100644 index 0000000000..7a9c64d837 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/README.md @@ -0,0 +1,14 @@ +# ANTLR Grammar Library + +This directory contains ANTLR grammar files that are common to more than one SQL dialect. Such as the grammar that covers stored procedures, which all +dialects of SQL support in some form, and for which we have a universal grammar. + +ANTLR processes included grammars as pure text, in the same way that say the C pre-processor processes `#include` directives. +This means that you must be careful to ensure that: + - if you define new tokens in an included grammar, that they do not clash with tokens in the including grammar. + - if you define new rules in an included grammar, that they do not clash with rules in the including grammar. + In particular, you must avoid creating ambiguities in rule/token prediction, where ANTLR will try to create + a parser anyway, but generate code that performs extremely long token lookahead, and is therefore very slow. + +In other words, you cannot just arbitrarily throw together some common Lexer and Parser rules and expect them +to just work. diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonlex.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonlex.g4 new file mode 100644 index 0000000000..e289e8437f --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonlex.g4 @@ -0,0 +1,1482 @@ +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignTrailingComments true +// $antlr-format columnLimit 150 +// $antlr-format maxEmptyLinesToKeep 1 +// $antlr-format reflowComments false +// $antlr-format useTab false +// $antlr-format allowShortRulesOnASingleLine true +// $antlr-format allowShortBlocksOnASingleLine true +// $antlr-format minEmptyLines 0 +// $antlr-format alignSemicolons ownLine +// $antlr-format alignColons trailing +// $antlr-format singleLineOverrulesHangingColon true +// $antlr-format alignLexerCommands true +// $antlr-format alignLabels true +// $antlr-format alignTrailers true +// ================================================================================= +lexer grammar commonlex; + +// TODO: Remove the use of DUMMY for unfinished Snoflake grammar productions +DUMMY: + 'DUMMY' +; // Dummy is not a keyword but rules reference it in unfinished Snowflake grammar - need to get rid + +ABORT : 'ABORT'; +ABORT_AFTER_WAIT : 'ABORT_AFTER_WAIT'; +ABORT_DETACHED_QUERY : 'ABORT_DETACHED_QUERY'; +ABORT_STATEMENT : 'ABORT_STATEMENT'; +ABSENT : 'ABSENT'; +ABSOLUTE : 'ABSOLUTE'; +ACCELERATED_DATABASE_RECOVERY : 'ACCELERATED_DATABASE_RECOVERY'; +ACCENT_SENSITIVITY : 'ACCENT_SENSITIVITY'; +ACCESS : 'ACCESS'; +ACCOUNT : 'ACCOUNT'; +ACCOUNTADMIN : 'ACCOUNTADMIN'; +ACCOUNTS : 'ACCOUNTS'; +ACTION : 'ACTION'; +ACTIVATION : 'ACTIVATION'; +ACTIVE : 'ACTIVE'; +ADD : 'ADD'; +ADDRESS : 'ADDRESS'; +ADMIN_NAME : 'ADMIN_NAME'; +ADMIN_PASSWORD : 'ADMIN_PASSWORD'; +ADMINISTER : 'ADMINISTER'; +AES : 'AES'; +AES_128 : 'AES_128'; +AES_192 : 'AES_192'; +AES_256 : 'AES_256'; +AFFINITY : 'AFFINITY'; +AFTER : 'AFTER'; +AGGREGATE : 'AGGREGATE'; +ALERT : 'ALERT'; +ALERTS : 'ALERTS'; +ALGORITHM : 'ALGORITHM'; +ALL : 'ALL'; +ALL_CONSTRAINTS : 'ALL_CONSTRAINTS'; +ALL_ERRORMSGS : 'ALL_ERRORMSGS'; +ALL_INDEXES : 'ALL_INDEXES'; +ALL_LEVELS : 'ALL_LEVELS'; +ALLOW_CLIENT_MFA_CACHING : 'ALLOW_CLIENT_MFA_CACHING'; +ALLOW_CONNECTIONS : 'ALLOW_CONNECTIONS'; +ALLOW_DUPLICATE : 'ALLOW_DUPLICATE'; +ALLOW_ENCRYPTED_VALUE_MODIFICATIONS : 'ALLOW_ENCRYPTED_VALUE_MODIFICATIONS'; +ALLOW_ID_TOKEN : 'ALLOW_ID_TOKEN'; +ALLOW_MULTIPLE_EVENT_LOSS : 'ALLOW_MULTIPLE_EVENT_LOSS'; +ALLOW_OVERLAPPING_EXECUTION : 'ALLOW_OVERLAPPING_EXECUTION'; +ALLOW_PAGE_LOCKS : 'ALLOW_PAGE_LOCKS'; +ALLOW_ROW_LOCKS : 'ALLOW_ROW_LOCKS'; +ALLOW_SINGLE_EVENT_LOSS : 'ALLOW_SINGLE_EVENT_LOSS'; +ALLOW_SNAPSHOT_ISOLATION : 'ALLOW_SNAPSHOT_ISOLATION'; +ALLOWED : 'ALLOWED'; +ALLOWED_ACCOUNTS : 'ALLOWED_ACCOUNTS'; +ALLOWED_DATABASES : 'ALLOWED_DATABASES'; +ALLOWED_INTEGRATION_TYPES : 'ALLOWED_INTEGRATION_TYPES'; +ALLOWED_IP_LIST : 'ALLOWED_IP_LIST'; +ALLOWED_SHARES : 'ALLOWED_SHARES'; +ALLOWED_VALUES : 'ALLOWED_VALUES'; +ALTER : 'ALTER'; +ALWAYS : 'ALWAYS'; +AND : 'AND'; +ANONYMOUS : 'ANONYMOUS'; +ANSI_DEFAULTS : 'ANSI_DEFAULTS'; +ANSI_NULL_DEFAULT : 'ANSI_NULL_DEFAULT'; +ANSI_NULL_DFLT_OFF : 'ANSI_NULL_DFLT_OFF'; +ANSI_NULL_DFLT_ON : 'ANSI_NULL_DFLT_ON'; +ANSI_NULLS : 'ANSI_NULLS'; +ANSI_PADDING : 'ANSI_PADDING'; +ANSI_WARNINGS : 'ANSI_WARNINGS'; +ANY : 'ANY'; +API : 'API'; +API_ALLOWED_PREFIXES : 'API_ALLOWED_PREFIXES'; +API_AWS_ROLE_ARN : 'API_AWS_ROLE_ARN'; +API_BLOCKED_PREFIXES : 'API_BLOCKED_PREFIXES'; +API_INTEGRATION : 'API_INTEGRATION'; +API_KEY : 'API_KEY'; +API_PROVIDER : 'API_PROVIDER'; +APPEND : 'APPEND'; +APPEND_ONLY : 'APPEND_ONLY'; +APPLICATION : 'APPLICATION'; +APPLICATION_LOG : 'APPLICATION_LOG'; +APPLY : 'APPLY'; +ARITHABORT : 'ARITHABORT'; +ARITHIGNORE : 'ARITHIGNORE'; +ARRAY : 'ARRAY'; +ARRAY_AGG : 'ARRAY' '_'? 'AGG'; +AS : 'AS'; +ASC : 'ASC'; +ASSEMBLY : 'ASSEMBLY'; +ASYMMETRIC : 'ASYMMETRIC'; +ASYNCHRONOUS_COMMIT : 'ASYNCHRONOUS_COMMIT'; +AT_KEYWORD : 'AT'; +ATTACH : 'ATTACH'; +AUDIT : 'AUDIT'; +AUDIT_GUID : 'AUDIT_GUID'; +AUTHENTICATE : 'AUTHENTICATE'; +AUTHENTICATION : 'AUTHENTICATION'; +AUTHORIZATION : 'AUTHORIZATION'; +AUTHORIZATIONS : 'AUTHORIZATIONS'; +AUTO : 'AUTO'; +AUTO_CLEANUP : 'AUTO_CLEANUP'; +AUTO_CLOSE : 'AUTO_CLOSE'; +AUTO_COMPRESS : 'AUTO_COMPRESS'; +AUTO_CREATE_STATISTICS : 'AUTO_CREATE_STATISTICS'; +AUTO_DETECT : 'AUTO_DETECT'; +AUTO_DROP : 'AUTO_DROP'; +AUTO_INGEST : 'AUTO_INGEST'; +AUTO_REFRESH : 'AUTO_REFRESH'; +AUTO_RESUME : 'AUTO_RESUME'; +AUTO_SHRINK : 'AUTO_SHRINK'; +AUTO_SUSPEND : 'AUTO_SUSPEND'; +AUTO_UPDATE_STATISTICS : 'AUTO_UPDATE_STATISTICS'; +AUTO_UPDATE_STATISTICS_ASYNC : 'AUTO_UPDATE_STATISTICS_ASYNC'; +AUTOCOMMIT : 'AUTOCOMMIT'; +AUTOCOMMIT_API_SUPPORTED : 'AUTOCOMMIT_API_SUPPORTED'; +AUTOGROW_ALL_FILES : 'AUTOGROW_ALL_FILES'; +AUTOGROW_SINGLE_FILE : 'AUTOGROW_SINGLE_FILE'; +AUTOINCREMENT : 'AUTOINCREMENT'; +AUTOMATED_BACKUP_PREFERENCE : 'AUTOMATED_BACKUP_PREFERENCE'; +AUTOMATIC : 'AUTOMATIC'; +AVAILABILITY : 'AVAILABILITY'; +AVAILABILITY_MODE : 'AVAILABILITY_MODE'; +AVRO : 'AVRO'; +AWS_KEY_ID : 'AWS_KEY_ID'; +AWS_ROLE : 'AWS_ROLE'; +AWS_SECRET_KEY : 'AWS_SECRET_KEY'; +AWS_SNS : 'AWS_SNS'; +AWS_SNS_ROLE_ARN : 'AWS_SNS_ROLE_ARN'; +AWS_SNS_TOPIC : 'AWS_SNS_TOPIC'; +AWS_SNS_TOPIC_ARN : 'AWS_SNS_TOPIC_ARN'; +AWS_TOKEN : 'AWS_TOKEN'; +AZURE_AD_APPLICATION_ID : 'AZURE_AD_APPLICATION_ID'; +AZURE_EVENT_GRID : 'AZURE_EVENT_GRID'; +AZURE_EVENT_GRID_TOPIC_ENDPOINT : 'AZURE_EVENT_GRID_TOPIC_ENDPOINT'; +AZURE_SAS_TOKEN : 'AZURE_SAS_TOKEN'; +AZURE_STORAGE_QUEUE_PRIMARY_URI : 'AZURE_STORAGE_QUEUE_PRIMARY_URI'; +AZURE_TENANT_ID : 'AZURE_TENANT_ID'; +BACKUP : 'BACKUP'; +BACKUP_CLONEDB : 'BACKUP_CLONEDB'; +BACKUP_PRIORITY : 'BACKUP_PRIORITY'; +BEFORE : 'BEFORE'; +BEGIN : 'BEGIN'; +BEGIN_DIALOG : 'BEGIN_DIALOG'; +BERNOULLI : 'BERNOULLI'; +BETWEEN : 'BETWEEN'; +BINARY : 'BINARY'; +BINARY_AS_TEXT : 'BINARY_AS_TEXT'; +BINARY_FORMAT : 'BINARY_FORMAT'; +BINARY_INPUT_FORMAT : 'BINARY_INPUT_FORMAT'; +BINARY_OUTPUT_FORMAT : 'BINARY_OUTPUT_FORMAT'; +BINDING : 'BINDING'; +BLOB_STORAGE : 'BLOB_STORAGE'; +BLOCK : 'BLOCK'; +BLOCKED_IP_LIST : 'BLOCKED_IP_LIST'; +BLOCKED_ROLES_LIST : 'BLOCKED_ROLES_LIST'; +BLOCKERS : 'BLOCKERS'; +BLOCKSIZE : 'BLOCKSIZE'; +BODY : 'BODY'; +BREAK : 'BREAK'; +BROKER : 'BROKER'; +BROKER_INSTANCE : 'BROKER_INSTANCE'; +BROTLI : 'BROTLI'; +BROWSE : 'BROWSE'; +BUFFER : 'BUFFER'; +BUFFERCOUNT : 'BUFFERCOUNT'; +BULK : 'BULK'; +BULK_LOGGED : 'BULK_LOGGED'; +BUSINESS_CRITICAL : 'BUSINESS_CRITICAL'; +BY : 'BY'; +BZ2 : 'BZ2'; +CACHE : 'CACHE'; +CALL : 'CALL'; +CALLED : 'CALLED'; +CALLER : 'CALLER'; +CAP_CPU_PERCENT : 'CAP_CPU_PERCENT'; +CASCADE : 'CASCADE'; +CASE : 'CASE'; +CASE_INSENSITIVE : 'CASE_INSENSITIVE'; +CASE_SENSITIVE : 'CASE_SENSITIVE'; +CAST : 'CAST'; +CATALOG : 'CATALOG'; +CATCH : 'CATCH'; +CERTIFICATE : 'CERTIFICATE'; +CHANGE : 'CHANGE'; +CHANGE_RETENTION : 'CHANGE_RETENTION'; +CHANGE_TRACKING : 'CHANGE_TRACKING'; +CHANGES : 'CHANGES'; +CHANGETABLE : 'CHANGETABLE'; +CHANNELS : 'CHANNELS'; +CHARACTER : 'CHARACTER'; +CHECK : 'CHECK'; +CHECK_EXPIRATION : 'CHECK_EXPIRATION'; +CHECK_POLICY : 'CHECK_POLICY'; +CHECKALLOC : 'CHECKALLOC'; +CHECKCATALOG : 'CHECKCATALOG'; +CHECKCONSTRAINTS : 'CHECKCONSTRAINTS'; +CHECKDB : 'CHECKDB'; +CHECKFILEGROUP : 'CHECKFILEGROUP'; +CHECKPOINT : 'CHECKPOINT'; +CHECKSUM : 'CHECKSUM'; +CHECKTABLE : 'CHECKTABLE'; +CLASSIFIER_FUNCTION : 'CLASSIFIER_FUNCTION'; +CLEANTABLE : 'CLEANTABLE'; +CLEANUP : 'CLEANUP'; +CLONE : 'CLONE'; +CLONEDATABASE : 'CLONEDATABASE'; +CLOSE : 'CLOSE'; +CLUSTER : 'CLUSTER'; +CLUSTERED : 'CLUSTERED'; +CLUSTERING : 'CLUSTERING'; +COLLATE : 'COLLATE'; +COLLECTION : 'COLLECTION'; +COLUMN : 'COLUMN'; +COLUMN_ENCRYPTION_KEY : 'COLUMN_ENCRYPTION_KEY'; +COLUMN_MASTER_KEY : 'COLUMN_MASTER_KEY'; +COLUMNS : 'COLUMNS'; +COLUMNSTORE : 'COLUMNSTORE'; +COLUMNSTORE_ARCHIVE : 'COLUMNSTORE_ARCHIVE'; +COMMENT : 'COMMENT'; +COMMIT : 'COMMIT'; +COMMITTED : 'COMMITTED'; +COMPATIBILITY_LEVEL : 'COMPATIBILITY_LEVEL'; +COMPRESS_ALL_ROW_GROUPS : 'COMPRESS_ALL_ROW_GROUPS'; +COMPRESSION : 'COMPRESSION'; +COMPRESSION_DELAY : 'COMPRESSION_DELAY'; +COMPUTE : 'COMPUTE'; +CONCAT : 'CONCAT'; +CONCAT_NULL_YIELDS_NULL : 'CONCAT_NULL_YIELDS_NULL'; +CONDITION : 'CONDITION'; +CONFIGURATION : 'CONFIGURATION'; +CONNECT : 'CONNECT'; +CONNECTION : 'CONNECTION'; +CONNECTIONS : 'CONNECTIONS'; +CONSTRAINT : 'CONSTRAINT'; +CONTAINMENT : 'CONTAINMENT'; +CONTAINS : 'CONTAINS'; +CONTAINSTABLE : 'CONTAINSTABLE'; +CONTENT : 'CONTENT'; +CONTEXT : 'CONTEXT'; +CONTEXT_HEADERS : 'CONTEXT_HEADERS'; +CONTINUE : 'CONTINUE'; +CONTINUE_AFTER_ERROR : 'CONTINUE_AFTER_ERROR'; +CONTRACT : 'CONTRACT'; +CONTRACT_NAME : 'CONTRACT_NAME'; +CONTROL : 'CONTROL'; +CONVERSATION : 'CONVERSATION'; +COOKIE : 'COOKIE'; +COPY : 'COPY'; +COPY_ONLY : 'COPY_ONLY'; +COPY_OPTIONS_ : 'COPY_OPTIONS'; +COUNTER : 'COUNTER'; +CPU : 'CPU'; +CREATE : 'CREATE'; +CREATE_NEW : 'CREATE_NEW'; +CREATION_DISPOSITION : 'CREATION_DISPOSITION'; +CREDENTIAL : 'CREDENTIAL'; +CREDENTIALS : 'CREDENTIALS'; +CREDIT_QUOTA : 'CREDIT_QUOTA'; +CROSS : 'CROSS'; +CRYPTOGRAPHIC : 'CRYPTOGRAPHIC'; +CSV : 'CSV'; +CURRENT : 'CURRENT'; +CURRENT_DATE : 'CURRENT_DATE'; +CURRENT_TIME : 'CURRENT_TIME'; +CURRENT_TIMESTAMP : 'CURRENT_TIMESTAMP'; +CURSOR : 'CURSOR'; +CURSOR_CLOSE_ON_COMMIT : 'CURSOR_CLOSE_ON_COMMIT'; +CURSOR_DEFAULT : 'CURSOR_DEFAULT'; +CUSTOM : 'CUSTOM'; +CYCLE : 'CYCLE'; +DAILY : 'DAILY'; +DATA : 'DATA'; +DATA_COMPRESSION : 'DATA_COMPRESSION'; +DATA_PURITY : 'DATA_PURITY'; +DATA_RETENTION_TIME_IN_DAYS : 'DATA_RETENTION_TIME_IN_DAYS'; +DATA_SOURCE : 'DATA_SOURCE'; +DATABASE : 'DATABASE'; +DATABASE_MIRRORING : 'DATABASE_MIRRORING'; +DATABASES : 'DATABASES'; +DATASPACE : 'DATASPACE'; +DATE_CORRELATION_OPTIMIZATION : 'DATE_CORRELATION_OPTIMIZATION'; +DATE_FORMAT : 'DATE_FORMAT'; +DATE_INPUT_FORMAT : 'DATE_INPUT_FORMAT'; +DATE_OUTPUT_FORMAT : 'DATE_OUTPUT_FORMAT'; +DAYS : 'DAYS'; +DAYS_TO_EXPIRY : 'DAYS_TO_EXPIRY'; +DB_CHAINING : 'DB_CHAINING'; +DB_FAILOVER : 'DB_FAILOVER'; +DBCC : 'DBCC'; +DBREINDEX : 'DBREINDEX'; +DDL : 'DDL'; +DEALLOCATE : 'DEALLOCATE'; +DECLARE : 'DECLARE'; +DECRYPTION : 'DECRYPTION'; +DEFAULT : 'DEFAULT'; +DEFAULT_DATABASE : 'DEFAULT_DATABASE'; +DEFAULT_DDL_COLLATION_ : 'DEFAULT_DDL_COLLATION'; +DEFAULT_DOUBLE_QUOTE : ["]'DEFAULT' ["]; +DEFAULT_FULLTEXT_LANGUAGE : 'DEFAULT_FULLTEXT_LANGUAGE'; +DEFAULT_LANGUAGE : 'DEFAULT_LANGUAGE'; +DEFAULT_NAMESPACE : 'DEFAULT_NAMESPACE'; +DEFAULT_ROLE : 'DEFAULT_ROLE'; +DEFAULT_SCHEMA : 'DEFAULT_SCHEMA'; +DEFAULT_WAREHOUSE : 'DEFAULT_WAREHOUSE'; +DEFERRABLE : 'DEFERRABLE'; +DEFERRED : 'DEFERRED'; +DEFINE : 'DEFINE'; +DEFINITION : 'DEFINITION'; +DEFLATE : 'DEFLATE'; +DELAY : 'DELAY'; +DELAYED_DURABILITY : 'DELAYED_DURABILITY'; +DELEGATED : 'DELEGATED'; +DELETE : 'DELETE'; +DELETED : 'DELETED'; +DELTA : 'DELTA'; +DENSE_RANK : 'DENSE_RANK'; +DENY : 'DENY'; +DEPENDENTS : 'DEPENDENTS'; +DES : 'DES'; +DESC : 'DESC'; +DESCRIBE : 'DESC' 'RIBE'?; +DESCRIPTION : 'DESCRIPTION'; +DESX : 'DESX'; +DETERMINISTIC : 'DETERMINISTIC'; +DHCP : 'DHCP'; +DIAGNOSTICS : 'DIAGNOSTICS'; +DIALOG : 'DIALOG'; +DIFFERENTIAL : 'DIFFERENTIAL'; +DIRECTION : 'DIRECTION'; +DIRECTORY : 'DIRECTORY'; +DIRECTORY_NAME : 'DIRECTORY_NAME'; +DISABLE : 'DISABLE'; +DISABLE_AUTO_CONVERT : 'DISABLE_AUTO_CONVERT'; +DISABLE_BROKER : 'DISABLE_BROKER'; +DISABLE_SNOWFLAKE_DATA : 'DISABLE_SNOWFLAKE_DATA'; +DISABLED : 'DISABLED'; +DISPLAY_NAME : 'DISPLAY_NAME'; +DISTINCT : 'DISTINCT'; +DISTRIBUTED : 'DISTRIBUTED'; +DISTRIBUTION : 'DISTRIBUTION'; +DO : 'DO'; +DOCUMENT : 'DOCUMENT'; +DOLLAR_PARTITION : '$PARTITION'; +DOUBLE_BACK_SLASH : '\\\\'; +DOUBLE_FORWARD_SLASH : '//'; +DOWNSTREAM : 'DOWNSTREAM'; +DROP : 'DROP'; +DROP_EXISTING : 'DROP_EXISTING'; +DROPCLEANBUFFERS : 'DROPCLEANBUFFERS'; +DTC_SUPPORT : 'DTC_SUPPORT'; +DYNAMIC : 'DYNAMIC'; +ECONOMY : 'ECONOMY'; +EDGE : 'EDGE'; +EDITION : 'EDITION'; +ELEMENTS : 'ELEMENTS'; +ELSE : 'ELSE'; +EMAIL : 'EMAIL'; +EMERGENCY : 'EMERGENCY'; +EMPTY : 'EMPTY'; +EMPTY_FIELD_AS_NULL : 'EMPTY_FIELD_AS_NULL'; +ENABLE : 'ENABLE'; +ENABLE_BROKER : 'ENABLE_BROKER'; +ENABLE_FOR_PRIVILEGE : 'ENABLE_FOR_PRIVILEGE'; +ENABLE_INTERNAL_STAGES_PRIVATELINK : 'ENABLE_INTERNAL_STAGES_PRIVATELINK'; +ENABLE_OCTAL : 'ENABLE_OCTAL'; +ENABLE_QUERY_ACCELERATION : 'ENABLE_QUERY_ACCELERATION'; +ENABLE_UNLOAD_PHYSICAL_TYPE_OPTIMIZATION : 'ENABLE_UNLOAD_PHYSICAL_TYPE_OPTIMIZATION'; +ENABLED : 'ENABLED'; +ENCODING : 'ENCODING'; +ENCRYPTED : 'ENCRYPTED'; +ENCRYPTED_VALUE : 'ENCRYPTED_VALUE'; +ENCRYPTION : 'ENCRYPTION'; +ENCRYPTION_TYPE : 'ENCRYPTION_TYPE'; +END : 'END'; +END_TIMESTAMP : 'END_TIMESTAMP'; +ENDPOINT : 'ENDPOINT'; +ENDPOINT_URL : 'ENDPOINT_URL'; +ENFORCE_LENGTH : 'ENFORCE_LENGTH'; +ENFORCE_SESSION_POLICY : 'ENFORCE_SESSION_POLICY'; +ENFORCED : 'ENFORCED'; +ENTERPRISE : 'ENTERPRISE'; +EQUALITY : 'EQUALITY'; +ERROR : 'ERROR'; +ERROR_BROKER_CONVERSATIONS : 'ERROR_BROKER_CONVERSATIONS'; +ERROR_INTEGRATION : 'ERROR_INTEGRATION'; +ERROR_ON_COLUMN_COUNT_MISMATCH : 'ERROR_ON_COLUMN_COUNT_MISMATCH'; +ERROR_ON_NONDETERMINISTIC_MERGE : 'ERROR_ON_NONDETERMINISTIC_MERGE'; +ERROR_ON_NONDETERMINISTIC_UPDATE : 'ERROR_ON_NONDETERMINISTIC_UPDATE'; +ESCAPE : 'ESCAPE'; +ESCAPE_UNENCLOSED_FIELD : 'ESCAPE_UNENCLOSED_FIELD'; +ESTIMATEONLY : 'ESTIMATEONLY'; +EVENT : 'EVENT'; +EVENT_RETENTION_MODE : 'EVENT_RETENTION_MODE'; +EXCEPT : 'EXCEPT'; +EXCEPTION : 'EXCEPTION'; +EXCHANGE : 'EXCHANGE'; +EXCLUSIVE : 'EXCLUSIVE'; +EXECUTABLE : 'EXECUTABLE'; +EXECUTABLE_FILE : 'EXECUTABLE_FILE'; +EXECUTE : 'EXEC' 'UTE'?; +EXECUTION : 'EXECUTION'; +EXISTS : 'EXISTS'; +EXPIREDATE : 'EXPIREDATE'; +EXPIRY_DATE : 'EXPIRY_DATE'; +EXPLAIN : 'EXPLAIN'; +EXPLICIT : 'EXPLICIT'; +EXTENDED_LOGICAL_CHECKS : 'EXTENDED_LOGICAL_CHECKS'; +EXTENSION : 'EXTENSION'; +EXTERNAL : 'EXTERNAL'; +EXTERNAL_ACCESS : 'EXTERNAL_ACCESS'; +EXTRACT : 'EXTRACT'; +FAIL_OPERATION : 'FAIL_OPERATION'; +FAILOVER : 'FAILOVER'; +FAILOVER_MODE : 'FAILOVER_MODE'; +FAILURE : 'FAILURE'; +FAILURE_CONDITION_LEVEL : 'FAILURE_CONDITION_LEVEL'; +FAILURECONDITIONLEVEL : 'FAILURECONDITIONLEVEL'; +FALSE : 'FALSE'; +FAN_IN : 'FAN_IN'; +FAST_FORWARD : 'FAST_FORWARD'; +FETCH : 'FETCH'; +FIELD_DELIMITER : 'FIELD_DELIMITER'; +FIELD_OPTIONALLY_ENCLOSED_BY : 'FIELD_OPTIONALLY_ENCLOSED_BY'; +FILE : 'FILE'; +FILE_EXTENSION : 'FILE_EXTENSION'; +FILE_FORMAT : 'FILE_FORMAT'; +FILE_SNAPSHOT : 'FILE_SNAPSHOT'; +FILEGROUP : 'FILEGROUP'; +FILEGROWTH : 'FILEGROWTH'; +FILENAME : 'FILENAME'; +FILEPATH : 'FILEPATH'; +FILES : 'FILES'; +FILESTREAM : 'FILESTREAM'; +FILESTREAM_ON : 'FILESTREAM_ON'; +FILETABLE : 'FILETABLE'; +FILLFACTOR : 'FILLFACTOR'; +FILTER : 'FILTER'; +FIRST : 'FIRST'; +FIRST_NAME : 'FIRST_NAME'; +FLATTEN : 'FLATTEN'; +FLOOR : 'FLOOR'; +FMTONLY : 'FMTONLY'; +FOLLOWING : 'FOLLOWING'; +FOR : 'FOR'; +FORCE : 'FORCE'; +FORCE_FAILOVER_ALLOW_DATA_LOSS : 'FORCE_FAILOVER_ALLOW_DATA_LOSS'; +FORCE_SERVICE_ALLOW_DATA_LOSS : 'FORCE_SERVICE_ALLOW_DATA_LOSS'; +FORCEPLAN : 'FORCEPLAN'; +FORCESCAN : 'FORCESCAN'; +FORCESEEK : 'FORCESEEK'; +FOREIGN : 'FOREIGN'; +FORMAT : 'FORMAT'; +FORMAT_NAME : 'FORMAT_NAME'; +FORMATS : 'FORMATS'; +FORWARD_ONLY : 'FORWARD_ONLY'; +FREE : 'FREE'; +FREETEXT : 'FREETEXT'; +FREETEXTTABLE : 'FREETEXTTABLE'; +FREQUENCY : 'FREQUENCY'; +FROM : 'FROM'; +FULL : 'FULL'; +FULLSCAN : 'FULLSCAN'; +FULLTEXT : 'FULLTEXT'; +FUNCTION : 'FUNCTION'; +FUNCTIONS : 'FUNCTIONS'; +FUTURE : 'FUTURE'; +GB : 'GB'; +GCP_PUBSUB : 'GCP_PUBSUB'; +GCP_PUBSUB_SUBSCRIPTION_NAME : 'GCP_PUBSUB_SUBSCRIPTION_NAME'; +GCP_PUBSUB_TOPIC_NAME : 'GCP_PUBSUB_TOPIC_NAME'; +GENERATED : 'GENERATED'; +GEO : 'GEO'; +GEOGRAPHY_OUTPUT_FORMAT : 'GEOGRAPHY_OUTPUT_FORMAT'; +GEOMETRY_OUTPUT_FORMAT : 'GEOMETRY_OUTPUT_FORMAT'; +GET : 'GET'; +GETDATE : 'GETDATE'; +GETROOT : 'GETROOT'; +GLOBAL : 'GLOBAL'; +GO : 'GO'; +GOOGLE_AUDIENCE : 'GOOGLE_AUDIENCE'; +GOTO : 'GOTO'; +GOVERNOR : 'GOVERNOR'; +GRANT : 'GRANT'; +GRANTS : 'GRANTS'; +GROUP : 'GROUP'; +GROUP_MAX_REQUESTS : 'GROUP_MAX_REQUESTS'; +GROUPING : 'GROUPING'; +GROUPS : 'GROUPS'; +GZIP : 'GZIP'; +HADR : 'HADR'; +HANDLER : 'HANDLER'; +HASH : 'HASH'; +HASHED : 'HASHED'; +HAVING : 'HAVING'; +HEADER : 'HEADER'; +HEADERS : 'HEADERS'; +HEALTH_CHECK_TIMEOUT : 'HEALTH_CHECK_TIMEOUT'; +HEALTHCHECKTIMEOUT : 'HEALTHCHECKTIMEOUT'; +HEAP : 'HEAP'; +HIDDEN_KEYWORD : 'HIDDEN'; +HIERARCHYID : 'HIERARCHYID'; +HIGH : 'HIGH'; +HISTORY : 'HISTORY'; +HOLDLOCK : 'HOLDLOCK'; +HONOR_BROKER_PRIORITY : 'HONOR_BROKER_PRIORITY'; +HOURS : 'HOURS'; +IDENTIFIER : 'IDENTIFIER'; +IDENTITY : 'IDENTITY'; +IDENTITY_INSERT : 'IDENTITY_INSERT'; +IDENTITY_VALUE : 'IDENTITY_VALUE'; +IF : 'IF'; +IFF : 'IFF'; +IGNORE : 'IGNORE'; +IGNORE_CONSTRAINTS : 'IGNORE_CONSTRAINTS'; +IGNORE_DUP_KEY : 'IGNORE_DUP_KEY'; +IGNORE_REPLICATED_TABLE_CACHE : 'IGNORE_REPLICATED_TABLE_CACHE'; +IGNORE_TRIGGERS : 'IGNORE_TRIGGERS'; +IGNORE_UTF8_ERRORS : 'IGNORE_UTF8_ERRORS'; +IIF : 'IIF'; +ILIKE : 'ILIKE'; +IMMEDIATE : 'IMMEDIATE'; +IMMEDIATELY : 'IMMEDIATELY'; +IMMUTABLE : 'IMMUTABLE'; +IMPERSONATE : 'IMPERSONATE'; +IMPLICIT : 'IMPLICIT'; +IMPLICIT_TRANSACTIONS : 'IMPLICIT_TRANSACTIONS'; +IMPORT : 'IMPORT'; +IMPORTANCE : 'IMPORTANCE'; +IMPORTED : 'IMPORTED'; +IMPORTS : 'IMPORTS'; +IN : 'IN'; +INCLUDE : 'INCLUDE'; +INCLUDE_NULL_VALUES : 'INCLUDE_NULL_VALUES'; +INCREMENT : 'INCREMENT'; +INCREMENTAL : 'INCREMENTAL'; +INDEX : 'INDEX'; +INFINITE : 'INFINITE'; +INFORMATION : 'INFORMATION'; +INIT : 'INIT'; +INITIAL_REPLICATION_SIZE_LIMIT_IN_TB : 'INITIAL_REPLICATION_SIZE_LIMIT_IN_TB'; +INITIALLY : 'INITIALLY'; +INITIALLY_SUSPENDED : 'INITIALLY_SUSPENDED'; +INITIATOR : 'INITIATOR'; +INNER : 'INNER'; +INPUT : 'INPUT'; +INSENSITIVE : 'INSENSITIVE'; +INSERT : 'INSERT'; +INSERT_ONLY : 'INSERT_ONLY'; +INSERTED : 'INSERTED'; +INSTEAD : 'INSTEAD'; +INTEGRATION : 'INTEGRATION'; +INTEGRATIONS : 'INTEGRATIONS'; +INTERSECT : 'INTERSECT'; +INTERVAL : 'INTERVAL'; +INTO : 'INTO'; +IO : 'IO'; +IP : 'IP'; +IS : 'IS'; +ISOLATION : 'ISOLATION'; +JDBC_TREAT_DECIMAL_AS_INT : 'JDBC_TREAT_DECIMAL_AS_INT'; +JDBC_TREAT_TIMESTAMP_NTZ_AS_UTC : 'JDBC_TREAT_TIMESTAMP_NTZ_AS_UTC'; +JDBC_USE_SESSION_TIMEZONE : 'JDBC_USE_SESSION_TIMEZONE'; +JOB : 'JOB'; +JOIN : 'JOIN'; +JS_TREAT_INTEGER_AS_BIGINT : 'JS_TREAT_INTEGER_AS_BIGINT'; +JSON : 'JSON'; +JSON_ARRAY : 'JSON_ARRAY'; +JSON_INDENT : 'JSON_INDENT'; +JSON_OBJECT : 'JSON_OBJECT'; +KB : 'KB'; +KEEPDEFAULTS : 'KEEPDEFAULTS'; +KEEPIDENTITY : 'KEEPIDENTITY'; +KERBEROS : 'KERBEROS'; +KEY : 'KEY'; +KEY_PATH : 'KEY_PATH'; +KEY_SOURCE : 'KEY_SOURCE'; +KEY_STORE_PROVIDER_NAME : 'KEY_STORE_PROVIDER_NAME'; +KEYS : 'KEYS'; +KEYSET : 'KEYSET'; +KILL : 'KILL'; +KMS_KEY_ID : 'KMS_KEY_ID'; +KWSKIP : 'SKIP'; +LANGUAGE : 'LANGUAGE'; +LARGE : 'LARGE'; +LAST : 'LAST'; +LAST_NAME : 'LAST_NAME'; +LAST_QUERY_ID : 'LAST_QUERY_ID'; +LATERAL : 'LATERAL'; +LEAD : 'LEAD'; +LEFT : 'LEFT'; +LENGTH : 'LENGTH'; +LET : 'LET'; +LEVEL : 'LEVEL'; +LIBRARY : 'LIBRARY'; +LIFETIME : 'LIFETIME'; +LIKE : 'LIKE'; +LIMIT : 'LIMIT'; +LINEAR : 'LINEAR'; +LINKED : 'LINKED'; +LINUX : 'LINUX'; +LIST : 'LIST'; +LISTAGG : 'LISTAGG'; +LISTENER : 'LISTENER'; +LISTENER_IP : 'LISTENER_IP'; +LISTENER_PORT : 'LISTENER_PORT'; +LISTENER_URL : 'LISTENER_URL'; +LISTING : 'LISTING'; +LOB_COMPACTION : 'LOB_COMPACTION'; +LOCAL : 'LOCAL'; +LOCAL_SERVICE_NAME : 'LOCAL_SERVICE_NAME'; +LOCALTIME : 'LOCALTIME'; +LOCALTIMESTAMP : 'LOCALTIMESTAMP'; +LOCATION : 'LOCATION'; +LOCK : 'LOCK'; +LOCK_ESCALATION : 'LOCK_ESCALATION'; +LOCK_TIMEOUT : 'LOCK_TIMEOUT'; +LOCKS : 'LOCKS'; +LOGIN : 'LOGIN'; +LOGIN_NAME : 'LOGIN_NAME'; +LOOKER : 'LOOKER'; +LOOP : 'LOOP'; +LOW : 'LOW'; +LZO : 'LZO'; +MANAGE : 'MANAGE'; +MANAGED : 'MANAGED'; +MANUAL : 'MANUAL'; +MARK : 'MARK'; +MASK : 'MASK'; +MASKED : 'MASKED'; +MASKING : 'MASKING'; +MASTER : 'MASTER'; +MASTER_KEY : 'MASTER_KEY'; +MATCH : 'MATCH'; +MATCH_BY_COLUMN_NAME : 'MATCH_BY_COLUMN_NAME'; +MATCH_RECOGNIZE : 'MATCH_RECOGNIZE'; +MATCHED : 'MATCHED'; +MATCHES : 'MATCHES'; +MATERIALIZED : 'MATERIALIZED'; +MAX : 'MAX'; +MAX_BATCH_ROWS : 'MAX_BATCH_ROWS'; +MAX_CLUSTER_COUNT : 'MAX_CLUSTER_COUNT'; +MAX_CONCURRENCY_LEVEL : 'MAX_CONCURRENCY_LEVEL'; +MAX_CPU_PERCENT : 'MAX_CPU_PERCENT'; +MAX_DATA_EXTENSION_TIME_IN_DAYS : 'MAX_DATA_EXTENSION_TIME_IN_DAYS'; +MAX_DISPATCH_LATENCY : 'MAX_DISPATCH_LATENCY'; +MAX_DOP : 'MAX_DOP'; +MAX_DURATION : 'MAX_DURATION'; +MAX_EVENT_SIZE : 'MAX_EVENT_SIZE'; +MAX_FILES : 'MAX_FILES'; +MAX_IOPS_PER_VOLUME : 'MAX_IOPS_PER_VOLUME'; +MAX_MEMORY : 'MAX_MEMORY'; +MAX_MEMORY_PERCENT : 'MAX_MEMORY_PERCENT'; +MAX_OUTSTANDING_IO_PER_VOLUME : 'MAX_OUTSTANDING_IO_PER_VOLUME'; +MAX_PROCESSES : 'MAX_PROCESSES'; +MAX_QUEUE_READERS : 'MAX_QUEUE_READERS'; +MAX_ROLLOVER_FILES : 'MAX_ROLLOVER_FILES'; +MAX_SIZE : 'MAX_SIZE'; +MAXSIZE : 'MAXSIZE'; +MAXTRANSFER : 'MAXTRANSFER'; +MAXVALUE : 'MAXVALUE'; +MB : 'MB'; +MEASURES : 'MEASURES'; +MEDIADESCRIPTION : 'MEDIADESCRIPTION'; +MEDIANAME : 'MEDIANAME'; +MEDIUM : 'MEDIUM'; +MEMBER : 'MEMBER'; +MEMOIZABLE : 'MEMOIZABLE'; +MEMORY_OPTIMIZED_DATA : 'MEMORY_OPTIMIZED_DATA'; +MEMORY_PARTITION_MODE : 'MEMORY_PARTITION_MODE'; +MERGE : 'MERGE'; +MESSAGE : 'MESSAGE'; +MESSAGE_FORWARD_SIZE : 'MESSAGE_FORWARD_SIZE'; +MESSAGE_FORWARDING : 'MESSAGE_FORWARDING'; +MIDDLE_NAME : 'MIDDLE_NAME'; +MIN_CLUSTER_COUNT : 'MIN_CLUSTER_COUNT'; +MIN_CPU_PERCENT : 'MIN_CPU_PERCENT'; +MIN_DATA_RETENTION_TIME_IN_DAYS : 'MIN_DATA_RETENTION_TIME_IN_DAYS'; +MIN_IOPS_PER_VOLUME : 'MIN_IOPS_PER_VOLUME'; +MIN_MEMORY_PERCENT : 'MIN_MEMORY_PERCENT'; +MINS_TO_BYPASS_MFA : 'MINS_TO_BYPASS_MFA'; +MINS_TO_UNLOCK : 'MINS_TO_UNLOCK'; +MINUS_ : 'MINUS'; +MINUTES : 'MINUTES'; +MINVALUE : 'MINVALUE'; +MIRROR : 'MIRROR'; +MIRROR_ADDRESS : 'MIRROR_ADDRESS'; +MIXED_PAGE_ALLOCATION : 'MIXED_PAGE_ALLOCATION'; +MODE : 'MODE'; +MODIFIED_AFTER : 'MODIFIED_AFTER'; +MODIFY : 'MODIFY'; +MONITOR : 'MONITOR'; +MONITORS : 'MONITORS'; +MONTHLY : 'MONTHLY'; +MOVE : 'MOVE'; +MULTI_STATEMENT_COUNT : 'MULTI_STATEMENT_COUNT'; +MULTI_USER : 'MULTI_USER'; +MUST_CHANGE : 'MUST_CHANGE'; +MUST_CHANGE_PASSWORD : 'MUST_CHANGE_PASSWORD'; +NAME : 'NAME'; +NATURAL : 'NATURAL'; +NEGOTIATE : 'NEGOTIATE'; +NESTED_TRIGGERS : 'NESTED_TRIGGERS'; +NETWORK : 'NETWORK'; +NETWORK_POLICY : 'NETWORK_POLICY'; +NEVER : 'NEVER'; +NEW_ACCOUNT : 'NEW_ACCOUNT'; +NEW_BROKER : 'NEW_BROKER'; +NEW_PASSWORD : 'NEW_PASSWORD'; +NEWNAME : 'NEWNAME'; +NEXT : 'NEXT'; +NEXTVAL : 'NEXTVAL'; +NO : 'NO'; +NO_CHECKSUM : 'NO_CHECKSUM'; +NO_COMPRESSION : 'NO_COMPRESSION'; +NO_EVENT_LOSS : 'NO_EVENT_LOSS'; +NO_INFOMSGS : 'NO_INFOMSGS'; +NO_QUERYSTORE : 'NO_QUERYSTORE'; +NO_STATISTICS : 'NO_STATISTICS'; +NO_TRUNCATE : 'NO_TRUNCATE'; +NO_WAIT : 'NO_WAIT'; +NOCHECK : 'NOCHECK'; +NOCOUNT : 'NOCOUNT'; +NODE : 'NODE'; +NODES : 'NODES'; +NOEXEC : 'NOEXEC'; +NOEXPAND : 'NOEXPAND'; +NOFORMAT : 'NOFORMAT'; +NOHOLDLOCK : 'NOHOLDLOCK'; +NOINDEX : 'NOINDEX'; +NOINIT : 'NOINIT'; +NOLOCK : 'NOLOCK'; +NON_TRANSACTED_ACCESS : 'NON_TRANSACTED_ACCESS'; +NONCLUSTERED : 'NONCLUSTERED'; +NONE : 'NONE'; +NOORDER : 'NOORDER'; +NORECOMPUTE : 'NORECOMPUTE'; +NORECOVERY : 'NORECOVERY'; +NORELY : 'NORELY'; +NOREWIND : 'NOREWIND'; +NOSKIP : 'NOSKIP'; +NOT : 'NOT'; +NOTIFICATION : 'NOTIFICATION'; +NOTIFICATION_INTEGRATION : 'NOTIFICATION_INTEGRATION'; +NOTIFICATION_PROVIDER : 'NOTIFICATION_PROVIDER'; +NOTIFICATIONS : 'NOTIFICATIONS'; +NOTIFY : 'NOTIFY'; +NOTIFY_USERS : 'NOTIFY_USERS'; +NOUNLOAD : 'NOUNLOAD'; +NOVALIDATE : 'NOVALIDATE'; +NTILE : 'NTILE'; +NTLM : 'NTLM'; +NULL : 'NULL'; +NULL_IF : 'NULL_IF'; +NULLS : 'NULLS'; +NUMANODE : 'NUMANODE'; +NUMERIC_ROUNDABORT : 'NUMERIC_ROUNDABORT'; +OAUTH : 'OAUTH'; +OAUTH_ALLOW_NON_TLS_REDIRECT_URI : 'OAUTH_ALLOW_NON_TLS_REDIRECT_URI'; +OAUTH_CLIENT : 'OAUTH_CLIENT'; +OAUTH_CLIENT_RSA_PUBLIC_KEY : 'OAUTH_CLIENT_RSA_PUBLIC_KEY'; +OAUTH_CLIENT_RSA_PUBLIC_KEY_2 : 'OAUTH_CLIENT_RSA_PUBLIC_KEY_2'; +OAUTH_ENFORCE_PKCE : 'OAUTH_ENFORCE_PKCE'; +OAUTH_ISSUE_REFRESH_TOKENS : 'OAUTH_ISSUE_REFRESH_TOKENS'; +OAUTH_REDIRECT_URI : 'OAUTH_REDIRECT_URI'; +OAUTH_REFRESH_TOKEN_VALIDITY : 'OAUTH_REFRESH_TOKEN_VALIDITY'; +OAUTH_USE_SECONDARY_ROLES : 'OAUTH_USE_SECONDARY_ROLES'; +OBJECT : 'OBJECT'; +OBJECT_TYPES : 'OBJECT_TYPES'; +OBJECTS : 'OBJECTS'; +OF : 'OF'; +OFF : 'OFF'; +OFFLINE : 'OFFLINE'; +OFFSET : 'OFFSET'; +OKTA : 'OKTA'; +OLD : 'OLD'; +OLD_ACCOUNT : 'OLD_ACCOUNT'; +OLD_PASSWORD : 'OLD_PASSWORD'; +OMIT : 'OMIT'; +ON : 'ON'; +ON_ERROR : 'ON_ERROR'; +ON_FAILURE : 'ON_FAILURE'; +ONE : 'ONE'; +ONLINE : 'ONLINE'; +ONLY : 'ONLY'; +OPEN : 'OPEN'; +OPEN_EXISTING : 'OPEN_EXISTING'; +OPENDATASOURCE : 'OPENDATASOURCE'; +OPENJSON : 'OPENJSON'; +OPENQUERY : 'OPENQUERY'; +OPENROWSET : 'OPENROWSET'; +OPENXML : 'OPENXML'; +OPERATE : 'OPERATE'; +OPERATIONS : 'OPERATIONS'; +OPTIMISTIC : 'OPTIMISTIC'; +OPTIMIZATION : 'OPTIMIZATION'; +OPTION : 'OPTION'; +OR : 'OR'; +ORC : 'ORC'; +ORDER : 'ORDER'; +ORGADMIN : 'ORGADMIN'; +ORGANIZATION : 'ORGANIZATION'; +OUT : 'OUT'; +OUTBOUND : 'OUTBOUND'; +OUTER : 'OUTER'; +OUTPUT : 'OUTPUT'; +OVER : 'OVER'; +OVERRIDE : 'OVERRIDE'; +OVERWRITE : 'OVERWRITE'; +OWNER : 'OWNER'; +OWNERSHIP : 'OWNERSHIP'; +PACKAGES : 'PACKAGES'; +PAD_INDEX : 'PAD_INDEX'; +PAGE : 'PAGE'; +PAGE_VERIFY : 'PAGE_VERIFY'; +PAGECOUNT : 'PAGECOUNT'; +PAGLOCK : 'PAGLOCK'; +PARALLEL : 'PARALLEL'; +PARAM_NODE : 'PARAM_NODE'; +PARAMETERIZATION : 'PARAMETERIZATION'; +PARAMETERS : 'PARAMETERS'; +PARQUET : 'PARQUET'; +PARSE : 'PARSE'; +PARSEONLY : 'PARSEONLY'; +PARTIAL : 'PARTIAL'; +PARTITION : 'PARTITION'; +PARTITION_TYPE : 'PARTITION_TYPE'; +PARTITIONS : 'PARTITIONS'; +PARTNER : 'PARTNER'; +PASSWORD : 'PASSWORD'; +PAST : 'PAST'; +PATH : 'PATH'; +PATTERN : 'PATTERN'; +PAUSE : 'PAUSE'; +PDW_SHOWSPACEUSED : 'PDW_SHOWSPACEUSED'; +PER : 'PER'; +PER_CPU : 'PER_CPU'; +PER_DB : 'PER_DB'; +PER_NODE : 'PER_NODE'; +PERCENT : 'PERCENT'; +PERIODIC_DATA_REKEYING : 'PERIODIC_DATA_REKEYING'; +PERMISSION_SET : 'PERMISSION_SET'; +PERSIST_SAMPLE_PERCENT : 'PERSIST_SAMPLE_PERCENT'; +PERSISTED : 'PERSISTED'; +PHYSICAL_ONLY : 'PHYSICAL_ONLY'; +PING_FEDERATE : 'PING_FEDERATE'; +PIPE : 'PIPE'; +PIPE_EXECUTION_PAUSED : 'PIPE_EXECUTION_PAUSED'; +PIPES : 'PIPES'; +PIVOT : 'PIVOT'; +PLATFORM : 'PLATFORM'; +POISON_MESSAGE_HANDLING : 'POISON_MESSAGE_HANDLING'; +POLICIES : 'POLICIES'; +POLICY : 'POLICY'; +POOL : 'POOL'; +PORT : 'PORT'; +PRE_AUTHORIZED_ROLES_LIST : 'PRE_AUTHORIZED_ROLES_LIST'; +PRECEDING : 'PRECEDING'; +PREDICATE : 'PREDICATE'; +PREFIX : 'PREFIX'; +PRESERVE_SPACE : 'PRESERVE_SPACE'; +PREVENT_UNLOAD_TO_INLINE_URL : 'PREVENT_UNLOAD_TO_INLINE_URL'; +PREVENT_UNLOAD_TO_INTERNAL_STAGES : 'PREVENT_UNLOAD_TO_INTERNAL_STAGES'; +PRIMARY : 'PRIMARY'; +PRIMARY_ROLE : 'PRIMARY_ROLE'; +PRINT : 'PRINT'; +PRIOR : 'PRIOR'; +PRIORITY : 'PRIORITY'; +PRIORITY_LEVEL : 'PRIORITY_LEVEL'; +PRIVATE : 'PRIVATE'; +PRIVATE_KEY : 'PRIVATE_KEY'; +PRIVILEGES : 'PRIVILEGES'; +PROCCACHE : 'PROCCACHE'; +PROCEDURE : 'PROC' 'EDURE'?; +PROCEDURE_NAME : 'PROCEDURE_NAME'; +PROCEDURES : 'PROCEDURES'; +PROCESS : 'PROCESS'; +PROFILE : 'PROFILE'; +PROPERTY : 'PROPERTY'; +PROVIDER : 'PROVIDER'; +PROVIDER_KEY_NAME : 'PROVIDER_KEY_NAME'; +PUBLIC : 'PUBLIC'; +PURGE : 'PURGE'; +PUT : 'PUT'; +QUALIFY : 'QUALIFY'; +QUERIES : 'QUERIES'; +QUERY : 'QUERY'; +QUERY_ACCELERATION_MAX_SCALE_FACTOR : 'QUERY_ACCELERATION_MAX_SCALE_FACTOR'; +QUERY_STORE : 'QUERY_STORE'; +QUERY_TAG : 'QUERY_TAG'; +QUEUE : 'QUEUE'; +QUEUE_DELAY : 'QUEUE_DELAY'; +QUOTED_IDENTIFIER : 'QUOTED_IDENTIFIER'; +QUOTED_IDENTIFIERS_IGNORE_CASE : 'QUOTED_IDENTIFIERS_IGNORE_CASE'; +RAISERROR : 'RAISERROR'; +RANDOMIZED : 'RANDOMIZED'; +RANGE : 'RANGE'; +RANK : 'RANK'; +RAW : 'RAW'; +RAW_DEFLATE : 'RAW_DEFLATE'; +RC2 : 'RC2'; +RC4 : 'RC4'; +RC4_128 : 'RC4_128'; +READ : 'READ'; +READ_COMMITTED_SNAPSHOT : 'READ_COMMITTED_SNAPSHOT'; +READ_ONLY : 'READ_ONLY'; +READ_ONLY_ROUTING_LIST : 'READ_ONLY_ROUTING_LIST'; +READ_WRITE : 'READ_WRITE'; +READ_WRITE_FILEGROUPS : 'READ_WRITE_FILEGROUPS'; +READCOMMITTED : 'READCOMMITTED'; +READCOMMITTEDLOCK : 'READCOMMITTEDLOCK'; +READER : 'READER'; +READONLY : 'READONLY'; +READPAST : 'READPAST'; +READUNCOMMITTED : 'READUNCOMMITTED'; +READWRITE : 'READWRITE'; +REBUILD : 'REBUILD'; +RECEIVE : 'RECEIVE'; +RECLUSTER : 'RECLUSTER'; +RECONFIGURE : 'RECONFIGURE'; +RECORD_DELIMITER : 'RECORD_DELIMITER'; +RECOVERY : 'RECOVERY'; +RECURSIVE : 'RECURSIVE'; +RECURSIVE_TRIGGERS : 'RECURSIVE_TRIGGERS'; +REFERENCE_USAGE : 'REFERENCE_USAGE'; +REFERENCES : 'REFERENCES'; +REFRESH : 'REFRESH'; +REFRESH_ON_CREATE : 'REFRESH_ON_CREATE'; +REGENERATE : 'REGENERATE'; +REGION : 'REGION'; +REGION_GROUP : 'REGION_GROUP'; +REGIONS : 'REGIONS'; +RELATED_CONVERSATION : 'RELATED_CONVERSATION'; +RELATED_CONVERSATION_GROUP : 'RELATED_CONVERSATION_GROUP'; +RELATIVE : 'RELATIVE'; +RELY : 'RELY'; +REMOTE : 'REMOTE'; +REMOTE_PROC_TRANSACTIONS : 'REMOTE_PROC_TRANSACTIONS'; +REMOTE_SERVICE_NAME : 'REMOTE_SERVICE_NAME'; +REMOVE : 'REMOVE'; +RENAME : 'RENAME'; +REORGANIZE : 'REORGANIZE'; +REPAIR_ALLOW_DATA_LOSS : 'REPAIR_ALLOW_DATA_LOSS'; +REPAIR_FAST : 'REPAIR_FAST'; +REPAIR_REBUILD : 'REPAIR_REBUILD'; +REPEATABLE : 'REPEATABLE'; +REPEATABLEREAD : 'REPEATABLEREAD'; +REPLACE : 'REPLACE'; +REPLACE_INVALID_CHARACTERS : 'REPLACE_INVALID_CHARACTERS'; +REPLICA : 'REPLICA'; +REPLICATE : 'REPLICATE'; +REPLICATION : 'REPLICATION'; +REPLICATION_SCHEDULE : 'REPLICATION_SCHEDULE'; +REQUEST_MAX_CPU_TIME_SEC : 'REQUEST_MAX_CPU_TIME_SEC'; +REQUEST_MAX_MEMORY_GRANT_PERCENT : 'REQUEST_MAX_MEMORY_GRANT_PERCENT'; +REQUEST_MEMORY_GRANT_TIMEOUT_SEC : 'REQUEST_MEMORY_GRANT_TIMEOUT_SEC'; +REQUEST_TRANSLATOR : 'REQUEST_TRANSLATOR'; +REQUIRED : 'REQUIRED'; +RESAMPLE : 'RESAMPLE'; +RESERVE_DISK_SPACE : 'RESERVE_DISK_SPACE'; +RESET : 'RESET'; +RESOURCE : 'RESOURCE'; +RESOURCE_MANAGER_LOCATION : 'RESOURCE_MANAGER_LOCATION'; +RESOURCE_MONITOR : 'RESOURCE_MONITOR'; +RESOURCES : 'RESOURCES'; +RESPECT : 'RESPECT'; +RESPONSE_TRANSLATOR : 'RESPONSE_TRANSLATOR'; +RESTART : 'RESTART'; +RESTRICT : 'RESTRICT'; +RESTRICTED_USER : 'RESTRICTED_USER'; +RESTRICTIONS : 'RESTRICTIONS'; +RESULT : 'RESULT'; +RESULTSET : 'RESULTSET'; +RESUMABLE : 'RESUMABLE'; +RESUME : 'RESUME'; +RETAINDAYS : 'RETAINDAYS'; +RETENTION : 'RETENTION'; +RETURN : 'RETURN'; +RETURN_ALL_ERRORS : 'RETURN_ALL_ERRORS'; +RETURN_ERRORS : 'RETURN_ERRORS'; +RETURN_FAILED_ONLY : 'RETURN_FAILED_ONLY'; +RETURN_ROWS : 'RETURN_ROWS'; +RETURNS : 'RETURNS'; +REVERT : 'REVERT'; +REVOKE : 'REVOKE'; +REWIND : 'REWIND'; +RIGHT : 'RIGHT'; +RLIKE : 'RLIKE'; +ROLE : 'ROLE'; +ROLES : 'ROLES'; +ROLLBACK : 'ROLLBACK'; +ROOT : 'ROOT'; +ROUND_ROBIN : 'ROUND_ROBIN'; +ROUTE : 'ROUTE'; +ROW : 'ROW'; +ROWCOUNT : 'ROWCOUNT'; +ROWGUID : 'ROWGUID'; +ROWGUIDCOL : 'ROWGUIDCOL'; +ROWLOCK : 'ROWLOCK'; +ROWS : 'ROWS'; +ROWS_PER_RESULTSET : 'ROWS_PER_RESULTSET'; +RSA_512 : 'RSA_512'; +RSA_1024 : 'RSA_1024'; +RSA_2048 : 'RSA_2048'; +RSA_3072 : 'RSA_3072'; +RSA_4096 : 'RSA_4096'; +RSA_PUBLIC_KEY : 'RSA_PUBLIC_KEY'; +RSA_PUBLIC_KEY_2 : 'RSA_PUBLIC_KEY_2'; +RULE : 'RULE'; +RUN_AS_ROLE : 'RUN_AS_ROLE'; +RUNTIME_VERSION : 'RUNTIME_VERSION'; +SAFE : 'SAFE'; +SAFETY : 'SAFETY'; +SAML2 : 'SAML2'; +SAML2_ENABLE_SP_INITIATED : 'SAML2_ENABLE_SP_INITIATED'; +SAML2_FORCE_AUTHN : 'SAML2_FORCE_AUTHN'; +SAML2_ISSUER : 'SAML2_ISSUER'; +SAML2_POST_LOGOUT_REDIRECT_URL : 'SAML2_POST_LOGOUT_REDIRECT_URL'; +SAML2_PROVIDER : 'SAML2_PROVIDER'; +SAML2_REQUESTED_NAMEID_FORMAT : 'SAML2_REQUESTED_NAMEID_FORMAT'; +SAML2_SIGN_REQUEST : 'SAML2_SIGN_REQUEST'; +SAML2_SNOWFLAKE_ACS_URL : 'SAML2_SNOWFLAKE_ACS_URL'; +SAML2_SNOWFLAKE_ISSUER_URL : 'SAML2_SNOWFLAKE_ISSUER_URL'; +SAML2_SNOWFLAKE_X509_CERT : 'SAML2_SNOWFLAKE_X509_CERT'; +SAML2_SP_INITIATED_LOGIN_PAGE_LABEL : 'SAML2_SP_INITIATED_LOGIN_PAGE_LABEL'; +SAML2_SSO_URL : 'SAML2_SSO_URL'; +SAML2_X509_CERT : 'SAML2_X509_CERT'; +SAML_IDENTITY_PROVIDER : 'SAML_IDENTITY_PROVIDER'; +SAMPLE : 'SAMPLE'; +SAVE : 'SAVE'; +SAVE_OLD_URL : 'SAVE_OLD_URL'; +SCALING_POLICY : 'SCALING_POLICY'; +SCHEDULE : 'SCHEDULE'; +SCHEDULER : 'SCHEDULER'; +SCHEMA : 'SCHEMA'; +SCHEMABINDING : 'SCHEMABINDING'; +SCHEMAS : 'SCHEMAS'; +SCHEME : 'SCHEME'; +SCIM : 'SCIM'; +SCIM_CLIENT : 'SCIM_CLIENT'; +SCOPED : 'SCOPED'; +SCRIPT : 'SCRIPT'; +SCROLL : 'SCROLL'; +SCROLL_LOCKS : 'SCROLL_LOCKS'; +SEARCH : 'SEARCH'; +SECONDARY : 'SECONDARY'; +SECONDARY_ONLY : 'SECONDARY_ONLY'; +SECONDARY_ROLE : 'SECONDARY_ROLE'; +SECONDS : 'SECONDS'; +SECRET : 'SECRET'; +SECURABLES : 'SECURABLES'; +SECURE : 'SECURE'; +SECURITY : 'SECURITY'; +SECURITY_LOG : 'SECURITY_LOG'; +SECURITYADMIN : 'SECURITYADMIN'; +SEED : 'SEED'; +SEEDING_MODE : 'SEEDING_MODE'; +SELECT : 'SELECT'; +SELF : 'SELF'; +SEMANTICKEYPHRASETABLE : 'SEMANTICKEYPHRASETABLE'; +SEMANTICSIMILARITYDETAILSTABLE : 'SEMANTICSIMILARITYDETAILSTABLE'; +SEMANTICSIMILARITYTABLE : 'SEMANTICSIMILARITYTABLE'; +SEMI_SENSITIVE : 'SEMI_SENSITIVE'; +SEND : 'SEND'; +SENT : 'SENT'; +SEQUENCE : 'SEQUENCE'; +SEQUENCE_NUMBER : 'SEQUENCE_NUMBER'; +SEQUENCES : 'SEQUENCES'; +SERIALIZABLE : 'SERIALIZABLE'; +SERVER : 'SERVER'; +SERVICE : 'SERVICE'; +SERVICE_BROKER : 'SERVICE_BROKER'; +SERVICE_NAME : 'SERVICE_NAME'; +SERVICEBROKER : 'SERVICEBROKER'; +SESSION : 'SESSION'; +SESSION_IDLE_TIMEOUT_MINS : 'SESSION_IDLE_TIMEOUT_MINS'; +SESSION_POLICY : 'SESSION_POLICY'; +SESSION_TIMEOUT : 'SESSION_TIMEOUT'; +SESSION_UI_IDLE_TIMEOUT_MINS : 'SESSION_UI_IDLE_TIMEOUT_MINS'; +SESSION_USER : 'SESSION_USER'; +SET : 'SET'; +SETS : 'SETS'; +SETTINGS : 'SETTINGS'; +SETUSER : 'SETUSER'; +SHARE : 'SHARE'; +SHARE_RESTRICTIONS : 'SHARE_RESTRICTIONS'; +SHARED : 'SHARED'; +SHARES : 'SHARES'; +SHOW : 'SHOW'; +SHOW_INITIAL_ROWS : 'SHOW_INITIAL_ROWS'; +SHOWCONTIG : 'SHOWCONTIG'; +SHOWPLAN : 'SHOWPLAN'; +SHOWPLAN_ALL : 'SHOWPLAN_ALL'; +SHOWPLAN_TEXT : 'SHOWPLAN_TEXT'; +SHOWPLAN_XML : 'SHOWPLAN_XML'; +SHRINKLOG : 'SHRINKLOG'; +SHUTDOWN : 'SHUTDOWN'; +SID : 'SID'; +SIGNATURE : 'SIGNATURE'; +SIMPLE : 'SIMPLE'; +SIMULATED_DATA_SHARING_CONSUMER : 'SIMULATED_DATA_SHARING_CONSUMER'; +SINGLE_USER : 'SINGLE_USER'; +SIZE : 'SIZE'; +SIZE_LIMIT : 'SIZE_LIMIT'; +SKIP_BLANK_LINES : 'SKIP_BLANK_LINES'; +SKIP_BYTE_ORDER_MARK : 'SKIP_BYTE_ORDER_MARK'; +SKIP_FILE : 'SKIP_FILE'; +SKIP_HEADER : 'SKIP_HEADER'; +SMALL : 'SMALL'; +SNAPPY : 'SNAPPY'; +SNAPPY_COMPRESSION : 'SNAPPY_COMPRESSION'; +SNAPSHOT : 'SNAPSHOT'; +SNOWFLAKE_FULL : 'SNOWFLAKE_FULL'; +SNOWFLAKE_SSE : 'SNOWFLAKE_SSE'; +SOFTNUMA : 'SOFTNUMA'; +SOME : 'SOME'; +SORT_IN_TEMPDB : 'SORT_IN_TEMPDB'; +SOURCE : 'SOURCE'; +SOURCE_COMPRESSION : 'SOURCE_COMPRESSION'; +SP_EXECUTESQL : 'SP_EXECUTESQL'; +SPARSE : 'SPARSE'; +SPATIAL_WINDOW_MAX_CELLS : 'SPATIAL_WINDOW_MAX_CELLS'; +SPECIFICATION : 'SPECIFICATION'; +SPLIT : 'SPLIT'; +SQLDUMPERFLAGS : 'SQLDUMPERFLAGS'; +SQLDUMPERPATH : 'SQLDUMPERPATH'; +SQLDUMPERTIMEOUT : 'SQLDUMPERTIMEOUT'; +SSO_LOGIN_PAGE : 'SSO_LOGIN_PAGE'; +STAGE : 'STAGE'; +STAGE_COPY_OPTIONS : 'STAGE_COPY_OPTIONS'; +STAGE_FILE_FORMAT : 'STAGE_FILE_FORMAT'; +STAGES : 'STAGES'; +STANDARD : 'STANDARD'; +STANDBY : 'STANDBY'; +START : 'START'; +START_DATE : 'START_DATE'; +START_TIMESTAMP : 'START_TIMESTAMP'; +STARTED : 'STARTED'; +STARTS : 'STARTS'; +STARTUP_STATE : 'STARTUP_STATE'; +STATE : 'STATE'; +STATEMENT : 'STATEMENT'; +STATEMENT_QUEUED_TIMEOUT_IN_SECONDS : 'STATEMENT_QUEUED_TIMEOUT_IN_SECONDS'; +STATEMENT_TIMEOUT_IN_SECONDS : 'STATEMENT_TIMEOUT_IN_SECONDS'; +STATIC : 'STATIC'; +STATISTICS : 'STATISTICS'; +STATISTICS_INCREMENTAL : 'STATISTICS_INCREMENTAL'; +STATISTICS_NORECOMPUTE : 'STATISTICS_NORECOMPUTE'; +STATS : 'STATS'; +STATS_STREAM : 'STATS_STREAM'; +STATUS : 'STATUS'; +STATUSONLY : 'STATUSONLY'; +STOP : 'STOP'; +STOP_ON_ERROR : 'STOP_ON_ERROR'; +STOPLIST : 'STOPLIST'; +STOPPED : 'STOPPED'; +STORAGE : 'STORAGE'; +STORAGE_ALLOWED_LOCATIONS : 'STORAGE_ALLOWED_LOCATIONS'; +STORAGE_AWS_OBJECT_ACL : 'STORAGE_AWS_OBJECT_ACL'; +STORAGE_AWS_ROLE_ARN : 'STORAGE_AWS_ROLE_ARN'; +STORAGE_BLOCKED_LOCATIONS : 'STORAGE_BLOCKED_LOCATIONS'; +STORAGE_INTEGRATION : 'STORAGE_INTEGRATION'; +STORAGE_PROVIDER : 'STORAGE_PROVIDER'; +STREAM : 'STREAM'; +STREAMS : 'STREAMS'; +STRICT : 'STRICT'; +STRICT_JSON_OUTPUT : 'STRICT_JSON_OUTPUT'; +STRIP_NULL_VALUES : 'STRIP_NULL_VALUES'; +STRIP_OUTER_ARRAY : 'STRIP_OUTER_ARRAY'; +STRIP_OUTER_ELEMENT : 'STRIP_OUTER_ELEMENT'; +SUBJECT : 'SUBJECT'; +SUBSCRIBE : 'SUBSCRIBE'; +SUBSCRIPTION : 'SUBSCRIPTION'; +SUBSTRING : 'SUBSTRING'; +SUPPORTED : 'SUPPORTED'; +SUSPEND : 'SUSPEND'; +SUSPEND_IMMEDIATE : 'SUSPEND_IMMEDIATE'; +SUSPEND_TASK_AFTER_NUM_FAILURES : 'SUSPEND_TASK_AFTER_NUM_FAILURES'; +SUSPENDED : 'SUSPENDED'; +SWAP : 'SWAP'; +SWITCH : 'SWITCH'; +SYMMETRIC : 'SYMMETRIC'; +SYNC_PASSWORD : 'SYNC_PASSWORD'; +SYNCHRONOUS_COMMIT : 'SYNCHRONOUS_COMMIT'; +SYNONYM : 'SYNONYM'; +SYSADMIN : 'SYSADMIN'; +SYSTEM : 'SYSTEM'; +SYSTEM_USER : 'SYSTEM_USER'; +TABLE : 'TABLE'; +TABLE_FORMAT : 'TABLE_FORMAT'; +TABLEAU_DESKTOP : 'TABLEAU_DESKTOP'; +TABLEAU_SERVER : 'TABLEAU_SERVER'; +TABLERESULTS : 'TABLERESULTS'; +TABLES : 'TABLES'; +TABLESAMPLE : 'TABLESAMPLE'; +TABLOCK : 'TABLOCK'; +TABLOCKX : 'TABLOCKX'; +TABULAR : 'TABULAR'; +TAG : 'TAG'; +TAGS : 'TAGS'; +TAKE : 'TAKE'; +TAPE : 'TAPE'; +TARGET : 'TARGET'; +TARGET_LAG : 'TARGET_LAG'; +TARGET_RECOVERY_TIME : 'TARGET_RECOVERY_TIME'; +TASK : 'TASK'; +TASKS : 'TASKS'; +TB : 'TB'; +TCP : 'TCP'; +TEMP : 'TEMP'; +TEMPORARY : 'TEMPORARY'; +TERSE : 'TERSE'; +TEXTIMAGE_ON : 'TEXTIMAGE_ON'; +TEXTSIZE : 'TEXTSIZE'; +THEN : 'THEN'; +THROW : 'THROW'; +TIES : 'TIES'; +TIME_FORMAT : 'TIME_FORMAT'; +TIME_INPUT_FORMAT : 'TIME_INPUT_FORMAT'; +TIME_OUTPUT_FORMAT : 'TIME_OUTPUT_FORMAT'; +TIMEDIFF : 'TIMEDIFF'; +TIMEOUT : 'TIMEOUT'; +TIMER : 'TIMER'; +TIMESTAMP : 'TIMESTAMP'; +TIMESTAMP_DAY_IS_ALWAYS_24H : 'TIMESTAMP_DAY_IS_ALWAYS_24H'; +TIMESTAMP_FORMAT : 'TIMESTAMP_FORMAT'; +TIMESTAMP_INPUT_FORMAT : 'TIMESTAMP_INPUT_FORMAT'; +TIMESTAMP_LTZ_OUTPUT_FORMAT : 'TIMESTAMP_LTZ_OUTPUT_FORMAT'; +TIMESTAMP_NTZ_OUTPUT_FORMAT : 'TIMESTAMP_NTZ_OUTPUT_FORMAT'; +TIMESTAMP_OUTPUT_FORMAT : 'TIMESTAMP_OUTPUT_FORMAT'; +TIMESTAMP_TYPE_MAPPING : 'TIMESTAMP_TYPE_MAPPING'; +TIMESTAMP_TZ_OUTPUT_FORMAT : 'TIMESTAMP_TZ_OUTPUT_FORMAT'; +TIMEZONE : 'TIMEZONE'; +TO : 'TO'; +TOP : 'TOP'; +TORN_PAGE_DETECTION : 'TORN_PAGE_DETECTION'; +TOSTRING : 'TOSTRING'; +TRACE : 'TRACE'; +TRACK_CAUSALITY : 'TRACK_CAUSALITY'; +TRACKING : 'TRACKING'; +TRAN : 'TRAN'; +TRANSACTION : 'TRANSACTION'; +TRANSACTION_ABORT_ON_ERROR : 'TRANSACTION_ABORT_ON_ERROR'; +TRANSACTION_DEFAULT_ISOLATION_LEVEL : 'TRANSACTION_DEFAULT_ISOLATION_LEVEL'; +TRANSACTION_ID : 'TRANSACTION_ID'; +TRANSACTIONS : 'TRANSACTIONS'; +TRANSFER : 'TRANSFER'; +TRANSFORM_NOISE_WORDS : 'TRANSFORM_NOISE_WORDS'; +TRANSIENT : 'TRANSIENT'; +TRIGGER : 'TRIGGER'; +TRIGGERS : 'TRIGGERS'; +TRIM_SPACE : 'TRIM_SPACE'; +TRIPLE_DES : 'TRIPLE_DES'; +TRIPLE_DES_3KEY : 'TRIPLE_DES_3KEY'; +TRUE : 'TRUE'; +TRUNCATE : 'TRUNCATE'; +TRUNCATECOLUMNS : 'TRUNCATECOLUMNS'; +TRUSTWORTHY : 'TRUSTWORTHY'; +TRY : 'TRY'; +TRY_CAST : 'TRY_CAST'; +TSQL : 'TSQL'; +TWO_DIGIT_CENTURY_START : 'TWO_DIGIT_CENTURY_START'; +TWO_DIGIT_YEAR_CUTOFF : 'TWO_DIGIT_YEAR_CUTOFF'; +TYPE : 'TYPE'; +TYPE_WARNING : 'TYPE_WARNING'; +UNBOUNDED : 'UNBOUNDED'; +UNCHECKED : 'UNCHECKED'; +UNCOMMITTED : 'UNCOMMITTED'; +UNDROP : 'UNDROP'; +UNION : 'UNION'; +UNIQUE : 'UNIQUE'; +UNLIMITED : 'UNLIMITED'; +UNLOCK : 'UNLOCK'; +UNMASK : 'UNMASK'; +UNMATCHED : 'UNMATCHED'; +UNPIVOT : 'UNPIVOT'; +UNSAFE : 'UNSAFE'; +UNSET : 'UNSET'; +UNSUPPORTED_DDL_ACTION : 'UNSUPPORTED_DDL_ACTION'; +UOW : 'UOW'; +UPDATE : 'UPDATE'; +UPDLOCK : 'UPDLOCK'; +URL : 'URL'; +USAGE : 'USAGE'; +USE : 'USE'; +USE_ANY_ROLE : 'USE_ANY_ROLE'; +USE_CACHED_RESULT : 'USE_CACHED_RESULT'; +USED : 'USED'; +USER : 'USER'; +USER_SPECIFIED : 'USER_SPECIFIED'; +USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE : 'USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE'; +USER_TASK_TIMEOUT_MS : 'USER_TASK_TIMEOUT_MS'; +USERADMIN : 'USERADMIN'; +USERS : 'USERS'; +USING : 'USING'; +VALID_XML : 'VALID_XML'; +VALIDATE : 'VALIDATE'; +VALIDATION : 'VALIDATION'; +VALIDATION_MODE : 'VALIDATION_MODE'; +VALUE : 'VALUE'; +VALUES : 'VALUES'; +VAR : 'VAR'; +VARIABLES : 'VARIABLES'; +VARYING : 'VARYING'; +VERBOSELOGGING : 'VERBOSELOGGING'; +VERIFY_CLONEDB : 'VERIFY_CLONEDB'; +VERSION : 'VERSION'; +VIEW : 'VIEW'; +VIEW_METADATA : 'VIEW_METADATA'; +VIEWS : 'VIEWS'; +VISIBILITY : 'VISIBILITY'; +VOLATILE : 'VOLATILE'; +WAIT : 'WAIT'; +WAIT_AT_LOW_PRIORITY : 'WAIT_AT_LOW_PRIORITY'; +WAITFOR : 'WAITFOR'; +WAREHOUSE : 'WAREHOUSE'; +WAREHOUSE_SIZE : 'WAREHOUSE_SIZE'; +WAREHOUSE_TYPE : 'WAREHOUSE_TYPE'; +WAREHOUSES : 'WAREHOUSES'; +WEEK_OF_YEAR_POLICY : 'WEEK_OF_YEAR_POLICY'; +WEEK_START : 'WEEK_START'; +WEEKLY : 'WEEKLY'; +WELL_FORMED_XML : 'WELL_FORMED_XML'; +WHEN : 'WHEN'; +WHERE : 'WHERE'; +WHILE : 'WHILE'; +WINDOWS : 'WINDOWS'; +WITH : 'WITH'; +WITHIN : 'WITHIN'; +WITHOUT : 'WITHOUT'; +WITHOUT_ARRAY_WRAPPER : 'WITHOUT_ARRAY_WRAPPER'; +WITNESS : 'WITNESS'; +WORK : 'WORK'; +WORKLOAD : 'WORKLOAD'; +WRITE : 'WRITE'; +X4LARGE : 'X4LARGE'; +X5LARGE : 'X5LARGE'; +X6LARGE : 'X6LARGE'; +XACT_ABORT : 'XACT_ABORT'; +XLARGE : 'XLARGE'; +XLOCK : 'XLOCK'; +XML : 'XML'; +XML_COMPRESSION : 'XML_COMPRESSION'; +XMLDATA : 'XMLDATA'; +XMLNAMESPACES : 'XMLNAMESPACES'; +XMLSCHEMA : 'XMLSCHEMA'; +XSINIL : 'XSINIL'; +XSMALL : 'XSMALL'; +XXLARGE : 'XXLARGE'; +XXXLARGE : 'XXXLARGE'; +YEARLY : 'YEARLY'; +ZONE : 'ZONE'; +ZSTD : 'ZSTD'; + +// Common operators +ASSIGN: ':='; + +// Common symbols +DOLLAR_STRING: '$$' ('\\$' | '$' ~'$' | ~'$')*? '$$'; + +// Junky stuff for Snowflake - we will gradually get rid of this + +// TODO: Revisit these tokens +RETURN_N_ROWS : 'RETURN_' [0-9]+ '_ROWS'; +SKIP_FILE_N : 'SKIP_FILE_' [0-9]+; + +// TODO: Replace these long options with genericOption as per TSQL - many others in commonlex also. +CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS : 'CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS'; +CLIENT_ENCRYPTION_KEY_SIZE : 'CLIENT_ENCRYPTION_KEY_SIZE'; +CLIENT_MEMORY_LIMIT : 'CLIENT_MEMORY_LIMIT'; +CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX : 'CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX'; +CLIENT_METADATA_USE_SESSION_DATABASE : 'CLIENT_METADATA_USE_SESSION_DATABASE'; +CLIENT_PREFETCH_THREADS : 'CLIENT_PREFETCH_THREADS'; +CLIENT_RESULT_CHUNK_SIZE : 'CLIENT_RESULT_CHUNK_SIZE'; +CLIENT_RESULT_COLUMN_CASE_INSENSITIVE : 'CLIENT_RESULT_COLUMN_CASE_INSENSITIVE'; +CLIENT_SESSION_KEEP_ALIVE : 'CLIENT_SESSION_KEEP_ALIVE'; +CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY : 'CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY'; +CLIENT_TIMESTAMP_TYPE_MAPPING : 'CLIENT_TIMESTAMP_TYPE_MAPPING'; +EXTERNAL_OAUTH : 'EXTERNAL_OAUTH'; +EXTERNAL_OAUTH_ADD_PRIVILEGED_ROLES_TO_BLOCKED_LIST: + 'EXTERNAL_OAUTH_ADD_PRIVILEGED_ROLES_TO_BLOCKED_LIST' +; +EXTERNAL_OAUTH_ALLOWED_ROLES_LIST : 'EXTERNAL_OAUTH_ALLOWED_ROLES_LIST'; +EXTERNAL_OAUTH_ANY_ROLE_MODE : 'EXTERNAL_OAUTH_ANY_ROLE_MODE'; +EXTERNAL_OAUTH_AUDIENCE_LIST : 'EXTERNAL_OAUTH_AUDIENCE_LIST'; +EXTERNAL_OAUTH_BLOCKED_ROLES_LIST : 'EXTERNAL_OAUTH_BLOCKED_ROLES_LIST'; +EXTERNAL_OAUTH_ISSUER : 'EXTERNAL_OAUTH_ISSUER'; +EXTERNAL_OAUTH_JWS_KEYS_URL : 'EXTERNAL_OAUTH_JWS_KEYS_URL'; +EXTERNAL_OAUTH_RSA_PUBLIC_KEY : 'EXTERNAL_OAUTH_RSA_PUBLIC_KEY'; +EXTERNAL_OAUTH_RSA_PUBLIC_KEY_2 : 'EXTERNAL_OAUTH_RSA_PUBLIC_KEY_2'; +EXTERNAL_OAUTH_SCOPE_DELIMITER : 'EXTERNAL_OAUTH_SCOPE_DELIMITER'; +EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE: + 'EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE' +; +EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM : 'EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM'; +EXTERNAL_OAUTH_TYPE : 'EXTERNAL_OAUTH_TYPE'; +EXTERNAL_STAGE : 'EXTERNAL_STAGE'; +REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_CREATION : 'REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_CREATION'; +REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_OPERATION: + 'REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_OPERATION' +; +// TOOD: Replace usage with genericOption(s) +REQUIRED_SYNCHRONIZED_SECONDARIES_TO_COMMIT: 'REQUIRED_SYNCHRONIZED_SECONDARIES_TO_COMMIT'; + +// Whitespace handling +WS: SPACE+ -> skip; + +// Comments +SQL_COMMENT : '/*' (SQL_COMMENT | .)*? '*/' -> channel(HIDDEN); +LINE_COMMENT : ('--' | '//') ~[\r\n]* -> channel(HIDDEN); + +// Identifiers +// Note that we consume any Jinja template reference that directly follows an ID +// so that the parser does not have to deal with it. Down the line we may have to +// spot template references in IDs and not mess with them if we are doing ID name trnsformations +// as if the ID is a composite of a template reference and a static string, we cannot use them +// to say build a schema for a table etc. +ID : IDFORM JINJA_REF_FORM?; +DOUBLE_QUOTE_ID : '"' ('""' | ~[\r\n"])* '"'; + +// Jinja Template Elements - note that composite elements whereby an identifier +// is constricted at runtime, fall out by eating the following text here. The post +// processor will only replace the template identifier. So: +// {{ ref("something") }}_mysuffix will be tokenized as the template reference +JINJA_REF: JINJA_REF_FORM ID?; + +// This lexer rule is needed so that any unknown character in the lexicon does not +// cause an incomprehensible error message from the lexer. This rule will allow the parser to issue +// something more meaningful and perform error recovery as the lexer CANNOT raise an error - it +// will alwys match at least one character using this catch-all rule. +// +// !IMPORTANT! - Always leave this as the last lexer rule, before the mode definitions +BADCHAR: .; + +// ------------------------------------------------------- +// Fragments for use in other lexer rules +fragment IDFORM : ( [A-Z_] | FullWidthLetter) ( [A-Z_#$@0-9] | FullWidthLetter)*; +fragment JINJA_REF_FORM : '_!Jinja' [0-9]+ ID?; +fragment LETTER : [A-Z_]; +fragment HEX_DIGIT : [0-9A-F]; +fragment DEC_DOT_DEC : [0-9]+ '.' [0-9]+ | [0-9]+ '.' | '.' [0-9]+; +fragment FullWidthLetter options { + caseInsensitive = false; +}: + '\u00c0' ..'\u00d6' + | '\u00d8' ..'\u00f6' + | '\u00f8' ..'\u00ff' + | '\u0100' ..'\u1fff' + | '\u2c00' ..'\u2fff' + | '\u3040' ..'\u318f' + | '\u3300' ..'\u337f' + | '\u3400' ..'\u3fff' + | '\u4e00' ..'\u9fff' + | '\ua000' ..'\ud7ff' + | '\uf900' ..'\ufaff' + | '\uff00' ..'\ufff0' +; // | '\u20000'..'\u2FA1F' +fragment HexDigit : [0-9a-f]; +fragment HexString : [A-Z0-9|.] [A-Z0-9+\-|.]*; +fragment SPACE: + [ \t\r\n\u000c\u0085\u00a0\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u202f\u205f\u3000]+ +; \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonparse.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonparse.g4 new file mode 100644 index 0000000000..ab9a66cb82 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/commonparse.g4 @@ -0,0 +1,785 @@ +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignColons hanging +// $antlr-format columnLimit 150 +// $antlr-format alignSemicolons hanging +// $antlr-format alignTrailingComments true +// ================================================================================= + +parser grammar commonparse; + +// The original TSQL grammar was basically trying to allow just about any reserved word as an identier and ANTLR will +// generate a parser that MOSTLY allows this, but it can create parsers where the lookahead is very large and slow. +// +// The use of reserved words has always been up in the air as the manuals say that you must escape them [ OUTER ] "INNER" +// depending on configuration flags that are set. This is a bit of a mess and we only need to handle keywords that are +// geniunely allowed (accidentally or not) in the various dialects. +// +// THis list has been reduced by Jim based on TSQL and Snowflake defined reserved words, but probably needs to be checked +// or adjusted if we see contrary examples within customer queries. +// +// In the worst case, we have to tell customers to escape one or two reserved words that the original query author +// did not spot as violating the rules but the source dialect parser will accept anyway. +keyword + : ABORT + | ABORT_AFTER_WAIT + | ABSENT + | ABSOLUTE + | ACCENT_SENSITIVITY + | ACCESS + | ACCOUNTADMIN + | ACTION + | ACTIVATION + | ACTIVE + | ADDRESS + | ADMINISTER + | AES + | AES_128 + | AES_192 + | AES_256 + | AFFINITY + | AFTER + | AGGREGATE + | ALERT + | ALGORITHM + | ALL_CONSTRAINTS + | ALL_ERRORMSGS + | ALL_INDEXES + | ALL_LEVELS + | ALLOW_CONNECTIONS + | ALLOW_ENCRYPTED_VALUE_MODIFICATIONS + | ALLOW_MULTIPLE_EVENT_LOSS + | ALLOW_PAGE_LOCKS + | ALLOW_ROW_LOCKS + | ALLOW_SINGLE_EVENT_LOSS + | ALLOW_SNAPSHOT_ISOLATION + | ALLOWED + | ALWAYS + | ANONYMOUS + | ANSI_DEFAULTS + | ANSI_NULL_DEFAULT + | ANSI_NULL_DFLT_OFF + | ANSI_NULL_DFLT_ON + | ANSI_NULLS + | ANSI_PADDING + | ANSI_WARNINGS + | APPEND + | APPLICATION + | APPLICATION_LOG + | APPLY + | ARITHABORT + | ARITHIGNORE + | ARRAY + | ARRAY_AGG + | ASSEMBLY + | ASYMMETRIC + | ASYNCHRONOUS_COMMIT + | AT_KEYWORD + | AUDIT + | AUDIT_GUID + | AUTHENTICATE + | AUTHENTICATION + | AUTO + | AUTO_CLEANUP + | AUTO_CLOSE + | AUTO_CREATE_STATISTICS + | AUTO_DROP + | AUTO_SHRINK + | AUTO_UPDATE_STATISTICS + | AUTO_UPDATE_STATISTICS_ASYNC + | AUTOGROW_ALL_FILES + | AUTOGROW_SINGLE_FILE + | AUTOMATED_BACKUP_PREFERENCE + | AUTOMATIC + | AVAILABILITY + | AVAILABILITY_MODE + | BACKUP_CLONEDB + | BACKUP_PRIORITY + | BEFORE + | BEGIN_DIALOG + | BINARY + | BINDING + | BLOB_STORAGE + | BLOCK + | BLOCKERS + | BLOCKSIZE + | BODY + | BROKER + | BROKER_INSTANCE + | BUFFER + | BUFFERCOUNT + | BULK_LOGGED + | CACHE + | CALLED + | CALLER + | CAP_CPU_PERCENT + | CAST + | CATALOG + | CATCH + | CERTIFICATE + | CHANGE + | CHANGE_RETENTION + | CHANGE_TRACKING + | CHANGES + | CHANGETABLE + | CHARACTER + | CHECK_EXPIRATION + | CHECK_POLICY + | CHECKALLOC + | CHECKCATALOG + | CHECKCONSTRAINTS + | CHECKDB + | CHECKFILEGROUP + | CHECKSUM + | CHECKTABLE + | CLASSIFIER_FUNCTION + | CLEANTABLE + | CLEANUP + | CLONEDATABASE + | CLUSTER + | COLLATE + | COLLECTION + | COLUMN_ENCRYPTION_KEY + | COLUMN_MASTER_KEY + | COLUMNS + | COLUMNSTORE + | COLUMNSTORE_ARCHIVE + | COMMENT + | COMMITTED + | COMPATIBILITY_LEVEL + | COMPRESS_ALL_ROW_GROUPS + | COMPRESSION + | COMPRESSION_DELAY + | CONCAT + | CONCAT_NULL_YIELDS_NULL + | CONDITION + | CONFIGURATION + | CONNECT + | CONNECTION + | CONTAINS + | CONTAINMENT + | CONTENT + | CONTEXT + | CONTINUE_AFTER_ERROR + | CONTRACT + | CONTRACT_NAME + | CONTROL + | CONVERSATION + | COOKIE + | COPY_ONLY + | COPY_OPTIONS_ + | COUNTER + | CPU + | CREATE_NEW + | CREATION_DISPOSITION + | CREDENTIAL + | CRYPTOGRAPHIC + | CURSOR_CLOSE_ON_COMMIT + | CURSOR_DEFAULT + | CYCLE + | DATA + | DATA_COMPRESSION + | DATA_PURITY + | DATA_SOURCE + | DATABASE_MIRRORING + | DATASPACE + | DATE_CORRELATION_OPTIMIZATION + | DATE_FORMAT + | DAYS + | DB_CHAINING + | DB_FAILOVER + | DBREINDEX + | DDL + | DECRYPTION + | DEFAULT + | DEFAULT_DATABASE + | DEFAULT_DOUBLE_QUOTE + | DEFAULT_FULLTEXT_LANGUAGE + | DEFAULT_LANGUAGE + | DEFAULT_SCHEMA + | DEFINITION + | DELAY + | DELAYED_DURABILITY + | DELETED + | DELTA + | DENSE_RANK + | DEPENDENTS + | DES + | DESCRIPTION + | DESX + | DETERMINISTIC + | DHCP + | DIAGNOSTICS + | DIALOG + | DIFFERENTIAL + | DIRECTION + | DIRECTORY_NAME + | DISABLE + | DISABLE_AUTO_CONVERT + | DISABLE_BROKER + | DISABLED + | DISTRIBUTION + | DOCUMENT + | DOWNSTREAM + | DROP_EXISTING + | DROPCLEANBUFFERS + | DTC_SUPPORT + | DUMMY + | DYNAMIC + | EDITION + | ELEMENTS + | EMAIL + | EMERGENCY + | EMPTY + | ENABLE + | ENABLE_BROKER + | ENABLED + | ENCRYPTED + | ENCRYPTED_VALUE + | ENCRYPTION + | ENCRYPTION_TYPE + | END + | ENDPOINT + | ENDPOINT_URL + | ERROR + | ERROR_BROKER_CONVERSATIONS + | ESTIMATEONLY + | EVENT + | EVENT_RETENTION_MODE + | EXCHANGE + | EXCLUSIVE + | EXECUTABLE + | EXECUTABLE_FILE + | EXPIREDATE + | EXPIRY_DATE + | EXPLICIT + | EXTENDED_LOGICAL_CHECKS + | EXTENSION + | EXTERNAL_ACCESS + | FAIL_OPERATION + | FAILOVER + | FAILOVER_MODE + | FAILURE + | FAILURE_CONDITION_LEVEL + | FAILURECONDITIONLEVEL + | FAN_IN + | FAST_FORWARD + | FILE_FORMAT + | FILE_SNAPSHOT + | FILEGROUP + | FILEGROWTH + | FILENAME + | FILEPATH + | FILESTREAM + | FILESTREAM_ON + | FILTER + | FIRST + | FIRST_NAME + | FLATTEN + | FLOOR + | FMTONLY + | FOLLOWING + | FORCE + | FORCE_FAILOVER_ALLOW_DATA_LOSS + | FORCE_SERVICE_ALLOW_DATA_LOSS + | FORCEPLAN + | FORCESCAN + | FORCESEEK + | FORMAT + | FORWARD_ONLY + | FREE + | FREQUENCY + | FULLSCAN + | FULLTEXT + | GB + | GENERATED + | GET + | GETROOT + | GLOBAL + | GO + | GOVERNOR + | GROUP_MAX_REQUESTS + | GROUPING + | HADR + | HASH + | HASHED + | HEALTH_CHECK_TIMEOUT + | HEALTHCHECKTIMEOUT + | HEAP + | HIDDEN_KEYWORD + | HIERARCHYID + | HIGH + | HONOR_BROKER_PRIORITY + | HOURS + | IDENTIFIER + | IDENTITY_VALUE + | IGNORE_CONSTRAINTS + | IGNORE_DUP_KEY + | IGNORE_REPLICATED_TABLE_CACHE + | IGNORE_TRIGGERS + | IIF + | IMMEDIATE + | IMPERSONATE + | IMPLICIT_TRANSACTIONS + | IMPORTANCE + | INCLUDE + | INCLUDE_NULL_VALUES + | INCREMENT + | INCREMENTAL + | INDEX + | INFINITE + | INIT + | INITIATOR + | INPUT + | INSENSITIVE + | INSERTED + | INSTEAD + | INTERVAL + | IO + | IP + | ISOLATION + | JOB + | JSON + | JSON_ARRAY + | JSON_OBJECT + | KB + | KEEPDEFAULTS + | KEEPIDENTITY + | KERBEROS + | KEY + | KEY_PATH + | KEY_SOURCE + | KEY_STORE_PROVIDER_NAME + | KEYS + | KEYSET + | KWSKIP + | LANGUAGE + | LAST + | LAST_NAME + | LAST_QUERY_ID + | LEAD + | LENGTH + | LEVEL + | LIBRARY + | LIFETIME + | LINKED + | LINUX + | LIST + | LISTAGG + | LISTENER + | LISTENER_IP + | LISTENER_PORT + | LISTENER_URL + | LOB_COMPACTION + | LOCAL + | LOCAL_SERVICE_NAME + | LOCATION + | LOCK + | LOCK_ESCALATION + | LOGIN + | LOOP + | LOW + | MANUAL + | MARK + | MASK + | MASKED + | MASTER + | MATCHED + | MATCHES + | MATERIALIZED + | MAX + | MAX_CONCURRENCY_LEVEL + | MAX_CPU_PERCENT + | MAX_DISPATCH_LATENCY + | MAX_DOP + | MAX_DURATION + | MAX_EVENT_SIZE + | MAX_FILES + | MAX_IOPS_PER_VOLUME + | MAX_MEMORY + | MAX_MEMORY_PERCENT + | MAX_OUTSTANDING_IO_PER_VOLUME + | MAX_PROCESSES + | MAX_QUEUE_READERS + | MAX_ROLLOVER_FILES + | MAX_SIZE + | MAXSIZE + | MAXTRANSFER + | MAXVALUE + | MB + | MEDIADESCRIPTION + | MEDIANAME + | MEDIUM + | MEMBER + | MEMORY_OPTIMIZED_DATA + | MEMORY_PARTITION_MODE + | MESSAGE + | MESSAGE_FORWARD_SIZE + | MESSAGE_FORWARDING + | MIN_CPU_PERCENT + | MIN_IOPS_PER_VOLUME + | MIN_MEMORY_PERCENT + | MINUTES + | MINVALUE + | MIRROR + | MIRROR_ADDRESS + | MIXED_PAGE_ALLOCATION + | MODE + | MODIFY + | MOVE + | MULTI_USER + | MUST_CHANGE + | NAME + | NESTED_TRIGGERS + | NETWORK + | NEW_ACCOUNT + | NEW_BROKER + | NEW_PASSWORD + | NEWNAME + | NEXT + | NEXTVAL + | NO + | NO_CHECKSUM + | NO_COMPRESSION + | NO_EVENT_LOSS + | NO_INFOMSGS + | NO_QUERYSTORE + | NO_STATISTICS + | NO_TRUNCATE + | NOCOUNT + | NODES + | NOEXEC + | NOEXPAND + | NOFORMAT + | NOINDEX + | NOINIT + | NOLOCK + | NON_TRANSACTED_ACCESS + | NONE + | NOORDER + | NORECOMPUTE + | NORECOVERY + | NOREWIND + | NOSKIP + | NOTIFICATION + | NOTIFICATIONS + | NOUNLOAD + | NTILE + | NTLM + | NUMANODE + | NUMERIC_ROUNDABORT + | OBJECT + | OFFLINE + | OFFSET + | OLD_ACCOUNT + | OLD_PASSWORD + | ON_FAILURE + | ONLINE + | ONLY + | OPEN_EXISTING + | OPENJSON + | OPERATIONS + | OPTIMISTIC + | ORDER + | ORGADMIN + | OUT + | OUTBOUND + | OUTPUT + | OVERRIDE + | OWNER + | OWNERSHIP + | PAD_INDEX + | PAGE + | PAGE_VERIFY + | PAGECOUNT + | PAGLOCK + | PARAM_NODE + | PARAMETERIZATION + | PARSEONLY + | PARTIAL + | PARTITION + | PARTITIONS + | PARTNER + | PASSWORD + | PATH + | PATTERN + | PAUSE + | PDW_SHOWSPACEUSED + | PER_CPU + | PER_DB + | PER_NODE + | PERMISSION_SET + | PERSIST_SAMPLE_PERCENT + | PERSISTED + | PHYSICAL_ONLY + | PLATFORM + | POISON_MESSAGE_HANDLING + | POLICY + | POOL + | PORT + | PRECEDING + | PREDICATE + | PRIMARY_ROLE + | PRIOR + | PRIORITY + | PRIORITY_LEVEL + | PRIVATE + | PRIVATE_KEY + | PRIVILEGES + | PROCCACHE + | PROCEDURE_NAME + | PROCESS + | PROFILE + | PROPERTY + | PROVIDER + | PROVIDER_KEY_NAME + | PUBLIC + | QUERY + | QUEUE + | QUEUE_DELAY + | QUOTED_IDENTIFIER + | RANDOMIZED + | RANGE + | RANK + | RC2 + | RC4 + | RC4_128 + | READ_COMMITTED_SNAPSHOT + | READ_ONLY + | READ_ONLY_ROUTING_LIST + | READ_WRITE + | READ_WRITE_FILEGROUPS + | READCOMMITTED + | READCOMMITTEDLOCK + | READONLY + | READPAST + | READUNCOMMITTED + | READWRITE + | REBUILD + | RECEIVE + | RECOVERY + | RECURSIVE + | RECURSIVE_TRIGGERS + | REGENERATE + | REGION + | RELATED_CONVERSATION + | RELATED_CONVERSATION_GROUP + | RELATIVE + | REMOTE + | REMOTE_PROC_TRANSACTIONS + | REMOTE_SERVICE_NAME + | REMOVE + | REORGANIZE + | REPAIR_ALLOW_DATA_LOSS + | REPAIR_FAST + | REPAIR_REBUILD + | REPEATABLE + | REPEATABLEREAD + | REPLACE + | REPLICA + | REPLICATE + | REQUIRED + | RESAMPLE + | RESERVE_DISK_SPACE + | RESET + | RESOURCE + | RESOURCE_MANAGER_LOCATION + | RESOURCES + | RESPECT + | RESTART + | RESTRICT + | RESTRICTED_USER + | RESULT + | RESUMABLE + | RESUME + | RETAINDAYS + | RETENTION + | RETURNS + | REWIND + | RLIKE + | ROLE + | ROOT + | ROUND_ROBIN + | ROUTE + | ROW + | ROWGUID + | ROWLOCK + | ROWS + | RSA_512 + | RSA_1024 + | RSA_2048 + | RSA_3072 + | RSA_4096 + | SAFE + | SAFETY + | SAMPLE + | SCHEDULER + | SCHEMABINDING + | SCHEME + | SCOPED + | SCRIPT + | SCROLL + | SCROLL_LOCKS + | SEARCH + | SECONDARY + | SECONDARY_ONLY + | SECONDARY_ROLE + | SECONDS + | SECRET + | SECURABLES + | SECURITY + | SECURITY_LOG + | SECURITYADMIN + | SEEDING_MODE + | SELF + | SEMI_SENSITIVE + | SEND + | SENT + | SEQUENCE + | SEQUENCE_NUMBER + | SERIALIZABLE + | SERVER + | SERVICE + | SERVICE_BROKER + | SERVICE_NAME + | SERVICEBROKER + | SESSION + | SESSION_TIMEOUT + | SETTINGS + | SHARE + | SHARED + | SHARES + | SHOWCONTIG + | SHOWPLAN + | SHOWPLAN_ALL + | SHOWPLAN_TEXT + | SHOWPLAN_XML + | SHRINKLOG + | SID + | SIGNATURE + | SINGLE_USER + | SIZE + | SNAPSHOT + | SOFTNUMA + | SORT_IN_TEMPDB + | SOURCE + | SP_EXECUTESQL + | SPARSE + | SPATIAL_WINDOW_MAX_CELLS + | SPECIFICATION + | SPLIT + | SQLDUMPERFLAGS + | SQLDUMPERPATH + | SQLDUMPERTIMEOUT + | STAGE + | STANDBY + | START + | START_DATE + | STARTED + | STARTUP_STATE + | STATE + | STATIC + | STATISTICS_INCREMENTAL + | STATISTICS_NORECOMPUTE + | STATS + | STATS_STREAM + | STATUS + | STATUSONLY + | STOP + | STOP_ON_ERROR + | STOPLIST + | STOPPED + | SUBJECT + | SUBSCRIBE + | SUBSCRIPTION + | SUBSTRING + | SUPPORTED + | SUSPEND + | SWITCH + | SYMMETRIC + | SYNCHRONOUS_COMMIT + | SYNONYM + | SYSADMIN + | SYSTEM + | TABLE + | TABLERESULTS + | TABLOCK + | TABLOCKX + | TAG + | TAGS + | TAKE + | TAPE + | TARGET + | TARGET_LAG + | TARGET_RECOVERY_TIME + | TB + | TCP + | TEMP + | TEXTIMAGE_ON + | THROW + | TIES + | TIMEOUT + | TIMER + | TIMESTAMP + | TIMEZONE + | TORN_PAGE_DETECTION + | TOSTRING + | TRACE + | TRACK_CAUSALITY + | TRACKING + | TRANSACTION_ID + | TRANSFER + | TRANSFORM_NOISE_WORDS + | TRIPLE_DES + | TRIPLE_DES_3KEY + | TRUSTWORTHY + | TRY + | TRY_CAST + | TSQL + | TWO_DIGIT_YEAR_CUTOFF + | TYPE + | TYPE_WARNING + | UNBOUNDED + | UNCHECKED + | UNCOMMITTED + | UNLIMITED + | UNLOCK + | UNMASK + | UNSAFE + | UOW + | UPDLOCK + | URL + | USED + | USERADMIN + | VALID_XML + | VALIDATION + | VALUE + | VAR + | VERBOSELOGGING + | VERIFY_CLONEDB + | VERSION + | VIEW_METADATA + | VISIBILITY + | WAIT + | WAIT_AT_LOW_PRIORITY + | WAREHOUSE + | WAREHOUSE_TYPE + | WELL_FORMED_XML + | WINDOWS + | WITHOUT + | WITHOUT_ARRAY_WRAPPER + | WITNESS + | WORK + | WORKLOAD + | XACT_ABORT + | XLOCK + | XML + | XML_COMPRESSION + | XMLDATA + | XMLNAMESPACES + | XMLSCHEMA + | XSINIL + | ZONE + ; \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/jinja.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/jinja.g4 new file mode 100644 index 0000000000..e2d4b20bac --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/jinja.g4 @@ -0,0 +1,18 @@ +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignColons hanging +// $antlr-format columnLimit 150 +// $antlr-format alignSemicolons hanging +// $antlr-format alignTrailingComments true +// ================================================================================= + +parser grammar jinja; + +// Jinja template elements can occur anywhere in the text and so this rule can be used strategically to allow +// what otherwise would be a syntax error to be parsed as a Jinja template. As it is essentially one +// token, ANTLR alts will predict easilly. For isntance on lists of expressions, we need to allow +// JINJA templates without separating COMMAs etc. +jinjaTemplate: JINJA_REF+ + ; \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/procedure.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/procedure.g4 new file mode 100644 index 0000000000..2a96bc30c9 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/lib/procedure.g4 @@ -0,0 +1,121 @@ +/* +Universal grammar for SQL stored procedure declarations + */ + + + + +parser grammar procedure; + +// Snowflake specific +// TODO: Reconcile with TSQL and SQL/PSM + +alterProcedure: + ALTER PROCEDURE (IF EXISTS)? id LPAREN dataTypeList? RPAREN RENAME TO id + | ALTER PROCEDURE (IF EXISTS)? id LPAREN dataTypeList? RPAREN SET ( + COMMENT EQ string + ) + | ALTER PROCEDURE (IF EXISTS)? id LPAREN dataTypeList? RPAREN UNSET COMMENT + | ALTER PROCEDURE (IF EXISTS)? id LPAREN dataTypeList? RPAREN EXECUTE AS ( + CALLER + | OWNER + ) +; + +createProcedure: + CREATE (OR REPLACE)? PROCEDURE dotIdentifier LPAREN ( + procArgDecl (COMMA procArgDecl)* + )? RPAREN RETURNS (dataType | table) (NOT? NULL)? LANGUAGE id ( + CALLED ON NULL INPUT + | RETURNS NULL ON NULL INPUT + | STRICT + )? (VOLATILE | IMMUTABLE)? // Note: VOLATILE and IMMUTABLE are deprecated. + (COMMENT EQ string)? executeAs? AS procedureDefinition + | CREATE (OR REPLACE)? SECURE? PROCEDURE dotIdentifier LPAREN ( + procArgDecl (COMMA procArgDecl)* + )? RPAREN RETURNS dataType (NOT? NULL)? LANGUAGE id ( + CALLED ON NULL INPUT + | RETURNS NULL ON NULL INPUT + | STRICT + )? (VOLATILE | IMMUTABLE)? // Note: VOLATILE and IMMUTABLE are deprecated. + (COMMENT EQ string)? executeAs? AS procedureDefinition + | CREATE (OR REPLACE)? SECURE? PROCEDURE dotIdentifier LPAREN ( + procArgDecl (COMMA procArgDecl)* + )? RPAREN RETURNS (dataType (NOT? NULL)? | table) LANGUAGE id RUNTIME_VERSION EQ string ( + IMPORTS EQ LPAREN stringList RPAREN + )? PACKAGES EQ LPAREN stringList RPAREN HANDLER EQ string + // ( CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT )? + // ( VOLATILE | IMMUTABLE )? // Note: VOLATILE and IMMUTABLE are deprecated. + (COMMENT EQ string)? executeAs? AS procedureDefinition +; + +procArgDecl: id dataType (DEFAULT expr)?; + +dropProcedure: + DROP PROCEDURE (IF EXISTS)? dotIdentifier ( + COMMA dotIdentifier + )* (LPAREN ( dataType (COMMA dataType)*)? RPAREN)? SEMI? +; + +procedureDefinition: + DOLLAR_STRING + | declareCommand? BEGIN procStatement* END SEMI? +; + +assign: + LET? id (dataType | RESULTSET)? (ASSIGN | DEFAULT) expr SEMI # assignVariable + | LET? id CURSOR FOR (selectStatement | id) SEMI # assignCursor +; + +// TSQL + +createOrAlterProcedure: ( + (CREATE (OR (ALTER | REPLACE))?) + | ALTER + ) PROCEDURE dotIdentifier (SEMI INT)? ( + LPAREN? procedureParam (COMMA procedureParam)* RPAREN? + )? (WITH procedureOption (COMMA procedureOption)*)? ( + FOR REPLICATION + )? AS (EXTERNAL NAME dotIdentifier | sqlClauses*) +; + +procedureParamDefaultValue: + NULL + | DEFAULT + | constant + | LOCAL_ID +; + +procedureParam: + LOCAL_ID AS? (id DOT)? dataType VARYING? ( + EQ procedureParamDefaultValue + )? (OUT | OUTPUT | READONLY)? +; + +procedureOption: executeAs | genericOption; + +procStatement: + declareCommand + | assign + | returnStatement + | sqlClauses +; + +// ----------------------------------------------------------- +// SQL/PSM is for Spark, GoogleSQL, mySQL, Teradata, DB2 +// The SQL/PSM standard is extended here to cover TSQL, Snowflake and other dialects in the future + +returnStatement: RETURN expr SEMI; + +// see https://docs.snowflake.com/en/sql-reference/snowflake-scripting/declare +declareCommand: DECLARE declareElement+; + +declareElement: + id dataType SEMI # declareSimple + | id dataType (DEFAULT | EQ) expr SEMI # declareWithDefault + | id CURSOR FOR expr SEMI # declareCursorElement + | id RESULTSET ((ASSIGN | DEFAULT) expr)? SEMI # declareResultSet + | id EXCEPTION LPAREN INT COMMA string RPAREN SEMI # declareException +; + +// TODO: Complete definition of SQL/PSM rules \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/preprocessor/DBTPreprocessorLexer.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/preprocessor/DBTPreprocessorLexer.g4 new file mode 100644 index 0000000000..d692a5bca4 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/preprocessor/DBTPreprocessorLexer.g4 @@ -0,0 +1,240 @@ +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignColons hanging +// $antlr-format columnLimit 150 +// $antlr-format alignSemicolons hanging +// $antlr-format alignTrailingComments true +// ================================================================================= + +lexer grammar DBTPreprocessorLexer; + +tokens { + STRING +} + +options { + caseInsensitive = true; +} + +@members { + /** + * Defines the configuration for the preprocessor, such as Jinja templating delimiters and + * any DBT parameters that are relevant to us. + */ + public class Config { + private String exprStart; + private String exprEnd; + private String statStart; + private String statEnd; + private String commentStart; + private String commentEnd; + private String lineStatStart; + + // Standard defaults for Jinja templating + public Config() { + this("{{", "}}", "{%", "%}", "{#", "#}", "#"); + } + + public Config(String exprStart, String exprEnd, String statStart, String statEnd, String commentStart, String commentEnd, String lineStatStart) { + this.exprStart = exprStart; + this.exprEnd = exprEnd; + this.statStart = statStart; + this.statEnd = statEnd; + this.commentStart = commentStart; + this.commentEnd = commentEnd; + this.lineStatStart = lineStatStart; + } + + // Getters + public String exprStart() { + return exprStart; + } + + public String exprEnd() { + return exprEnd; + } + + public String statStart() { + return statStart; + } + + public String statEnd() { + return statEnd; + } + + public String commentStart() { + return commentStart; + } + + public String commentEnd() { + return commentEnd; + } + + public String lineStatStart() { + return lineStatStart; + } + } + + public Config config = new Config(); + + /** + * Our template lexer rules only consume a single character, even when the sequence is longer than + * one character. So we we need to advance the input past the matched sequence. + */ + private void scanPast(String str) { + int index = _input.index(); + + // If there was a preceding hyphen such as -%}, we need to move past that as well + if (_input.LA(-1) == '-') { + index++; + } + _input.seek(index + str.length() - 1); + + // If there is a trailing hyphen suchas {%- then we need to scan past that as well + if (_input.LA(1) == '-') { + _input.consume(); + } + } + + /** + * Called when a single character is matched to see if it and the next sequence of characters + * match the current configured chracters that start a Ninja statement templage + */ + private boolean matchAndConsume(String str) { + // Check if the first character matches otherwise we would accidentally accept anything for + // a single character marker such as LineStatStart + if (str.charAt(0) != _input.LA(-1)) { + return false; + } + for (int i = 1; i < str.length(); i++) { + if (str.charAt(i) != _input.LA(1)) { + return false; + } + // Move to next character + _input.consume(); + } + + // All characters matched, return true + return true; + } + + private boolean isStatement() { + if( matchAndConsume(config.statStart())) { + // There may be a trailing hyphen that is part of the statement start + if (_input.LA(1) == '-') { + _input.consume(); + } + return true; + } + return false; + } + + private boolean isExpression() { + if (matchAndConsume(config.exprStart())) { + // There may be a trailing hyphen that is part of the expression start + if (_input.LA(1) == '-') { + _input.consume(); + } + return true; + } + return false; + } + + private boolean isComment() { + return matchAndConsume(config.commentStart()); + } + + // Note that this is not qute correct yet as we must check that this is the + // the first non-whitespace character on the line as well. + private boolean isLineStat() { + return matchAndConsume(config.lineStatStart()); + } + + private boolean isStatementEnd() { + // There may be a preceding hyphen that is part of the statement end + int index = _input.index(); + if (_input.LA(-1) == '-') { + _input.consume(); + } + if (matchAndConsume(config.statEnd())) { + return true; + } + // Return to the start of the statement, any hyphen wwas just that and + // not part of the token + _input.seek(index); + return false; + } + + private boolean isExpresionEnd() { + // There may be a preceding hyphen that is part of the expression end + int index = _input.index(); + if (_input.LA(-1) == '-') { + _input.consume(); + } + if (matchAndConsume(config.exprEnd())) { + return true; + } + // Return to the start of the expression, any hyphen wwas just that and + // not part of the token + _input.seek(index); + return false; + } + + private boolean isCommentEnd() { + return matchAndConsume(config.commentEnd()); + } +} + +STATEMENT: . { scanPast(config.statStart); } { isStatement() }? -> pushMode(statementMode ) + ; +EXPRESSION: . { scanPast(config.exprStart); } { isExpression() }? -> pushMode(expressionMode ) + ; +COMMENT: . { scanPast(config.commentStart); } { isComment() }? -> pushMode(commentMode ) + ; +LINESTAT: . { scanPast(config.lineStatStart); } { isLineStat() }? -> pushMode(lineStatMode ) + ; + +// We track whitespace as it can influence how a template following or not following +// whitespace is replaced in the processed output. Similarly, we want to know if whitespace +// followed the template reference or not +WS: [ \t]+ + ; + +C: . + ; + +mode statementMode; + +STATEMENT_STRING: '\'' ('\\' . | ~["\n])* '\'' -> type(STRING) + ; +STATEMENT_END: . { scanPast(config.statEnd); } { isStatementEnd() }? -> popMode + ; +STATMENT_BIT: . -> type(C) + ; + +mode expressionMode; + +EXPRESSION_STRING: '\'' ('\\' . | ~['\n])* '\'' -> type(STRING) + ; +EXPRESSION_END: . { scanPast(config.exprEnd); } { isExpresionEnd() }? -> popMode + ; +EXPRESSION_BIT: . -> type(C) + ; + +mode commentMode; +COMMENT_STRING: '\'' ('\\' . | ~['\n])* '\'' -> type(STRING) + ; +COMMENT_END: . { scanPast(config.commentEnd); } { isCommentEnd() }? -> popMode + ; +COMMENT_BIT: . -> type(C) + ; + +mode lineStatMode; +LINESTAT_STRING: '\'' ('\\' . | ~['\n])* '\'' -> type(STRING) + ; +LINESTAT_END: '\r'? '\n' -> popMode + ; +LINESTAT_BIT: . -> type(C) + ; \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/.gitignore b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/.gitignore new file mode 100644 index 0000000000..4f62b849d5 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/.gitignore @@ -0,0 +1 @@ +gen diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeLexer.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeLexer.g4 new file mode 100644 index 0000000000..ff94fa69fa --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeLexer.g4 @@ -0,0 +1,171 @@ +/* +Snowflake Database grammar. +The MIT License (MIT). + +Copyright (c) 2022, Michał Lorek. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignTrailingComments true +// $antlr-format columnLimit 150 +// $antlr-format maxEmptyLinesToKeep 1 +// $antlr-format reflowComments false +// $antlr-format useTab false +// $antlr-format allowShortRulesOnASingleLine true +// $antlr-format allowShortBlocksOnASingleLine true +// $antlr-format minEmptyLines 0 +// $antlr-format alignSemicolons ownLine +// $antlr-format alignColons trailing +// $antlr-format singleLineOverrulesHangingColon true +// $antlr-format alignLexerCommands true +// $antlr-format alignLabels true +// $antlr-format alignTrailers true +// ================================================================================= +lexer grammar SnowflakeLexer; + +import commonlex; + +tokens { + STRING_CONTENT +} + +options { + caseInsensitive = true; +} + +@members { + private static int TSQL_DIALECT = 1; + private static int SNOWFLAKE_DIALECT = 2; + private static int dialect = SNOWFLAKE_DIALECT; +} + +INT : [0-9]+; +FLOAT : DEC_DOT_DEC; +REAL : (INT | DEC_DOT_DEC) 'E' [+-]? [0-9]+; + +BANG : '!'; +ARROW : '->'; +ASSOC : '=>'; + +NE : '!='; +LTGT : '<>'; +EQ : '='; +GT : '>'; +GE : '>='; +LT : '<'; +LE : '<='; + +PIPE_PIPE : '||'; +DOT : '.'; +AT : '@'; +DOLLAR : '$'; +LPAREN : '('; +RPAREN : ')'; +LSB : '['; +RSB : ']'; +LCB : '{'; +RCB : '}'; +COMMA : ','; +SEMI : ';'; +COLON : ':'; +COLON_COLON : '::'; +STAR : '*'; +DIVIDE : '/'; +MODULE : '%'; +PLUS : '+'; +MINUS : '-'; +TILDA : '~'; +AMP : '&'; + +// A question mark can be used as a placeholder for a prepared statement that will use binding. +PARAM: '?'; + +SQLCOMMAND: + '!' SPACE? ( + 'abort' + | 'connect' + | 'define' + | 'edit' + | 'exit' + | 'help' + | 'options' + | 'pause' + | 'print' + | 'queries' + | 'quit' + | 'rehash' + | 'result' + | 'set' + | 'source' + | 'spool' + | 'system' + | 'variables' + ) ~[\r\n]* +; + +// Parameters +LOCAL_ID: DOLLAR ID; + +STRING_START: '\'' -> pushMode(stringMode); + +// ================================================================================================ +// LEXICAL MODES +// +// Lexical modes are used to allow the lexer to return different token types than the main lexer +// and are triggered by a main lexer rule matching a specific token. The mode is ended by matching +// a specific lexical sequence in the input stream. Note that this is a lexical trigger only and is +// not influenced by the parser state as the paresr does NOT direct the lexer: +// +// 1) The lexer runs against the entire input sequence and returns tokens to the parser. +// 2) THEN the parser uses the tokens to build the parse tree - it cannot therefore influence the +// lexer in any way. + +// In string mode we are separating out normal string literals from defined variable +// references, so that they can be translated from Snowflakey syntax to Databricks SQL syntax. +// This mode is trigered when we hit a single quote in the lexer and ends when we hit the +// terminating single quote minus the usual escape character processing. +mode stringMode; + +// An element that is a variable reference can be &{variable} or just &variable. They are +// separated out in case there is any difference needed in translation/generation. A single +// & is placed in a string by using &&. + +// We exit the stringMode when we see the terminating single quote. +// +STRING_END: '\'' -> popMode; + +STRING_AMP : '&' -> type(STRING_CONTENT); +STRING_AMPAMP : '&&'; + +// Note that snowflake also allows $var, and :var +// if they are allowed in literal strings// then they can be added here. +// +VAR_SIMPLE : '&' [A-Z_] [A-Z0-9_]*; +VAR_COMPLEX : '&{' [A-Z_] [A-Z0-9_]* '}'; + +// TODO: Do we also need \xHH hex and \999 octal escapes? +STRING_UNICODE : '\\' 'u' HexDigit HexDigit HexDigit HexDigit; +STRING_ESCAPE : '\\' .; +STRING_SQUOTE : '\'\''; + +// Anything that is not a variable reference is just a normal piece of text. +STRING_BIT: ~['\\&]+ -> type(STRING_CONTENT); \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 new file mode 100644 index 0000000000..038ee586ef --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 @@ -0,0 +1,3407 @@ +/* +Snowflake Database grammar. +The MIT License (MIT). + +Copyright (c) 2022, Michał Lorek. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md +// $antlr-format alignColons hanging +// $antlr-format columnLimit 150 +// $antlr-format alignSemicolons hanging +// $antlr-format alignTrailingComments true +// ================================================================================= +parser grammar SnowflakeParser; + +import procedure, commonparse, jinja; + +options { + tokenVocab = SnowflakeLexer; +} + +// ============== Dialect compatibiltiy rules ============== +// The following rules provide substitutes for grammar rules referenced in the procedure.g4 grammar, that +// we do not have real equivalents for in this gramamr. +// Over time, as we homogonize more and more of the dialect grammars, these rules will be removed. + +// TODO: We will move genericOption into a parsercommon.g4 and reference it in all dialects for common option processing +// For now, this rule is not used within the Snowflake rules, but it will be +genericOption: ID EQ (string | INT | trueFalse | jsonLiteral) + ; + +// ====================================================== + +snowflakeFile: SEMI* batch? EOF + ; + +batch: (sqlClauses SEMI*)+ + ; + +sqlClauses + : ddlCommand + | dmlCommand + | showCommand + | useCommand + | describeCommand + | otherCommand + | snowSqlCommand + ; + +ddlCommand: alterCommand | createCommand | dropCommand | undropCommand + ; + +dmlCommand + : queryStatement + | insertStatement + | insertMultiTableStatement + | updateStatement + | deleteStatement + | mergeStatement + ; + +insertStatement + : INSERT OVERWRITE? INTO dotIdentifier (LPAREN ids += id (COMMA ids += id)* RPAREN)? ( + valuesTableBody + | queryStatement + ) + ; + +insertMultiTableStatement + : INSERT OVERWRITE? ALL intoClause2 + | INSERT OVERWRITE? (FIRST | ALL) (WHEN searchCondition THEN intoClause2+)+ (ELSE intoClause2)? subquery + ; + +intoClause2: INTO dotIdentifier (LPAREN columnList RPAREN)? valuesList? + ; + +valuesList: VALUES LPAREN valueItem (COMMA valueItem)* RPAREN + ; + +valueItem: columnName | DEFAULT | NULL + ; + +mergeStatement: MERGE INTO tableRef USING tableSource ON searchCondition mergeCond + ; + +mergeCond: (mergeCondMatch | mergeCondNotMatch)+ + ; + +mergeCondMatch: (WHEN MATCHED (AND searchCondition)? THEN mergeUpdateDelete) + ; + +mergeCondNotMatch: WHEN NOT MATCHED (AND searchCondition)? THEN mergeInsert + ; + +mergeUpdateDelete: UPDATE SET setColumnValue (COMMA setColumnValue)* | DELETE + ; + +mergeInsert: INSERT ((LPAREN columnList RPAREN)? VALUES LPAREN exprList RPAREN)? + ; + +updateStatement + : UPDATE tableRef SET setColumnValue (COMMA setColumnValue)* (FROM tableSources)? ( + WHERE searchCondition + )? + ; + +setColumnValue: columnName EQ expr + ; + +tableRef: dotIdentifier asAlias? + ; + +tableOrQuery: tableRef | LPAREN subquery RPAREN asAlias? + ; + +tablesOrQueries: tableOrQuery (COMMA tableOrQuery)* + ; + +deleteStatement: DELETE FROM tableRef (USING tablesOrQueries)? (WHERE searchCondition)? + ; + +otherCommand + : copyIntoTable + | copyIntoLocation + | comment + | commit + | executeImmediate + | executeTask + | explain + | getDml + | grantOwnership + | grantToRole + | grantToShare + | grantRole + | list + | put + | remove + | revokeFromRole + | revokeFromShare + | revokeRole + | rollback + | set + | truncateMaterializedView + | truncateTable + | unset + | call + | beginTxn + | declareCommand + | let + ; + +snowSqlCommand: SQLCOMMAND + ; + +beginTxn: BEGIN (WORK | TRANSACTION)? (NAME id)? | START TRANSACTION ( NAME id)? + ; + +copyIntoTable + : COPY INTO dotIdentifier FROM (tableStage | userStage | namedStage | externalLocation) files? pattern? fileFormat? copyOptions* ( + VALIDATION_MODE EQ (RETURN_N_ROWS | RETURN_ERRORS | RETURN_ALL_ERRORS) + )? + // + /* Data load with transformation */ + | COPY INTO dotIdentifier (LPAREN columnList RPAREN)? FROM LPAREN SELECT selectList FROM ( + tableStage + | userStage + | namedStage + ) RPAREN files? pattern? fileFormat? copyOptions* + ; + +externalLocation + : string + //(for Amazon S3) + //'s3://[/]' + // ( ( STORAGE_INTEGRATION EQ id_ )? + // | ( CREDENTIALS EQ LPAREN ( AWS_KEY_ID EQ string AWS_SECRET_KEY EQ string ( AWS_TOKEN EQ string )? ) RPAREN )? + // )? + // [ ENCRYPTION = ( [ TYPE = 'AWS_CSE' ] [ MASTER_KEY = '' ] | + // [ TYPE = 'AWS_SSE_S3' ] | + // [ TYPE = 'AWS_SSE_KMS' [ KMS_KEY_ID = '' ] ] | + // [ TYPE = 'NONE' ] ) ] + // + // (for Google Cloud Storage) + //'gcs://[/]' + // ( STORAGE_INTEGRATION EQ id_ )? + //[ ENCRYPTION = ( [ TYPE = 'GCS_SSE_KMS' ] [ KMS_KEY_ID = '' ] | [ TYPE = 'NONE' ] ) ] + // + // (for Microsoft Azure) + //'azure://.blob.core.windows.net/[/]' + // ( ( STORAGE_INTEGRATION EQ id_ )? + // | ( CREDENTIALS EQ LPAREN ( AZURE_SAS_TOKEN EQ string ) RPAREN ) + // )? + //[ ENCRYPTION = ( [ TYPE = { 'AZURE_CSE' | 'NONE' } ] [ MASTER_KEY = '' ] ) ] + ; + +files: FILES EQ LPAREN string (COMMA string)* RPAREN + ; + +fileFormat: FILE_FORMAT EQ LPAREN (formatName | formatType) RPAREN + ; + +formatName: FORMAT_NAME EQ string + ; + +formatType: TYPE EQ typeFileformat formatTypeOptions* + ; + +let + // variable and resultset are covered under the same visitor since expr is common + : LET id (dataType | RESULTSET)? (ASSIGN | DEFAULT) expr SEMI # letVariableAssignment + | LET id CURSOR FOR (selectStatement | id) SEMI # letCursor + ; + +stageFileFormat + : STAGE_FILE_FORMAT EQ LPAREN FORMAT_NAME EQ string + | TYPE EQ typeFileformat formatTypeOptions+ RPAREN + ; + +copyIntoLocation + : COPY INTO (tableStage | userStage | namedStage | externalLocation) FROM ( + dotIdentifier + | LPAREN queryStatement RPAREN + ) partitionBy? fileFormat? copyOptions? (VALIDATION_MODE EQ RETURN_ROWS)? HEADER? + ; + +comment + : COMMENT (IF EXISTS)? ON objectTypeName dotIdentifier functionSignature? IS string + | COMMENT (IF EXISTS)? ON COLUMN dotIdentifier IS string + ; + +functionSignature: LPAREN dataTypeList? RPAREN + ; + +commit: COMMIT WORK? + ; + +executeImmediate + : EXECUTE IMMEDIATE (string | id | LOCAL_ID) (USING LPAREN id (COMMA id)* RPAREN)? + ; + +executeTask: EXECUTE TASK dotIdentifier + ; + +explain: EXPLAIN (USING (TABULAR | JSON | id))? sqlClauses + ; + +parallel: PARALLEL EQ INT + ; + +getDml: GET (namedStage | userStage | tableStage) string parallel? pattern? + ; + +grantOwnership + : GRANT OWNERSHIP ( + ON ( + objectTypeName dotIdentifier + | ALL objectTypePlural IN ( DATABASE id | SCHEMA schemaName) + ) + | ON FUTURE objectTypePlural IN ( DATABASE id | SCHEMA schemaName) + ) TO ROLE id (( REVOKE | COPY) CURRENT GRANTS)? + ; + +grantToRole + : GRANT ( + ( globalPrivileges | ALL PRIVILEGES?) ON ACCOUNT + | (accountObjectPrivileges | ALL PRIVILEGES?) ON ( + USER + | RESOURCE MONITOR + | WAREHOUSE + | DATABASE + | INTEGRATION + ) dotIdentifier + | (schemaPrivileges | ALL PRIVILEGES?) ON (SCHEMA schemaName | ALL SCHEMAS IN DATABASE id) + | ( schemaPrivileges | ALL PRIVILEGES?) ON FUTURE SCHEMAS IN DATABASE id + | (schemaObjectPrivileges | ALL PRIVILEGES?) ON ( + objectType dotIdentifier + | ALL objectTypePlural IN ( DATABASE id | SCHEMA schemaName) + ) + | (schemaObjectPrivileges | ALL PRIVILEGES?) ON FUTURE objectTypePlural IN ( + DATABASE id + | SCHEMA schemaName + ) + ) TO ROLE? id (WITH GRANT OPTION)? + ; + +globalPrivileges: globalPrivilege (COMMA globalPrivilege)* + ; + +globalPrivilege + : CREATE ( + ACCOUNT + | DATA EXCHANGE LISTING + | DATABASE + | INTEGRATION + | NETWORK POLICY + | ROLE + | SHARE + | USER + | WAREHOUSE + ) + | ( + APPLY MASKING POLICY + | APPLY ROW ACCESS POLICY + | APPLY SESSION POLICY + | APPLY TAG + | ATTACH POLICY + ) + | ( + EXECUTE TASK + | IMPORT SHARE + | MANAGE GRANTS + | MONITOR ( EXECUTION | USAGE) + | OVERRIDE SHARE RESTRICTIONS + ) + ; + +accountObjectPrivileges: accountObjectPrivilege (COMMA accountObjectPrivilege)* + ; + +accountObjectPrivilege + : MONITOR + | MODIFY + | USAGE + | OPERATE + | CREATE SCHEMA + | IMPORTED PRIVILEGES + | USE_ANY_ROLE + ; + +schemaPrivileges: schemaPrivilege (COMMA schemaPrivilege)* + ; + +schemaPrivilege + : MODIFY + | MONITOR + | USAGE + | CREATE ( + TABLE + | EXTERNAL TABLE + | VIEW + | MATERIALIZED VIEW + | MASKING POLICY + | ROW ACCESS POLICY + | SESSION POLICY + | TAG + | SEQUENCE + | FUNCTION + | PROCEDURE + | FILE FORMAT + | STAGE + | PIPE + | STREAM + | TASK + ) + | ADD SEARCH OPTIMIZATION + ; + +schemaObjectPrivileges: schemaObjectPrivilege (COMMA schemaObjectPrivilege)* + ; + +schemaObjectPrivilege + : SELECT + | INSERT + | UPDATE + | DELETE + | TRUNCATE + | REFERENCES + | USAGE + | READ (COMMA WRITE)? + | MONITOR + | OPERATE + | APPLY + ; + +grantToShare + : GRANT objectPrivilege ON ( + DATABASE id + | SCHEMA id + | FUNCTION id + | ( TABLE dotIdentifier | ALL TABLES IN SCHEMA schemaName) + | VIEW id + ) TO SHARE id + ; + +objectPrivilege: USAGE | SELECT | REFERENCE_USAGE + ; + +grantRole: GRANT ROLE roleName TO (ROLE roleName | USER id) + ; + +roleName: systemDefinedRole | id + ; + +systemDefinedRole: ORGADMIN | ACCOUNTADMIN | SECURITYADMIN | USERADMIN | SYSADMIN | PUBLIC + ; + +list: LIST (userStage | tableStage | namedStage) pattern? + ; + +// @~[/] +userStage: AT TILDA stagePath? + ; + +// @[.]%[/] +tableStage: AT schemaName? MODULE id stagePath? + ; + +// @[.][/] +namedStage: AT dotIdentifier stagePath? + ; + +stagePath: DIVIDE (ID (DIVIDE ID)* DIVIDE?)? + ; + +put + : PUT string (tableStage | userStage | namedStage) (PARALLEL EQ INT)? ( + AUTO_COMPRESS EQ trueFalse + )? ( + SOURCE_COMPRESSION EQ ( + AUTO_DETECT + | GZIP + | BZ2 + | BROTLI + | ZSTD + | DEFLATE + | RAW_DEFLATE + | NONE + ) + )? (OVERWRITE EQ trueFalse)? + ; + +remove: REMOVE (tableStage | userStage | namedStage) pattern? + ; + +revokeFromRole + : REVOKE (GRANT OPTION FOR)? ( + ( globalPrivilege | ALL PRIVILEGES?) ON ACCOUNT + | (accountObjectPrivileges | ALL PRIVILEGES?) ON ( + RESOURCE MONITOR + | WAREHOUSE + | DATABASE + | INTEGRATION + ) dotIdentifier + | (schemaPrivileges | ALL PRIVILEGES?) ON (SCHEMA schemaName | ALL SCHEMAS IN DATABASE id) + | (schemaPrivileges | ALL PRIVILEGES?) ON (FUTURE SCHEMAS IN DATABASE ) + | (schemaObjectPrivileges | ALL PRIVILEGES?) ON ( + objectType dotIdentifier + | ALL objectTypePlural IN SCHEMA schemaName + ) + | (schemaObjectPrivileges | ALL PRIVILEGES?) ON FUTURE objectTypePlural IN ( + DATABASE id + | SCHEMA schemaName + ) + ) FROM ROLE? id cascadeRestrict? + ; + +revokeFromShare + : REVOKE objectPrivilege ON ( + DATABASE id + | SCHEMA schemaName + | ( TABLE dotIdentifier | ALL TABLES IN SCHEMA schemaName) + | ( VIEW dotIdentifier | ALL VIEWS IN SCHEMA schemaName) + ) FROM SHARE id + ; + +revokeRole: REVOKE ROLE roleName FROM (ROLE roleName | USER id) + ; + +rollback: ROLLBACK WORK? + ; + +set: SET id EQ expr | SET LPAREN id (COMMA id)* RPAREN EQ LPAREN expr (COMMA expr)* RPAREN + ; + +truncateMaterializedView: TRUNCATE MATERIALIZED VIEW dotIdentifier + ; + +truncateTable: TRUNCATE TABLE? (IF EXISTS)? dotIdentifier + ; + +unset: UNSET id | UNSET LPAREN id (COMMA id)* RPAREN + ; + +// alter commands +alterCommand + : alterAccount + | alterAlert + | alterApiIntegration + | alterConnection + | alterDatabase + | alterDynamicTable + //| alterEventTable // uses ALTER TABLE stmt + | alterExternalTable + | alterFailoverGroup + | alterFileFormat + | alterFunction + | alterMaskingPolicy + | alterMaterializedView + | alterNetworkPolicy + | alterNotificationIntegration + | alterPipe + | alterProcedure + | alterReplicationGroup + | alterResourceMonitor + | alterRole + | alterRowAccessPolicy + | alterSchema + | alterSecurityIntegrationExternalOauth + | alterSecurityIntegrationSnowflakeOauth + | alterSecurityIntegrationSaml2 + | alterSecurityIntegrationScim + | alterSequence + | alterSession + | alterSessionPolicy + | alterShare + | alterStage + | alterStorageIntegration + | alterStream + | alterTable + | alterTableAlterColumn + | alterTag + | alterTask + | alterUser + | alterView + | alterWarehouse + ; + +accountParams + : ALLOW_ID_TOKEN EQ trueFalse + | CLIENT_ENCRYPTION_KEY_SIZE EQ INT + | ENFORCE_SESSION_POLICY EQ trueFalse + | EXTERNAL_OAUTH_ADD_PRIVILEGED_ROLES_TO_BLOCKED_LIST EQ trueFalse + | INITIAL_REPLICATION_SIZE_LIMIT_IN_TB EQ INT + | NETWORK_POLICY EQ string + | PERIODIC_DATA_REKEYING EQ trueFalse + | PREVENT_UNLOAD_TO_INLINE_URL EQ trueFalse + | PREVENT_UNLOAD_TO_INTERNAL_STAGES EQ trueFalse + | REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_CREATION EQ trueFalse + | REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_OPERATION EQ trueFalse + | SAML_IDENTITY_PROVIDER EQ jsonLiteral + | SESSION_POLICY EQ string + | SSO_LOGIN_PAGE EQ trueFalse + ; + +objectParams + : DATA_RETENTION_TIME_IN_DAYS EQ INT + | MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT + | defaultDdlCollation + | MAX_CONCURRENCY_LEVEL EQ INT + | NETWORK_POLICY EQ string + | PIPE_EXECUTION_PAUSED EQ trueFalse + | SESSION_POLICY EQ string + | STATEMENT_QUEUED_TIMEOUT_IN_SECONDS EQ INT + | STATEMENT_TIMEOUT_IN_SECONDS EQ INT + ; + +defaultDdlCollation: DEFAULT_DDL_COLLATION_ EQ string + ; + +objectProperties + : PASSWORD EQ string + | LOGIN_NAME EQ string + | DISPLAY_NAME EQ string + | FIRST_NAME EQ string + | MIDDLE_NAME EQ string + | LAST_NAME EQ string + | EMAIL EQ string + | MUST_CHANGE_PASSWORD EQ trueFalse + | DISABLED EQ trueFalse + | DAYS_TO_EXPIRY EQ INT + | MINS_TO_UNLOCK EQ INT + | DEFAULT_WAREHOUSE EQ string + | DEFAULT_NAMESPACE EQ string + | DEFAULT_ROLE EQ string + //| DEFAULT_SECONDARY_ROLES EQ LPAREN 'ALL' RPAREN + | MINS_TO_BYPASS_MFA EQ INT + | RSA_PUBLIC_KEY EQ string + | RSA_PUBLIC_KEY_2 EQ string + | (COMMENT EQ string) + ; + +sessionParams + : ABORT_DETACHED_QUERY EQ trueFalse + | AUTOCOMMIT EQ trueFalse + | BINARY_INPUT_FORMAT EQ string + | BINARY_OUTPUT_FORMAT EQ string + | DATE_INPUT_FORMAT EQ string + | DATE_OUTPUT_FORMAT EQ string + | ERROR_ON_NONDETERMINISTIC_MERGE EQ trueFalse + | ERROR_ON_NONDETERMINISTIC_UPDATE EQ trueFalse + | JSON_INDENT EQ INT + | LOCK_TIMEOUT EQ INT + | QUERY_TAG EQ string + | ROWS_PER_RESULTSET EQ INT + | SIMULATED_DATA_SHARING_CONSUMER EQ string + | STATEMENT_TIMEOUT_IN_SECONDS EQ INT + | STRICT_JSON_OUTPUT EQ trueFalse + | TIMESTAMP_DAY_IS_ALWAYS_24H EQ trueFalse + | TIMESTAMP_INPUT_FORMAT EQ string + | TIMESTAMP_LTZ_OUTPUT_FORMAT EQ string + | TIMESTAMP_NTZ_OUTPUT_FORMAT EQ string + | TIMESTAMP_OUTPUT_FORMAT EQ string + | TIMESTAMP_TYPE_MAPPING EQ string + | TIMESTAMP_TZ_OUTPUT_FORMAT EQ string + | TIMEZONE EQ string + | TIME_INPUT_FORMAT EQ string + | TIME_OUTPUT_FORMAT EQ string + | TRANSACTION_DEFAULT_ISOLATION_LEVEL EQ string + | TWO_DIGIT_CENTURY_START EQ INT + | UNSUPPORTED_DDL_ACTION EQ string + | USE_CACHED_RESULT EQ trueFalse + | WEEK_OF_YEAR_POLICY EQ INT + | WEEK_START EQ INT + ; + +alterAccount: ALTER ACCOUNT alterAccountOpts + ; + +enabledTrueFalse: ENABLED EQ trueFalse + ; + +alterAlert + : ALTER ALERT (IF EXISTS)? id ( + resumeSuspend + | SET alertSetClause+ + | UNSET alertUnsetClause+ + | MODIFY CONDITION EXISTS LPAREN alertCondition RPAREN + | MODIFY ACTION alertAction + ) + ; + +resumeSuspend: RESUME | SUSPEND + ; + +alertSetClause: WAREHOUSE EQ id | SCHEDULE EQ string | (COMMENT EQ string) + ; + +alertUnsetClause: WAREHOUSE | SCHEDULE | COMMENT + ; + +alterApiIntegration + : ALTER API? INTEGRATION (IF EXISTS)? id SET (API_AWS_ROLE_ARN EQ string)? ( + AZURE_AD_APPLICATION_ID EQ string + )? (API_KEY EQ string)? enabledTrueFalse? (API_ALLOWED_PREFIXES EQ LPAREN string RPAREN)? ( + API_BLOCKED_PREFIXES EQ LPAREN string RPAREN + )? (COMMENT EQ string)? + | ALTER API? INTEGRATION id setTags + | ALTER API? INTEGRATION id unsetTags + | ALTER API? INTEGRATION (IF EXISTS)? id UNSET apiIntegrationProperty ( + COMMA apiIntegrationProperty + )* + ; + +apiIntegrationProperty: API_KEY | ENABLED | API_BLOCKED_PREFIXES | COMMENT + ; + +alterConnection: ALTER CONNECTION alterConnectionOpts + ; + +alterDatabase + : ALTER DATABASE (IF EXISTS)? id RENAME TO id + | ALTER DATABASE (IF EXISTS)? id SWAP WITH id + | ALTER DATABASE (IF EXISTS)? id SET (DATA_RETENTION_TIME_IN_DAYS EQ INT)? ( + MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT + )? defaultDdlCollation? (COMMENT EQ string)? + | ALTER DATABASE id setTags + | ALTER DATABASE id unsetTags + | ALTER DATABASE (IF EXISTS)? id UNSET databaseProperty (COMMA databaseProperty)* + | ALTER DATABASE id ENABLE REPLICATION TO ACCOUNTS accountIdList (IGNORE EDITION CHECK)? + | ALTER DATABASE id DISABLE REPLICATION ( TO ACCOUNTS accountIdList)? + | ALTER DATABASE id REFRESH + // Database Failover + | ALTER DATABASE id ENABLE FAILOVER TO ACCOUNTS accountIdList + | ALTER DATABASE id DISABLE FAILOVER ( TO ACCOUNTS accountIdList)? + | ALTER DATABASE id PRIMARY + ; + +databaseProperty + : DATA_RETENTION_TIME_IN_DAYS + | MAX_DATA_EXTENSION_TIME_IN_DAYS + | DEFAULT_DDL_COLLATION_ + | COMMENT + ; + +accountIdList: id (COMMA id)* + ; + +alterDynamicTable: ALTER DYNAMIC TABLE id (resumeSuspend | REFRESH | SET WAREHOUSE EQ id) + ; + +alterExternalTable + : ALTER EXTERNAL TABLE (IF EXISTS)? dotIdentifier REFRESH string? + | ALTER EXTERNAL TABLE (IF EXISTS)? dotIdentifier ADD FILES LPAREN stringList RPAREN + | ALTER EXTERNAL TABLE (IF EXISTS)? dotIdentifier REMOVE FILES LPAREN stringList RPAREN + | ALTER EXTERNAL TABLE (IF EXISTS)? dotIdentifier SET (AUTO_REFRESH EQ trueFalse)? tagDeclList? + | ALTER EXTERNAL TABLE (IF EXISTS)? dotIdentifier unsetTags + //Partitions added and removed manually + | ALTER EXTERNAL TABLE dotIdentifier (IF EXISTS)? ADD PARTITION LPAREN columnName EQ string ( + COMMA columnName EQ string + )* RPAREN LOCATION string + | ALTER EXTERNAL TABLE dotIdentifier (IF EXISTS)? DROP PARTITION LOCATION string + ; + +ignoreEditionCheck: IGNORE EDITION CHECK + ; + +replicationSchedule: REPLICATION_SCHEDULE EQ string + ; + +dbNameList: id (COMMA id)* + ; + +shareNameList: id (COMMA id)* + ; + +fullAcctList: fullAcct (COMMA fullAcct)* + ; + +alterFailoverGroup + //Source Account + : ALTER FAILOVER GROUP (IF EXISTS)? id RENAME TO id + | ALTER FAILOVER GROUP (IF EXISTS)? id SET (OBJECT_TYPES EQ objectTypeList)? replicationSchedule? + | ALTER FAILOVER GROUP (IF EXISTS)? id SET OBJECT_TYPES EQ objectTypeList + // ALLOWED_INTEGRATION_TYPES EQ [ , ... ] ] + replicationSchedule? + | ALTER FAILOVER GROUP (IF EXISTS)? id ADD dbNameList TO ALLOWED_DATABASES + | ALTER FAILOVER GROUP (IF EXISTS)? id MOVE DATABASES dbNameList TO FAILOVER GROUP id + | ALTER FAILOVER GROUP (IF EXISTS)? id REMOVE dbNameList FROM ALLOWED_DATABASES + | ALTER FAILOVER GROUP (IF EXISTS)? id ADD shareNameList TO ALLOWED_SHARES + | ALTER FAILOVER GROUP (IF EXISTS)? id MOVE SHARES shareNameList TO FAILOVER GROUP id + | ALTER FAILOVER GROUP (IF EXISTS)? id REMOVE shareNameList FROM ALLOWED_SHARES + | ALTER FAILOVER GROUP (IF EXISTS)? id ADD fullAcctList TO ALLOWED_ACCOUNTS ignoreEditionCheck? + | ALTER FAILOVER GROUP (IF EXISTS)? id REMOVE fullAcctList FROM ALLOWED_ACCOUNTS + //Target Account + | ALTER FAILOVER GROUP (IF EXISTS)? id (REFRESH | PRIMARY | SUSPEND | RESUME) + ; + +alterFileFormat + : ALTER FILE FORMAT (IF EXISTS)? id RENAME TO id + | ALTER FILE FORMAT (IF EXISTS)? id SET (formatTypeOptions* (COMMENT EQ string)?) + ; + +alterFunction + : alterFunctionSignature RENAME TO id + | alterFunctionSignature SET (COMMENT EQ string) + | alterFunctionSignature SET SECURE + | alterFunctionSignature UNSET (SECURE | COMMENT) + // External Functions + | alterFunctionSignature SET API_INTEGRATION EQ id + | alterFunctionSignature SET HEADERS EQ LPAREN headerDecl* RPAREN + | alterFunctionSignature SET CONTEXT_HEADERS EQ LPAREN id* RPAREN + | alterFunctionSignature SET MAX_BATCH_ROWS EQ INT + | alterFunctionSignature SET COMPRESSION EQ compressionType + | alterFunctionSignature SET (REQUEST_TRANSLATOR | RESPONSE_TRANSLATOR) EQ id + | alterFunctionSignature UNSET ( + COMMENT + | HEADERS + | CONTEXT_HEADERS + | MAX_BATCH_ROWS + | COMPRESSION + | SECURE + | REQUEST_TRANSLATOR + | RESPONSE_TRANSLATOR + ) + ; + +alterFunctionSignature: ALTER FUNCTION (IF EXISTS)? id LPAREN dataTypeList? RPAREN + ; + +dataTypeList: dataType (COMMA dataType)* + ; + +alterMaskingPolicy + : ALTER MASKING POLICY (IF EXISTS)? id SET BODY ARROW expr + | ALTER MASKING POLICY (IF EXISTS)? id RENAME TO id + | ALTER MASKING POLICY (IF EXISTS)? id SET (COMMENT EQ string) + ; + +alterMaterializedView + : ALTER MATERIALIZED VIEW id ( + RENAME TO id + | CLUSTER BY LPAREN exprList RPAREN + | DROP CLUSTERING KEY + | resumeSuspend RECLUSTER? + | SET ( SECURE? (COMMENT EQ string)?) + | UNSET ( SECURE | COMMENT) + ) + ; + +alterNetworkPolicy: ALTER NETWORK POLICY alterNetworkPolicyOpts + ; + +alterNotificationIntegration + : ALTER NOTIFICATION? INTEGRATION (IF EXISTS)? id SET enabledTrueFalse? cloudProviderParamsAuto ( + COMMENT EQ string + )? + // Push notifications + | ALTER NOTIFICATION? INTEGRATION (IF EXISTS)? id SET enabledTrueFalse? cloudProviderParamsPush ( + COMMENT EQ string + )? + | ALTER NOTIFICATION? INTEGRATION id setTags + | ALTER NOTIFICATION? INTEGRATION id unsetTags + | ALTER NOTIFICATION? INTEGRATION (IF EXISTS) id UNSET (ENABLED | COMMENT) + ; + +alterPipe + : ALTER PIPE (IF EXISTS)? id SET (objectProperties? (COMMENT EQ string)?) + | ALTER PIPE id setTags + | ALTER PIPE id unsetTags + | ALTER PIPE (IF EXISTS)? id UNSET PIPE_EXECUTION_PAUSED EQ trueFalse + | ALTER PIPE (IF EXISTS)? id UNSET COMMENT + | ALTER PIPE (IF EXISTS)? id REFRESH (PREFIX EQ string)? (MODIFIED_AFTER EQ string)? + ; + +alterReplicationGroup + //Source Account + : ALTER REPLICATION GROUP (IF EXISTS)? id RENAME TO id + | ALTER REPLICATION GROUP (IF EXISTS)? id SET (OBJECT_TYPES EQ objectTypeList)? ( + REPLICATION_SCHEDULE EQ string + )? + | ALTER REPLICATION GROUP (IF EXISTS)? id SET OBJECT_TYPES EQ objectTypeList ALLOWED_INTEGRATION_TYPES EQ integrationTypeName ( + COMMA integrationTypeName + )* (REPLICATION_SCHEDULE EQ string)? + | ALTER REPLICATION GROUP (IF EXISTS)? id ADD dbNameList TO ALLOWED_DATABASES + | ALTER REPLICATION GROUP (IF EXISTS)? id MOVE DATABASES dbNameList TO REPLICATION GROUP id + | ALTER REPLICATION GROUP (IF EXISTS)? id REMOVE dbNameList FROM ALLOWED_DATABASES + | ALTER REPLICATION GROUP (IF EXISTS)? id ADD shareNameList TO ALLOWED_SHARES + | ALTER REPLICATION GROUP (IF EXISTS)? id MOVE SHARES shareNameList TO REPLICATION GROUP id + | ALTER REPLICATION GROUP (IF EXISTS)? id REMOVE shareNameList FROM ALLOWED_SHARES + | ALTER REPLICATION GROUP (IF EXISTS)? id ADD accountIdList TO ALLOWED_ACCOUNTS ignoreEditionCheck? + | ALTER REPLICATION GROUP (IF EXISTS)? id REMOVE accountIdList FROM ALLOWED_ACCOUNTS + //Target Account + | ALTER REPLICATION GROUP (IF EXISTS)? id REFRESH + | ALTER REPLICATION GROUP (IF EXISTS)? id SUSPEND + | ALTER REPLICATION GROUP (IF EXISTS)? id RESUME + ; + +creditQuota: CREDIT_QUOTA EQ INT + ; + +frequency: FREQUENCY EQ (MONTHLY | DAILY | WEEKLY | YEARLY | NEVER) + ; + +notifyUsers: NOTIFY_USERS EQ LPAREN id (COMMA id)* RPAREN + ; + +triggerDefinition: ON INT PERCENT DO (SUSPEND | SUSPEND_IMMEDIATE | NOTIFY) + ; + +alterResourceMonitor + : ALTER RESOURCE MONITOR (IF EXISTS)? id ( + SET creditQuota? frequency? (START_TIMESTAMP EQ LPAREN string | IMMEDIATELY RPAREN)? ( + END_TIMESTAMP EQ string + )? + )? (notifyUsers ( TRIGGERS triggerDefinition (COMMA triggerDefinition)*)?)? + ; + +alterRole + : ALTER ROLE (IF EXISTS)? id RENAME TO id + | ALTER ROLE (IF EXISTS)? id SET (COMMENT EQ string) + | ALTER ROLE (IF EXISTS)? id UNSET COMMENT + | ALTER ROLE (IF EXISTS)? id setTags + | ALTER ROLE (IF EXISTS)? id unsetTags + ; + +alterRowAccessPolicy + : ALTER ROW ACCESS POLICY (IF EXISTS)? id SET BODY ARROW expr + | ALTER ROW ACCESS POLICY (IF EXISTS)? id RENAME TO id + | ALTER ROW ACCESS POLICY (IF EXISTS)? id SET (COMMENT EQ string) + ; + +alterSchema + : ALTER SCHEMA (IF EXISTS)? schemaName RENAME TO schemaName + | ALTER SCHEMA (IF EXISTS)? schemaName SWAP WITH schemaName + | ALTER SCHEMA (IF EXISTS)? schemaName SET ( + (DATA_RETENTION_TIME_IN_DAYS EQ INT)? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? defaultDdlCollation? ( + COMMENT EQ string + )? + ) + | ALTER SCHEMA (IF EXISTS)? schemaName setTags + | ALTER SCHEMA (IF EXISTS)? schemaName unsetTags + | ALTER SCHEMA (IF EXISTS)? schemaName UNSET schemaProperty (COMMA schemaProperty)* + | ALTER SCHEMA (IF EXISTS)? schemaName ( ENABLE | DISABLE) MANAGED ACCESS + ; + +schemaProperty + : DATA_RETENTION_TIME_IN_DAYS + | MAX_DATA_EXTENSION_TIME_IN_DAYS + | DEFAULT_DDL_COLLATION_ + | COMMENT + ; + +alterSequence + : ALTER SEQUENCE (IF EXISTS)? dotIdentifier RENAME TO dotIdentifier + | ALTER SEQUENCE (IF EXISTS)? dotIdentifier SET? ( INCREMENT BY? EQ? INT)? + | ALTER SEQUENCE (IF EXISTS)? dotIdentifier SET ( + orderNoorder? (COMMENT EQ string) + | orderNoorder + ) + | ALTER SEQUENCE (IF EXISTS)? dotIdentifier UNSET COMMENT + ; + +alterSecurityIntegrationExternalOauth + : ALTER SECURITY? INTEGRATION (IF EXISTS) id SET (TYPE EQ EXTERNAL_OAUTH)? ( + ENABLED EQ trueFalse + )? (EXTERNAL_OAUTH_TYPE EQ ( OKTA | id | PING_FEDERATE | CUSTOM))? ( + EXTERNAL_OAUTH_ISSUER EQ string + )? (EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM EQ (string | LPAREN stringList RPAREN))? ( + EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE EQ string + )? (EXTERNAL_OAUTH_JWS_KEYS_URL EQ string)? // For OKTA | PING_FEDERATE | CUSTOM + (EXTERNAL_OAUTH_JWS_KEYS_URL EQ (string | LPAREN stringList RPAREN))? // For Azure + (EXTERNAL_OAUTH_RSA_PUBLIC_KEY EQ string)? (EXTERNAL_OAUTH_RSA_PUBLIC_KEY_2 EQ string)? ( + EXTERNAL_OAUTH_BLOCKED_ROLES_LIST EQ LPAREN stringList RPAREN + )? (EXTERNAL_OAUTH_ALLOWED_ROLES_LIST EQ LPAREN stringList RPAREN)? ( + EXTERNAL_OAUTH_AUDIENCE_LIST EQ LPAREN string RPAREN + )? (EXTERNAL_OAUTH_ANY_ROLE_MODE EQ (DISABLE | ENABLE | ENABLE_FOR_PRIVILEGE))? ( + EXTERNAL_OAUTH_ANY_ROLE_MODE EQ string + )? // Only for EXTERNAL_OAUTH_TYPE EQ CUSTOM + | ALTER SECURITY? INTEGRATION (IF EXISTS)? id UNSET securityIntegrationExternalOauthProperty ( + COMMA securityIntegrationExternalOauthProperty + )* + | ALTER SECURITY? INTEGRATION id setTags + | ALTER SECURITY? INTEGRATION id unsetTags + ; + +securityIntegrationExternalOauthProperty + : ENABLED + | NETWORK_POLICY + | OAUTH_CLIENT_RSA_PUBLIC_KEY + | OAUTH_CLIENT_RSA_PUBLIC_KEY_2 + | OAUTH_USE_SECONDARY_ROLES EQ (IMPLICIT | NONE) + | COMMENT + ; + +alterSecurityIntegrationSnowflakeOauth + : ALTER SECURITY? INTEGRATION (IF EXISTS)? id SET (TYPE EQ EXTERNAL_OAUTH)? enabledTrueFalse? ( + EXTERNAL_OAUTH_TYPE EQ ( OKTA | id | PING_FEDERATE | CUSTOM) + )? (EXTERNAL_OAUTH_ISSUER EQ string)? ( + EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM EQ (string | LPAREN stringList RPAREN) + )? (EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE EQ string)? ( + EXTERNAL_OAUTH_JWS_KEYS_URL EQ string + )? // For OKTA | PING_FEDERATE | CUSTOM + (EXTERNAL_OAUTH_JWS_KEYS_URL EQ ( string | LPAREN stringList RPAREN))? // For Azure + (EXTERNAL_OAUTH_RSA_PUBLIC_KEY EQ string)? (EXTERNAL_OAUTH_RSA_PUBLIC_KEY_2 EQ string)? ( + EXTERNAL_OAUTH_BLOCKED_ROLES_LIST EQ LPAREN stringList RPAREN + )? (EXTERNAL_OAUTH_ALLOWED_ROLES_LIST EQ LPAREN stringList RPAREN)? ( + EXTERNAL_OAUTH_AUDIENCE_LIST EQ LPAREN string RPAREN + )? (EXTERNAL_OAUTH_ANY_ROLE_MODE EQ DISABLE | ENABLE | ENABLE_FOR_PRIVILEGE)? ( + EXTERNAL_OAUTH_SCOPE_DELIMITER EQ string + ) // Only for EXTERNAL_OAUTH_TYPE EQ CUSTOM + | ALTER SECURITY? INTEGRATION (IF EXISTS)? id UNSET securityIntegrationSnowflakeOauthProperty ( + COMMA securityIntegrationSnowflakeOauthProperty + )* + | ALTER SECURITY? INTEGRATION id setTags + | ALTER SECURITY? INTEGRATION id unsetTags + ; + +securityIntegrationSnowflakeOauthProperty: ENABLED | EXTERNAL_OAUTH_AUDIENCE_LIST + ; + +alterSecurityIntegrationSaml2 + : ALTER SECURITY? INTEGRATION (IF EXISTS)? id SET (TYPE EQ SAML2)? enabledTrueFalse? ( + SAML2_ISSUER EQ string + )? (SAML2_SSO_URL EQ string)? (SAML2_PROVIDER EQ string)? (SAML2_X509_CERT EQ string)? ( + SAML2_SP_INITIATED_LOGIN_PAGE_LABEL EQ string + )? (SAML2_ENABLE_SP_INITIATED EQ trueFalse)? (SAML2_SNOWFLAKE_X509_CERT EQ string)? ( + SAML2_SIGN_REQUEST EQ trueFalse + )? (SAML2_REQUESTED_NAMEID_FORMAT EQ string)? (SAML2_POST_LOGOUT_REDIRECT_URL EQ string)? ( + SAML2_FORCE_AUTHN EQ trueFalse + )? (SAML2_SNOWFLAKE_ISSUER_URL EQ string)? (SAML2_SNOWFLAKE_ACS_URL EQ string)? + | ALTER SECURITY? INTEGRATION (IF EXISTS)? id UNSET ENABLED + | ALTER SECURITY? INTEGRATION id setTags + | ALTER SECURITY? INTEGRATION id unsetTags + ; + +alterSecurityIntegrationScim + : ALTER SECURITY? INTEGRATION (IF EXISTS)? id SET (NETWORK_POLICY EQ string)? ( + SYNC_PASSWORD EQ trueFalse + )? (COMMENT EQ string)? + | ALTER SECURITY? INTEGRATION (IF EXISTS)? id UNSET securityIntegrationScimProperty ( + COMMA securityIntegrationScimProperty + )* + | ALTER SECURITY? INTEGRATION id setTags + | ALTER SECURITY? INTEGRATION id unsetTags + ; + +securityIntegrationScimProperty: NETWORK_POLICY | SYNC_PASSWORD | COMMENT + ; + +alterSession: ALTER SESSION SET sessionParams | ALTER SESSION UNSET id (COMMA id)* + ; + +alterSessionPolicy + : ALTER SESSION POLICY (IF EXISTS)? id (UNSET | SET) (SESSION_IDLE_TIMEOUT_MINS EQ INT)? ( + SESSION_UI_IDLE_TIMEOUT_MINS EQ INT + )? (COMMENT EQ string)? + | ALTER SESSION POLICY (IF EXISTS)? id RENAME TO id + ; + +alterShare + : ALTER SHARE (IF EXISTS)? id (ADD | REMOVE) ACCOUNTS EQ id (COMMA id)* ( + SHARE_RESTRICTIONS EQ trueFalse + )? + | ALTER SHARE (IF EXISTS)? id ADD ACCOUNTS EQ id (COMMA id)* (SHARE_RESTRICTIONS EQ trueFalse)? + | ALTER SHARE (IF EXISTS)? id SET (ACCOUNTS EQ id (COMMA id)*)? (COMMENT EQ string)? + | ALTER SHARE (IF EXISTS)? id setTags + | ALTER SHARE id unsetTags + | ALTER SHARE (IF EXISTS)? id UNSET COMMENT + ; + +alterStorageIntegration + : ALTER STORAGE? INTEGRATION (IF EXISTS)? id SET cloudProviderParams2? enabledTrueFalse? ( + STORAGE_ALLOWED_LOCATIONS EQ LPAREN stringList RPAREN + )? (STORAGE_BLOCKED_LOCATIONS EQ LPAREN stringList RPAREN)? (COMMENT EQ string)? + | ALTER STORAGE? INTEGRATION (IF EXISTS)? id setTags + | ALTER STORAGE? INTEGRATION id unsetTags + | ALTER STORAGE? INTEGRATION (IF EXISTS)? id UNSET ( + ENABLED + | STORAGE_BLOCKED_LOCATIONS + | COMMENT + ) + //[ , ... ] + ; + +alterStream + : ALTER STREAM (IF EXISTS)? id SET tagDeclList? (COMMENT EQ string)? + | ALTER STREAM (IF EXISTS)? id setTags + | ALTER STREAM id unsetTags + | ALTER STREAM (IF EXISTS)? id UNSET COMMENT + ; + +alterTable + : ALTER TABLE (IF EXISTS)? dotIdentifier RENAME TO dotIdentifier + | ALTER TABLE (IF EXISTS)? dotIdentifier SWAP WITH dotIdentifier + | ALTER TABLE (IF EXISTS)? dotIdentifier ( + clusteringAction + | tableColumnAction + | constraintAction + ) + | ALTER TABLE (IF EXISTS)? dotIdentifier extTableColumnAction + | ALTER TABLE (IF EXISTS)? dotIdentifier searchOptimizationAction + | ALTER TABLE (IF EXISTS)? dotIdentifier SET stageFileFormat? ( + STAGE_COPY_OPTIONS EQ LPAREN copyOptions RPAREN + )? (DATA_RETENTION_TIME_IN_DAYS EQ INT)? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? ( + CHANGE_TRACKING EQ trueFalse + )? defaultDdlCollation? (COMMENT EQ string)? + | ALTER TABLE (IF EXISTS)? dotIdentifier setTags + | ALTER TABLE (IF EXISTS)? dotIdentifier unsetTags + | ALTER TABLE (IF EXISTS)? dotIdentifier UNSET ( + DATA_RETENTION_TIME_IN_DAYS + | MAX_DATA_EXTENSION_TIME_IN_DAYS + | CHANGE_TRACKING + | DEFAULT_DDL_COLLATION_ + | COMMENT + | + ) + //[ , ... ] + | ALTER TABLE (IF EXISTS)? dotIdentifier ADD ROW ACCESS POLICY id ON columnListInParentheses + | ALTER TABLE (IF EXISTS)? dotIdentifier DROP ROW ACCESS POLICY id + | ALTER TABLE (IF EXISTS)? dotIdentifier DROP ROW ACCESS POLICY id COMMA ADD ROW ACCESS POLICY id ON columnListInParentheses + | ALTER TABLE (IF EXISTS)? dotIdentifier DROP ALL ROW ACCESS POLICIES + ; + +clusteringAction + : CLUSTER BY LPAREN exprList RPAREN + | RECLUSTER ( MAX_SIZE EQ INT)? ( WHERE expr)? + | resumeSuspend RECLUSTER + | DROP CLUSTERING KEY + ; + +tableColumnAction + : ADD COLUMN? (IF NOT EXISTS)? fullColDecl (COMMA fullColDecl)* + | RENAME COLUMN columnName TO columnName + | alterModify ( + LPAREN alterColumnClause (COLON alterColumnClause)* RPAREN + | alterColumnClause (COLON alterColumnClause)* + ) + | alterModify COLUMN columnName SET MASKING POLICY id ( + USING LPAREN columnName COMMA columnList RPAREN + )? FORCE? + | alterModify COLUMN columnName UNSET MASKING POLICY + | alterModify columnSetTags (COMMA columnSetTags)* + | alterModify columnUnsetTags (COMMA columnUnsetTags)* + | DROP COLUMN? (IF EXISTS)? columnList + //| DROP DEFAULT + ; + +alterColumnClause + : COLUMN? columnName ( + DROP DEFAULT + | SET DEFAULT dotIdentifier DOT NEXTVAL + | ( SET? NOT NULL | DROP NOT NULL) + | ( (SET DATA)? TYPE)? dataType + | COMMENT string + | UNSET COMMENT + ) + ; + +inlineConstraint + : (CONSTRAINT id)? ( + (UNIQUE | primaryKey) commonConstraintProperties* + | foreignKey REFERENCES dotIdentifier (LPAREN columnName RPAREN)? constraintProperties + ) + ; + +enforcedNotEnforced: NOT? ENFORCED + ; + +deferrableNotDeferrable: NOT? DEFERRABLE + ; + +initiallyDeferredOrImmediate: INITIALLY (DEFERRED | IMMEDIATE) + ; + +//TODO : Some properties are mutualy exclusive ie INITIALLY DEFERRED is not compatible with NOT DEFERRABLE +// also VALIDATE | NOVALIDATE need to be after ENABLE or ENFORCED. Lot of case to handle :) +commonConstraintProperties + : enforcedNotEnforced (VALIDATE | NOVALIDATE)? + | deferrableNotDeferrable + | initiallyDeferredOrImmediate + | ( ENABLE | DISABLE) ( VALIDATE | NOVALIDATE)? + | RELY + | NORELY + ; + +onUpdate: ON UPDATE onAction + ; + +onDelete: ON DELETE onAction + ; + +foreignKeyMatch: MATCH matchType = (FULL | PARTIAL | SIMPLE) + ; + +onAction: CASCADE | SET ( NULL | DEFAULT) | RESTRICT | NO ACTION + ; + +constraintProperties + : commonConstraintProperties* + | foreignKeyMatch + | foreignKeyMatch? ( onUpdate onDelete? | onDelete onUpdate?) + ; + +extTableColumnAction + : ADD COLUMN? columnName dataType AS LPAREN expr RPAREN + | RENAME COLUMN columnName TO columnName + | DROP COLUMN? columnList + ; + +constraintAction + : ADD outOfLineConstraint + | RENAME CONSTRAINT id TO id + | alterModify (CONSTRAINT id | primaryKey | UNIQUE | foreignKey) columnListInParentheses enforcedNotEnforced? ( + VALIDATE + | NOVALIDATE + ) (RELY | NORELY) + | DROP (CONSTRAINT id | primaryKey | UNIQUE | foreignKey) columnListInParentheses? cascadeRestrict? + | DROP PRIMARY KEY + ; + +searchOptimizationAction + : ADD SEARCH OPTIMIZATION (ON searchMethodWithTarget (COMMA searchMethodWithTarget)*)? + | DROP SEARCH OPTIMIZATION (ON searchMethodWithTarget (COMMA searchMethodWithTarget)*)? + ; + +searchMethodWithTarget: (EQUALITY | SUBSTRING | GEO) LPAREN (STAR | expr) RPAREN + ; + +alterTableAlterColumn + : ALTER TABLE dotIdentifier alterModify ( + LPAREN alterColumnDeclList RPAREN + | alterColumnDeclList + ) + | ALTER TABLE dotIdentifier alterModify COLUMN columnName SET MASKING POLICY id ( + USING LPAREN columnName COMMA columnList RPAREN + )? FORCE? + | ALTER TABLE dotIdentifier alterModify COLUMN columnName UNSET MASKING POLICY + | ALTER TABLE dotIdentifier alterModify columnSetTags (COMMA columnSetTags)* + | ALTER TABLE dotIdentifier alterModify columnUnsetTags (COMMA columnUnsetTags)* + ; + +alterColumnDeclList: alterColumnDecl (COMMA alterColumnDecl)* + ; + +alterColumnDecl: COLUMN? columnName alterColumnOpts + ; + +alterColumnOpts + : DROP DEFAULT + | SET DEFAULT dotIdentifier DOT NEXTVAL + | ( SET? NOT NULL | DROP NOT NULL) + | ( (SET DATA)? TYPE)? dataType + | (COMMENT EQ string) + | UNSET COMMENT + ; + +columnSetTags: COLUMN? columnName setTags + ; + +columnUnsetTags: COLUMN columnName unsetTags + ; + +alterTag: ALTER TAG (IF EXISTS)? dotIdentifier alterTagOpts + ; + +alterTask + : ALTER TASK (IF EXISTS)? dotIdentifier resumeSuspend + | ALTER TASK (IF EXISTS)? dotIdentifier ( REMOVE | ADD) AFTER stringList + | ALTER TASK (IF EXISTS)? dotIdentifier SET + // TODO : Check and review if element's order binded or not + (WAREHOUSE EQ id)? taskSchedule? taskOverlap? taskTimeout? taskSuspendAfterFailureNumber? ( + COMMENT EQ string + )? sessionParamsList? + | ALTER TASK (IF EXISTS)? dotIdentifier UNSET + // TODO : Check and review if element's order binded or not + WAREHOUSE? SCHEDULE? ALLOW_OVERLAPPING_EXECUTION? USER_TASK_TIMEOUT_MS? SUSPEND_TASK_AFTER_NUM_FAILURES? COMMENT? sessionParameterList? + //[ , ... ] + | ALTER TASK (IF EXISTS)? dotIdentifier setTags + | ALTER TASK (IF EXISTS)? dotIdentifier unsetTags + | ALTER TASK (IF EXISTS)? dotIdentifier MODIFY AS sql + | ALTER TASK (IF EXISTS)? dotIdentifier MODIFY WHEN expr + ; + +alterUser: ALTER USER (IF EXISTS)? id alterUserOpts + ; + +alterView + : ALTER VIEW (IF EXISTS)? dotIdentifier RENAME TO dotIdentifier + | ALTER VIEW (IF EXISTS)? dotIdentifier SET (COMMENT EQ string) + | ALTER VIEW (IF EXISTS)? dotIdentifier UNSET COMMENT + | ALTER VIEW dotIdentifier SET SECURE + | ALTER VIEW dotIdentifier UNSET SECURE + | ALTER VIEW (IF EXISTS)? dotIdentifier setTags + | ALTER VIEW (IF EXISTS)? dotIdentifier unsetTags + | ALTER VIEW (IF EXISTS)? dotIdentifier ADD ROW ACCESS POLICY id ON columnListInParentheses + | ALTER VIEW (IF EXISTS)? dotIdentifier DROP ROW ACCESS POLICY id + | ALTER VIEW (IF EXISTS)? dotIdentifier ADD ROW ACCESS POLICY id ON columnListInParentheses COMMA DROP ROW ACCESS POLICY id + | ALTER VIEW (IF EXISTS)? dotIdentifier DROP ALL ROW ACCESS POLICIES + | ALTER VIEW dotIdentifier alterModify COLUMN? id SET MASKING POLICY id ( + USING LPAREN columnName COMMA columnList RPAREN + )? FORCE? + | ALTER VIEW dotIdentifier alterModify COLUMN? id UNSET MASKING POLICY + | ALTER VIEW dotIdentifier alterModify COLUMN? id setTags + | ALTER VIEW dotIdentifier alterModify COLUMN id unsetTags + ; + +alterModify: ALTER | MODIFY + ; + +alterWarehouse: ALTER WAREHOUSE (IF EXISTS)? alterWarehouseOpts + ; + +alterConnectionOpts + : id ENABLE FAILOVER TO ACCOUNTS id DOT id (COMMA id DOT id)* ignoreEditionCheck? + | id DISABLE FAILOVER ( TO ACCOUNTS id DOT id (COMMA id DOT id))? + | id PRIMARY + | (IF EXISTS)? id SET (COMMENT EQ string) + | (IF EXISTS)? id UNSET COMMENT + ; + +alterUserOpts + : RENAME TO id + | RESET PASSWORD + | ABORT ALL QUERIES + | ADD DELEGATED AUTHORIZATION OF ROLE id TO SECURITY INTEGRATION id + | REMOVE DELEGATED (AUTHORIZATION OF ROLE id | AUTHORIZATIONS) FROM SECURITY INTEGRATION id + | setTags + | unsetTags + // | SET objectProperties? objectParams? sessionParams? + // | UNSET (objectPropertyName | objectid | sessionid) //[ , ... ] + ; + +alterTagOpts + : RENAME TO dotIdentifier + | ( ADD | DROP) tagAllowedValues + | UNSET ALLOWED_VALUES + | SET MASKING POLICY id (COMMA MASKING POLICY id)* + | UNSET MASKING POLICY id (COMMA MASKING POLICY id)* + | SET (COMMENT EQ string) + | UNSET COMMENT + ; + +alterNetworkPolicyOpts + : (IF EXISTS)? id SET (ALLOWED_IP_LIST EQ LPAREN stringList RPAREN)? ( + BLOCKED_IP_LIST EQ LPAREN stringList RPAREN + )? (COMMENT EQ string)? + | (IF EXISTS)? id UNSET COMMENT + | id RENAME TO id + ; + +alterWarehouseOpts + : idFn? (SUSPEND | RESUME (IF SUSPENDED)?) + | idFn? ABORT ALL QUERIES + | idFn RENAME TO id + // | id SET [ objectProperties ] + | idFn setTags + | idFn unsetTags + | idFn UNSET id (COMMA id)* + | id SET whProperties (COLON whProperties)* + ; + +alterAccountOpts + : SET accountParams? objectParams? sessionParams? + | UNSET id (COMMA id)? + | SET RESOURCE_MONITOR EQ id + | setTags + | unsetTags + | id RENAME TO id ( SAVE_OLD_URL EQ trueFalse)? + | id DROP OLD URL + ; + +setTags: SET tagDeclList + ; + +tagDeclList: TAG dotIdentifier EQ string (COMMA dotIdentifier EQ string)* + ; + +unsetTags: UNSET TAG dotIdentifier (COMMA dotIdentifier)* + ; + +// create commands +createCommand + : createAccount + | createAlert + | createApiIntegration + | createObjectClone + | createConnection + | createDatabase + | createDynamicTable + | createEventTable + | createExternalFunction + | createExternalTable + | createFailoverGroup + | createFileFormat + | createFunction + //| createIntegration + | createManagedAccount + | createMaskingPolicy + | createMaterializedView + | createNetworkPolicy + | createNotificationIntegration + | createPipe + | createProcedure + | createReplicationGroup + | createResourceMonitor + | createRole + | createRowAccessPolicy + | createSchema + | createSecurityIntegrationExternalOauth + | createSecurityIntegrationSnowflakeOauth + | createSecurityIntegrationSaml2 + | createSecurityIntegrationScim + | createSequence + | createSessionPolicy + | createShare + | createStage + | createStorageIntegration + | createStream + | createTable + | createTableAsSelect + | createTableLike + // | create_|AlterTable_…Constraint + | createTag + | createTask + | createUser + | createView + | createWarehouse + ; + +createAccount + : CREATE ACCOUNT id ADMIN_NAME EQ id ADMIN_PASSWORD EQ string (FIRST_NAME EQ id)? ( + LAST_NAME EQ id + )? EMAIL EQ string (MUST_CHANGE_PASSWORD EQ trueFalse)? EDITION EQ ( + STANDARD + | ENTERPRISE + | BUSINESS_CRITICAL + ) (REGION_GROUP EQ id)? (REGION EQ id)? (COMMENT EQ string)? + ; + +createAlert + : CREATE (OR REPLACE)? ALERT (IF NOT EXISTS)? id WAREHOUSE EQ id SCHEDULE EQ string IF LPAREN EXISTS LPAREN alertCondition RPAREN RPAREN THEN + alertAction + ; + +alertCondition: selectStatement | showCommand | call + ; + +alertAction: sqlClauses + ; + +createApiIntegration + : CREATE (OR REPLACE)? API INTEGRATION (IF NOT EXISTS)? id API_PROVIDER EQ (id) API_AWS_ROLE_ARN EQ string ( + API_KEY EQ string + )? API_ALLOWED_PREFIXES EQ LPAREN string RPAREN (API_BLOCKED_PREFIXES EQ LPAREN string RPAREN)? ENABLED EQ trueFalse ( + COMMENT EQ string + )? + | CREATE (OR REPLACE)? API INTEGRATION (IF NOT EXISTS)? id API_PROVIDER EQ id AZURE_TENANT_ID EQ string AZURE_AD_APPLICATION_ID EQ string ( + API_KEY EQ string + )? API_ALLOWED_PREFIXES EQ LPAREN string RPAREN (API_BLOCKED_PREFIXES EQ LPAREN string RPAREN)? ENABLED EQ trueFalse ( + COMMENT EQ string + )? + | CREATE (OR REPLACE) API INTEGRATION (IF NOT EXISTS) id API_PROVIDER EQ id GOOGLE_AUDIENCE EQ string API_ALLOWED_PREFIXES EQ LPAREN string RPAREN + ( + API_BLOCKED_PREFIXES EQ LPAREN string RPAREN + )? ENABLED EQ trueFalse (COMMENT EQ string)? + ; + +createObjectClone + : CREATE (OR REPLACE)? (DATABASE | SCHEMA | TABLE) (IF NOT EXISTS)? id CLONE dotIdentifier ( + atBefore1 LPAREN (TIMESTAMP ASSOC string | OFFSET ASSOC string | STATEMENT ASSOC id) RPAREN + )? + | CREATE (OR REPLACE)? (STAGE | FILE FORMAT | SEQUENCE | STREAM | TASK) (IF NOT EXISTS)? dotIdentifier CLONE dotIdentifier + ; + +createConnection + : CREATE CONNECTION (IF NOT EXISTS)? id ( + (COMMENT EQ string)? + | (AS REPLICA OF id DOT id DOT id (COMMENT EQ string)?) + ) + ; + +createDatabase + : CREATE (OR REPLACE)? TRANSIENT? DATABASE (IF NOT EXISTS)? id cloneAtBefore? ( + DATA_RETENTION_TIME_IN_DAYS EQ INT + )? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? defaultDdlCollation? withTags? (COMMENT EQ string)? + ; + +cloneAtBefore + : CLONE id ( + atBefore1 LPAREN (TIMESTAMP ASSOC string | OFFSET ASSOC string | STATEMENT ASSOC id) RPAREN + )? + ; + +atBefore1: AT_KEYWORD | BEFORE + ; + +headerDecl: string EQ string + ; + +compressionType: NONE | GZIP | DEFLATE | AUTO + ; + +compression: COMPRESSION EQ compressionType + ; + +createDynamicTable + : CREATE (OR REPLACE)? DYNAMIC TABLE id TARGET_LAG EQ (string | DOWNSTREAM) WAREHOUSE EQ wh = id AS queryStatement + ; + +createEventTable + : CREATE (OR REPLACE)? EVENT TABLE (IF NOT EXISTS)? id clusterBy? ( + DATA_RETENTION_TIME_IN_DAYS EQ INT + )? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? changeTracking? (DEFAULT_DDL_COLLATION_ EQ string)? copyGrants? withRowAccessPolicy? withTags? ( + WITH? (COMMENT EQ string) + )? + ; + +createExternalFunction + : CREATE (OR REPLACE)? SECURE? EXTERNAL FUNCTION dotIdentifier LPAREN ( + id dataType (COMMA id dataType)* + )? RPAREN RETURNS dataType (NOT? NULL)? ( + ( CALLED ON NULL INPUT) + | ((RETURNS NULL ON NULL INPUT) | STRICT) + )? (VOLATILE | IMMUTABLE)? (COMMENT EQ string)? API_INTEGRATION EQ id ( + HEADERS EQ LPAREN headerDecl (COMMA headerDecl)* RPAREN + )? (CONTEXT_HEADERS EQ LPAREN id (COMMA id)* RPAREN)? (MAX_BATCH_ROWS EQ INT)? compression? ( + REQUEST_TRANSLATOR EQ id + )? (RESPONSE_TRANSLATOR EQ id)? AS string + ; + +createExternalTable + // Partitions computed from expressions + : CREATE (OR REPLACE)? EXTERNAL TABLE (IF NOT EXISTS)? dotIdentifier LPAREN externalTableColumnDeclList RPAREN cloudProviderParams3? partitionBy? + WITH? LOCATION EQ namedStage (REFRESH_ON_CREATE EQ trueFalse)? (AUTO_REFRESH EQ trueFalse)? pattern? fileFormat ( + AWS_SNS_TOPIC EQ string + )? copyGrants? withRowAccessPolicy? withTags? (COMMENT EQ string)? + // Partitions added and removed manually + | CREATE (OR REPLACE)? EXTERNAL TABLE (IF NOT EXISTS)? dotIdentifier LPAREN externalTableColumnDeclList RPAREN cloudProviderParams3? partitionBy? + WITH? LOCATION EQ namedStage PARTITION_TYPE EQ USER_SPECIFIED fileFormat copyGrants? withRowAccessPolicy? withTags? ( + COMMENT EQ string + )? + // Delta Lake + | CREATE (OR REPLACE)? EXTERNAL TABLE (IF NOT EXISTS)? dotIdentifier LPAREN externalTableColumnDeclList RPAREN cloudProviderParams3? partitionBy? + WITH? LOCATION EQ namedStage PARTITION_TYPE EQ USER_SPECIFIED fileFormat ( + TABLE_FORMAT EQ DELTA + )? copyGrants? withRowAccessPolicy? withTags? (COMMENT EQ string)? + ; + +externalTableColumnDecl: columnName dataType AS (expr | id) inlineConstraint? + ; + +externalTableColumnDeclList: externalTableColumnDecl (COMMA externalTableColumnDecl)* + ; + +fullAcct: id DOT id + ; + +integrationTypeName: SECURITY INTEGRATIONS | API INTEGRATIONS + ; + +createFailoverGroup + : CREATE FAILOVER GROUP (IF NOT EXISTS)? id OBJECT_TYPES EQ objectType (COMMA objectType)* ( + ALLOWED_DATABASES EQ id (COMMA id)* + )? (ALLOWED_SHARES EQ id (COMMA id)*)? ( + ALLOWED_INTEGRATION_TYPES EQ integrationTypeName (COMMA integrationTypeName)* + )? ALLOWED_ACCOUNTS EQ fullAcct (COMMA fullAcct)* (IGNORE EDITION CHECK)? ( + REPLICATION_SCHEDULE EQ string + )? + // Secondary Replication Group + | CREATE FAILOVER GROUP (IF NOT EXISTS)? id AS REPLICA OF id DOT id DOT id + ; + +typeFileformat: CSV | JSON | AVRO | ORC | PARQUET | XML | string + ; + +createFileFormat + : CREATE (OR REPLACE)? FILE FORMAT (IF NOT EXISTS)? dotIdentifier (TYPE EQ typeFileformat)? formatTypeOptions* ( + COMMENT EQ string + )? + ; + +argDecl: id dataType (DEFAULT expr)? + ; + +colDecl: columnName dataType? virtualColumnDecl? + ; + +virtualColumnDecl: AS LPAREN functionCall RPAREN + ; + +functionDefinition: string + ; + +// TODO: merge these rules to avoid massive lookahead +createFunction + : CREATE (OR REPLACE)? SECURE? FUNCTION (IF NOT EXISTS)? dotIdentifier LPAREN ( + argDecl (COMMA argDecl)* + )? RPAREN RETURNS (dataType | TABLE LPAREN (colDecl (COMMA colDecl)*)? RPAREN) (LANGUAGE id)? ( + CALLED ON NULL INPUT + | RETURNS NULL ON NULL INPUT + | STRICT + )? (VOLATILE | IMMUTABLE)? (PACKAGES EQ LPAREN stringList RPAREN)? ( + RUNTIME_VERSION EQ (string | FLOAT) + )? (IMPORTS EQ LPAREN stringList RPAREN)? (PACKAGES EQ LPAREN stringList RPAREN)? ( + HANDLER EQ string + )? (NOT? NULL)? (COMMENT EQ com = string)? AS functionDefinition + | CREATE (OR REPLACE)? SECURE? FUNCTION dotIdentifier LPAREN (argDecl (COMMA argDecl)*)? RPAREN RETURNS ( + dataType + | TABLE LPAREN (colDecl (COMMA colDecl)*)? RPAREN + ) (NOT? NULL)? (CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT)? ( + VOLATILE + | IMMUTABLE + )? MEMOIZABLE? (COMMENT EQ com = string)? AS functionDefinition + ; + +createManagedAccount + : CREATE MANAGED ACCOUNT id ADMIN_NAME EQ id COMMA ADMIN_PASSWORD EQ string COMMA TYPE EQ READER ( + COMMA (COMMENT EQ string) + )? + ; + +createMaskingPolicy + : CREATE (OR REPLACE)? MASKING POLICY (IF NOT EXISTS)? dotIdentifier AS LPAREN id dataType ( + COMMA id dataType + )? RPAREN RETURNS dataType ARROW expr (COMMENT EQ string)? + ; + +tagDecl: dotIdentifier EQ string + ; + +columnListInParentheses: LPAREN columnList RPAREN + ; + +createMaterializedView + : CREATE (OR REPLACE)? SECURE? MATERIALIZED VIEW (IF NOT EXISTS)? dotIdentifier ( + LPAREN columnListWithComment RPAREN + )? viewCol* withRowAccessPolicy? withTags? copyGrants? (COMMENT EQ string)? clusterBy? AS selectStatement + //NOTA MATERIALIZED VIEW accept only simple select statement at this time + ; + +createNetworkPolicy + : CREATE (OR REPLACE)? NETWORK POLICY id ALLOWED_IP_LIST EQ LPAREN stringList? RPAREN ( + BLOCKED_IP_LIST EQ LPAREN stringList? RPAREN + )? (COMMENT EQ string)? + ; + +cloudProviderParamsAuto + //(for Google Cloud Storage) + : NOTIFICATION_PROVIDER EQ GCP_PUBSUB GCP_PUBSUB_SUBSCRIPTION_NAME EQ string + //(for Microsoft Azure Storage) + | NOTIFICATION_PROVIDER EQ AZURE_EVENT_GRID AZURE_STORAGE_QUEUE_PRIMARY_URI EQ string AZURE_TENANT_ID EQ string + ; + +cloudProviderParamsPush + //(for Amazon SNS) + : NOTIFICATION_PROVIDER EQ AWS_SNS AWS_SNS_TOPIC_ARN EQ string AWS_SNS_ROLE_ARN EQ string + //(for Google Pub/Sub) + | NOTIFICATION_PROVIDER EQ GCP_PUBSUB GCP_PUBSUB_TOPIC_NAME EQ string + //(for Microsoft Azure Event Grid) + | NOTIFICATION_PROVIDER EQ AZURE_EVENT_GRID AZURE_EVENT_GRID_TOPIC_ENDPOINT EQ string AZURE_TENANT_ID EQ string + ; + +createNotificationIntegration + : CREATE (OR REPLACE)? NOTIFICATION INTEGRATION (IF NOT EXISTS)? id ENABLED EQ trueFalse TYPE EQ QUEUE cloudProviderParamsAuto ( + COMMENT EQ string + )? + | CREATE (OR REPLACE)? NOTIFICATION INTEGRATION (IF NOT EXISTS)? id ENABLED EQ trueFalse DIRECTION EQ OUTBOUND TYPE EQ QUEUE + cloudProviderParamsPush (COMMENT EQ string)? + ; + +createPipe + : CREATE (OR REPLACE)? PIPE (IF NOT EXISTS)? dotIdentifier (AUTO_INGEST EQ trueFalse)? ( + ERROR_INTEGRATION EQ id + )? (AWS_SNS_TOPIC EQ string)? (INTEGRATION EQ string)? (COMMENT EQ string)? AS copyIntoTable + ; + +executeAs: EXECUTE AS (CALLER | OWNER) + ; + +table: TABLE (LPAREN (colDecl (COMMA colDecl)*)? RPAREN) | (functionCall) + ; + +createReplicationGroup + : CREATE REPLICATION GROUP (IF NOT EXISTS)? id OBJECT_TYPES EQ objectType (COMMA objectType)* ( + ALLOWED_DATABASES EQ id (COMMA id)* + )? (ALLOWED_SHARES EQ id (COMMA id)*)? ( + ALLOWED_INTEGRATION_TYPES EQ integrationTypeName (COMMA integrationTypeName)* + )? ALLOWED_ACCOUNTS EQ fullAcct (COMMA fullAcct)* (IGNORE EDITION CHECK)? ( + REPLICATION_SCHEDULE EQ string + )? + //Secondary Replication Group + | CREATE REPLICATION GROUP (IF NOT EXISTS)? id AS REPLICA OF id DOT id DOT id + ; + +createResourceMonitor + : CREATE (OR REPLACE)? RESOURCE MONITOR id WITH creditQuota? frequency? ( + START_TIMESTAMP EQ ( string | IMMEDIATELY) + )? (END_TIMESTAMP EQ string)? notifyUsers? (TRIGGERS triggerDefinition+)? + ; + +createRole: CREATE (OR REPLACE)? ROLE (IF NOT EXISTS)? id withTags? (COMMENT EQ string)? + ; + +createRowAccessPolicy + : CREATE (OR REPLACE)? ROW ACCESS POLICY (IF NOT EXISTS)? id AS LPAREN argDecl (COMMA argDecl)* RPAREN RETURNS id /* BOOLEAN */ ARROW expr ( + COMMENT EQ string + )? + ; + +createSchema + : CREATE (OR REPLACE)? TRANSIENT? SCHEMA (IF NOT EXISTS)? schemaName cloneAtBefore? ( + WITH MANAGED ACCESS + )? (DATA_RETENTION_TIME_IN_DAYS EQ INT)? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? defaultDdlCollation? withTags? ( + COMMENT EQ string + )? + ; + +createSecurityIntegrationExternalOauth + : CREATE (OR REPLACE)? SECURITY INTEGRATION (IF NOT EXISTS)? id TYPE EQ EXTERNAL_OAUTH ENABLED EQ trueFalse EXTERNAL_OAUTH_TYPE EQ ( + OKTA + | id + | PING_FEDERATE + | CUSTOM + ) EXTERNAL_OAUTH_ISSUER EQ string EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM EQ ( + string + | LPAREN stringList RPAREN + ) EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE EQ string ( + EXTERNAL_OAUTH_JWS_KEYS_URL EQ string + )? // For OKTA | PING_FEDERATE | CUSTOM + (EXTERNAL_OAUTH_JWS_KEYS_URL EQ (string | LPAREN stringList RPAREN))? // For Azure + (EXTERNAL_OAUTH_BLOCKED_ROLES_LIST EQ LPAREN stringList RPAREN)? ( + EXTERNAL_OAUTH_ALLOWED_ROLES_LIST EQ LPAREN stringList RPAREN + )? (EXTERNAL_OAUTH_RSA_PUBLIC_KEY EQ string)? (EXTERNAL_OAUTH_RSA_PUBLIC_KEY_2 EQ string)? ( + EXTERNAL_OAUTH_AUDIENCE_LIST EQ LPAREN string RPAREN + )? (EXTERNAL_OAUTH_ANY_ROLE_MODE EQ (DISABLE | ENABLE | ENABLE_FOR_PRIVILEGE))? ( + EXTERNAL_OAUTH_SCOPE_DELIMITER EQ string + )? // Only for EXTERNAL_OAUTH_TYPE EQ CUSTOM + ; + +implicitNone: IMPLICIT | NONE + ; + +createSecurityIntegrationSnowflakeOauth + : CREATE (OR REPLACE)? SECURITY INTEGRATION (IF NOT EXISTS)? id TYPE EQ OAUTH OAUTH_CLIENT EQ partnerApplication OAUTH_REDIRECT_URI EQ string + //Required when OAUTH_CLIENTEQLOOKER + enabledTrueFalse? (OAUTH_ISSUE_REFRESH_TOKENS EQ trueFalse)? ( + OAUTH_REFRESH_TOKEN_VALIDITY EQ INT + )? (OAUTH_USE_SECONDARY_ROLES EQ implicitNone)? ( + BLOCKED_ROLES_LIST EQ LPAREN stringList RPAREN + )? (COMMENT EQ string)? + // Snowflake OAuth for custom clients + | CREATE (OR REPLACE)? SECURITY INTEGRATION (IF NOT EXISTS)? id TYPE EQ OAUTH OAUTH_CLIENT EQ CUSTOM + //OAUTH_CLIENT_TYPE EQ 'CONFIDENTIAL' | 'PUBLIC' + OAUTH_REDIRECT_URI EQ string enabledTrueFalse? (OAUTH_ALLOW_NON_TLS_REDIRECT_URI EQ trueFalse)? ( + OAUTH_ENFORCE_PKCE EQ trueFalse + )? (OAUTH_USE_SECONDARY_ROLES EQ implicitNone)? ( + PRE_AUTHORIZED_ROLES_LIST EQ LPAREN stringList RPAREN + )? (BLOCKED_ROLES_LIST EQ LPAREN stringList RPAREN)? (OAUTH_ISSUE_REFRESH_TOKENS EQ trueFalse)? ( + OAUTH_REFRESH_TOKEN_VALIDITY EQ INT + )? networkPolicy? (OAUTH_CLIENT_RSA_PUBLIC_KEY EQ string)? ( + OAUTH_CLIENT_RSA_PUBLIC_KEY_2 EQ string + )? (COMMENT EQ string)? + ; + +createSecurityIntegrationSaml2 + : CREATE (OR REPLACE)? SECURITY INTEGRATION (IF NOT EXISTS)? TYPE EQ SAML2 enabledTrueFalse SAML2_ISSUER EQ string SAML2_SSO_URL EQ string + SAML2_PROVIDER EQ string SAML2_X509_CERT EQ string ( + SAML2_SP_INITIATED_LOGIN_PAGE_LABEL EQ string + )? (SAML2_ENABLE_SP_INITIATED EQ trueFalse)? (SAML2_SNOWFLAKE_X509_CERT EQ string)? ( + SAML2_SIGN_REQUEST EQ trueFalse + )? (SAML2_REQUESTED_NAMEID_FORMAT EQ string)? (SAML2_POST_LOGOUT_REDIRECT_URL EQ string)? ( + SAML2_FORCE_AUTHN EQ trueFalse + )? (SAML2_SNOWFLAKE_ISSUER_URL EQ string)? (SAML2_SNOWFLAKE_ACS_URL EQ string)? + ; + +createSecurityIntegrationScim + : CREATE (OR REPLACE)? SECURITY INTEGRATION (IF NOT EXISTS)? id TYPE EQ SCIM SCIM_CLIENT EQ string RUN_AS_ROLE EQ string networkPolicy? ( + SYNC_PASSWORD EQ trueFalse + )? (COMMENT EQ string)? + ; + +networkPolicy: NETWORK_POLICY EQ string + ; + +partnerApplication: TABLEAU_DESKTOP | TABLEAU_SERVER | LOOKER + ; + +startWith: START WITH? EQ? INT + ; + +incrementBy: INCREMENT BY? EQ? INT + ; + +createSequence + : CREATE (OR REPLACE)? SEQUENCE (IF NOT EXISTS)? dotIdentifier WITH? startWith? incrementBy? orderNoorder? ( + COMMENT EQ string + )? + ; + +createSessionPolicy + : CREATE (OR REPLACE)? SESSION POLICY (IF EXISTS)? id (SESSION_IDLE_TIMEOUT_MINS EQ INT)? ( + SESSION_UI_IDLE_TIMEOUT_MINS EQ INT + )? (COMMENT EQ string)? + ; + +createShare: CREATE (OR REPLACE)? SHARE id (COMMENT EQ string)? + ; + +formatTypeOptions + //-- If TYPE EQ CSV + : COMPRESSION EQ (AUTO | GZIP | BZ2 | BROTLI | ZSTD | DEFLATE | RAW_DEFLATE | NONE | string) + | RECORD_DELIMITER EQ ( string | NONE) + | FIELD_DELIMITER EQ ( string | NONE) + | FILE_EXTENSION EQ string + | SKIP_HEADER EQ INT + | SKIP_BLANK_LINES EQ trueFalse + | DATE_FORMAT EQ (string | AUTO) + | TIME_FORMAT EQ (string | AUTO) + | TIMESTAMP_FORMAT EQ (string | AUTO) + | BINARY_FORMAT EQ id + | ESCAPE EQ (string | NONE) + | ESCAPE_UNENCLOSED_FIELD EQ (string | NONE) + | TRIM_SPACE EQ trueFalse + | FIELD_OPTIONALLY_ENCLOSED_BY EQ (string | NONE) + | NULL_IF EQ LPAREN stringList RPAREN + | ERROR_ON_COLUMN_COUNT_MISMATCH EQ trueFalse + | REPLACE_INVALID_CHARACTERS EQ trueFalse + | EMPTY_FIELD_AS_NULL EQ trueFalse + | SKIP_BYTE_ORDER_MARK EQ trueFalse + | ENCODING EQ (string | id) //by the way other encoding keyword are valid ie WINDOWS1252 + //-- If TYPE EQ JSON + //| COMPRESSION EQ (AUTO | GZIP | BZ2 | BROTLI | ZSTD | DEFLATE | RAW_DEFLATE | NONE) + // | DATE_FORMAT EQ string | AUTO + // | TIME_FORMAT EQ string | AUTO + // | TIMESTAMP_FORMAT EQ string | AUTO + // | BINARY_FORMAT id + // | TRIM_SPACE EQ trueFalse + // | NULL_IF EQ LR_BRACKET stringList RR_BRACKET + // | FILE_EXTENSION EQ string + | ENABLE_OCTAL EQ trueFalse + | ALLOW_DUPLICATE EQ trueFalse + | STRIP_OUTER_ARRAY EQ trueFalse + | STRIP_NULL_VALUES EQ trueFalse + // | REPLACE_INVALID_CHARACTERS EQ trueFalse + | IGNORE_UTF8_ERRORS EQ trueFalse + // | SKIP_BYTE_ORDER_MARK EQ trueFalse + //-- If TYPE EQ AVRO + // | COMPRESSION EQ AUTO | GZIP | BROTLI | ZSTD | DEFLATE | RAW_DEFLATE | NONE + // | TRIM_SPACE EQ trueFalse + // | NULL_IF EQ LR_BRACKET stringList RR_BRACKET + //-- If TYPE EQ ORC + // | TRIM_SPACE EQ trueFalse + // | NULL_IF EQ LR_BRACKET stringList RR_BRACKET + //-- If TYPE EQ PARQUET + | COMPRESSION EQ AUTO + | LZO + | SNAPPY + | NONE + | SNAPPY_COMPRESSION EQ trueFalse + | BINARY_AS_TEXT EQ trueFalse + // | TRIM_SPACE EQ trueFalse + // | NULL_IF EQ LR_BRACKET stringList RR_BRACKET + //-- If TYPE EQ XML + | COMPRESSION EQ AUTO + | GZIP + | BZ2 + | BROTLI + | ZSTD + | DEFLATE + | RAW_DEFLATE + | NONE + // | IGNORE_UTF8_ERRORS EQ trueFalse + | PRESERVE_SPACE EQ trueFalse + | STRIP_OUTER_ELEMENT EQ trueFalse + | DISABLE_SNOWFLAKE_DATA EQ trueFalse + | DISABLE_AUTO_CONVERT EQ trueFalse + // | SKIP_BYTE_ORDER_MARK EQ trueFalse + ; + +copyOptions + : ON_ERROR EQ (CONTINUE | SKIP_FILE | SKIP_FILE_N | SKIP_FILE_N ABORT_STATEMENT) + | SIZE_LIMIT EQ INT + | PURGE EQ trueFalse + | RETURN_FAILED_ONLY EQ trueFalse + | MATCH_BY_COLUMN_NAME EQ CASE_SENSITIVE + | CASE_INSENSITIVE + | NONE + | ENFORCE_LENGTH EQ trueFalse + | TRUNCATECOLUMNS EQ trueFalse + | FORCE EQ trueFalse + ; + +stageEncryptionOptsInternal: ENCRYPTION EQ LPAREN TYPE EQ (SNOWFLAKE_FULL | SNOWFLAKE_SSE) RPAREN + ; + +storageIntegrationEqId: STORAGE_INTEGRATION EQ id + ; + +storageCredentials: CREDENTIALS EQ parenStringOptions + ; + +storageEncryption: ENCRYPTION EQ parenStringOptions + ; + +parenStringOptions: LPAREN stringOption* RPAREN + ; + +stringOption: id EQ string + ; + +externalStageParams: URL EQ string storageIntegrationEqId? storageCredentials? storageEncryption? + ; + +trueFalse: TRUE | FALSE + ; + +enable: ENABLE EQ trueFalse + ; + +refreshOnCreate: REFRESH_ON_CREATE EQ trueFalse + ; + +autoRefresh: AUTO_REFRESH EQ trueFalse + ; + +notificationIntegration: NOTIFICATION_INTEGRATION EQ string + ; + +directoryTableInternalParams + : DIRECTORY EQ LPAREN ( + enable refreshOnCreate? + | REFRESH_ON_CREATE EQ FALSE + | refreshOnCreate enable + ) RPAREN + ; + +directoryTableExternalParams + // (for Amazon S3) + : DIRECTORY EQ LPAREN enable refreshOnCreate? autoRefresh? RPAREN + // (for Google Cloud Storage) + | DIRECTORY EQ LPAREN enable autoRefresh? refreshOnCreate? notificationIntegration? RPAREN + // (for Microsoft Azure) + | DIRECTORY EQ LPAREN enable refreshOnCreate? autoRefresh? notificationIntegration? RPAREN + ; + +/* =========== Stage DDL section =========== */ +createStage + : CREATE (OR REPLACE)? temporary? STAGE (IF NOT EXISTS)? dotIdentifierOrIdent stageEncryptionOptsInternal? directoryTableInternalParams? ( + FILE_FORMAT EQ LPAREN (FORMAT_NAME EQ string | TYPE EQ typeFileformat formatTypeOptions*) RPAREN + )? (COPY_OPTIONS_ EQ LPAREN copyOptions RPAREN)? withTags? (COMMENT EQ string)? + | CREATE (OR REPLACE)? temporary? STAGE (IF NOT EXISTS)? dotIdentifierOrIdent externalStageParams directoryTableExternalParams? ( + FILE_FORMAT EQ LPAREN (FORMAT_NAME EQ string | TYPE EQ typeFileformat formatTypeOptions*) RPAREN + )? (COPY_OPTIONS_ EQ LPAREN copyOptions RPAREN)? withTags? (COMMENT EQ string)? + ; + +alterStage + : ALTER STAGE (IF EXISTS)? dotIdentifierOrIdent RENAME TO dotIdentifierOrIdent + | ALTER STAGE (IF EXISTS)? dotIdentifierOrIdent setTags + | ALTER STAGE (IF EXISTS)? dotIdentifierOrIdent unsetTags + | ALTER STAGE (IF EXISTS)? dotIdentifierOrIdent SET externalStageParams? fileFormat? ( + COPY_OPTIONS_ EQ LPAREN copyOptions RPAREN + )? (COMMENT EQ string)? + ; + +dropStage: DROP STAGE (IF EXISTS)? dotIdentifierOrIdent + ; + +describeStage: DESCRIBE STAGE dotIdentifierOrIdent + ; + +showStages: SHOW STAGES likePattern? inObj? + ; + +/* =========== End of stage DDL section =========== */ + +cloudProviderParams + : STORAGE_PROVIDER EQ string ( + AZURE_TENANT_ID EQ (string | ID) + | STORAGE_AWS_ROLE_ARN EQ string (STORAGE_AWS_OBJECT_ACL EQ string)? + )? + ; + +cloudProviderParams2 + //(for Amazon S3) + : STORAGE_AWS_ROLE_ARN EQ string (STORAGE_AWS_OBJECT_ACL EQ string)? + //(for Microsoft Azure) + | AZURE_TENANT_ID EQ string + ; + +cloudProviderParams3: INTEGRATION EQ string + ; + +createStorageIntegration + : CREATE (OR REPLACE)? STORAGE INTEGRATION (IF NOT EXISTS)? id TYPE EQ EXTERNAL_STAGE cloudProviderParams ENABLED EQ trueFalse + STORAGE_ALLOWED_LOCATIONS EQ LPAREN stringList RPAREN ( + STORAGE_BLOCKED_LOCATIONS EQ LPAREN stringList RPAREN + )? (COMMENT EQ string)? + ; + +copyGrants: COPY GRANTS + ; + +appendOnly: APPEND_ONLY EQ trueFalse + ; + +insertOnly: INSERT_ONLY EQ TRUE + ; + +showInitialRows: SHOW_INITIAL_ROWS EQ trueFalse + ; + +streamTime + : atBefore1 LPAREN ( + TIMESTAMP ASSOC string + | OFFSET ASSOC string + | STATEMENT ASSOC id + | STREAM ASSOC string + ) RPAREN + ; + +createStream + //-- table + : CREATE (OR REPLACE)? STREAM (IF NOT EXISTS)? dotIdentifier copyGrants? ON TABLE dotIdentifier streamTime? appendOnly? showInitialRows? ( + COMMENT EQ string + )? + //-- External table + | CREATE (OR REPLACE)? STREAM (IF NOT EXISTS)? dotIdentifier copyGrants? ON EXTERNAL TABLE dotIdentifier streamTime? insertOnly? ( + COMMENT EQ string + )? + //-- Directory table + | CREATE (OR REPLACE)? STREAM (IF NOT EXISTS)? dotIdentifier copyGrants? ON STAGE dotIdentifier ( + COMMENT EQ string + )? + //-- View + | CREATE (OR REPLACE)? STREAM (IF NOT EXISTS)? dotIdentifier copyGrants? ON VIEW dotIdentifier streamTime? appendOnly? showInitialRows? ( + COMMENT EQ string + )? + ; + +temporary: TEMP | TEMPORARY + ; + +tableType: (( LOCAL | GLOBAL)? temporary | VOLATILE) | TRANSIENT + ; + +withTags: WITH? TAG LPAREN tagDecl (COMMA tagDecl)* RPAREN + ; + +withRowAccessPolicy: WITH? ROW ACCESS POLICY id ON LPAREN columnName (COMMA columnName)* RPAREN + ; + +clusterBy: CLUSTER BY LINEAR? exprListInParentheses + ; + +changeTracking: CHANGE_TRACKING EQ trueFalse + ; + +withMaskingPolicy: WITH? MASKING POLICY id (USING columnListInParentheses)? + ; + +collate: COLLATE string + ; + +orderNoorder: ORDER | NOORDER + ; + +defaultValue + : DEFAULT expr + | (AUTOINCREMENT | IDENTITY) ( + LPAREN INT COMMA INT RPAREN + | startWith + | incrementBy + | startWith incrementBy + )? orderNoorder? + ; + +foreignKey: FOREIGN KEY + ; + +primaryKey: PRIMARY KEY + ; + +outOfLineConstraint + : (CONSTRAINT id)? ( + (UNIQUE | primaryKey) columnListInParentheses commonConstraintProperties* + | foreignKey columnListInParentheses REFERENCES dotIdentifier columnListInParentheses constraintProperties + ) + ; + +// TODO: Fix fullColDecl as defaultvalue being NULL and nullable are ambiguous - works with current visitor thoguh for now +fullColDecl + : colDecl (collate | inlineConstraint | (NOT? NULL) | (defaultValue | NULL))* withMaskingPolicy? withTags? ( + COMMENT string + )? + ; + +columnDeclItem: fullColDecl | outOfLineConstraint + ; + +columnDeclItemList: columnDeclItem (COMMA columnDeclItem)* + ; + +createTable + : CREATE (OR REPLACE)? tableType? TABLE ( + (IF NOT EXISTS)? dotIdentifier + | dotIdentifier (IF NOT EXISTS)? + ) (((COMMENT EQ string)? createTableClause) | (createTableClause (COMMENT EQ string)?)) + ; + +columnDeclItemListParen: LPAREN columnDeclItemList RPAREN + ; + +createTableClause + : ( + columnDeclItemListParen clusterBy? + | clusterBy? (COMMENT EQ string)? columnDeclItemListParen + ) stageFileFormat? (STAGE_COPY_OPTIONS EQ LPAREN copyOptions RPAREN)? ( + DATA_RETENTION_TIME_IN_DAYS EQ INT + )? (MAX_DATA_EXTENSION_TIME_IN_DAYS EQ INT)? changeTracking? defaultDdlCollation? copyGrants? ( + COMMENT EQ string + )? withRowAccessPolicy? withTags? + ; + +createTableAsSelect + : CREATE (OR REPLACE)? tableType? TABLE ( + (IF NOT EXISTS)? dotIdentifier + | dotIdentifier (IF NOT EXISTS)? + ) (LPAREN columnDeclItemList RPAREN)? clusterBy? copyGrants? withRowAccessPolicy? withTags? ( + COMMENT EQ string + )? AS LPAREN? queryStatement RPAREN? + ; + +createTableLike + : CREATE (OR REPLACE)? TRANSIENT? TABLE (IF NOT EXISTS)? dotIdentifier LIKE dotIdentifier clusterBy? copyGrants? + ; + +createTag + : CREATE (OR REPLACE)? TAG (IF NOT EXISTS)? dotIdentifier tagAllowedValues? (COMMENT EQ string)? + ; + +tagAllowedValues: ALLOWED_VALUES stringList + ; + +sessionParameter + : ABORT_DETACHED_QUERY + | ALLOW_CLIENT_MFA_CACHING + | ALLOW_ID_TOKEN + | AUTOCOMMIT + | AUTOCOMMIT_API_SUPPORTED + | BINARY_INPUT_FORMAT + | BINARY_OUTPUT_FORMAT + | CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS + | CLIENT_ENCRYPTION_KEY_SIZE + | CLIENT_MEMORY_LIMIT + | CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX + | CLIENT_METADATA_USE_SESSION_DATABASE + | CLIENT_PREFETCH_THREADS + | CLIENT_RESULT_CHUNK_SIZE + | CLIENT_RESULT_COLUMN_CASE_INSENSITIVE + | CLIENT_SESSION_KEEP_ALIVE + | CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY + | CLIENT_TIMESTAMP_TYPE_MAPPING + | DATA_RETENTION_TIME_IN_DAYS + | DATE_INPUT_FORMAT + | DATE_OUTPUT_FORMAT + | DEFAULT_DDL_COLLATION_ + | ENABLE_INTERNAL_STAGES_PRIVATELINK + | ENABLE_UNLOAD_PHYSICAL_TYPE_OPTIMIZATION + | ENFORCE_SESSION_POLICY + | ERROR_ON_NONDETERMINISTIC_MERGE + | ERROR_ON_NONDETERMINISTIC_UPDATE + | EXTERNAL_OAUTH_ADD_PRIVILEGED_ROLES_TO_BLOCKED_LIST + | GEOGRAPHY_OUTPUT_FORMAT + | GEOMETRY_OUTPUT_FORMAT + | INITIAL_REPLICATION_SIZE_LIMIT_IN_TB + | JDBC_TREAT_DECIMAL_AS_INT + | JDBC_TREAT_TIMESTAMP_NTZ_AS_UTC + | JDBC_USE_SESSION_TIMEZONE + | JSON_INDENT + | JS_TREAT_INTEGER_AS_BIGINT + | LOCK_TIMEOUT + | MAX_CONCURRENCY_LEVEL + | MAX_DATA_EXTENSION_TIME_IN_DAYS + | MULTI_STATEMENT_COUNT + | MIN_DATA_RETENTION_TIME_IN_DAYS + | NETWORK_POLICY + | SHARE_RESTRICTIONS + | PERIODIC_DATA_REKEYING + | PIPE_EXECUTION_PAUSED + | PREVENT_UNLOAD_TO_INLINE_URL + | PREVENT_UNLOAD_TO_INTERNAL_STAGES + | QUERY_TAG + | QUOTED_IDENTIFIERS_IGNORE_CASE + | REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_CREATION + | REQUIRE_STORAGE_INTEGRATION_FOR_STAGE_OPERATION + | ROWS_PER_RESULTSET + | SAML_IDENTITY_PROVIDER + | SIMULATED_DATA_SHARING_CONSUMER + | SSO_LOGIN_PAGE + | STATEMENT_QUEUED_TIMEOUT_IN_SECONDS + | STATEMENT_TIMEOUT_IN_SECONDS + | STRICT_JSON_OUTPUT + | SUSPEND_TASK_AFTER_NUM_FAILURES + | TIMESTAMP_DAY_IS_ALWAYS_24H + | TIMESTAMP_INPUT_FORMAT + | TIMESTAMP_LTZ_OUTPUT_FORMAT + | TIMESTAMP_NTZ_OUTPUT_FORMAT + | TIMESTAMP_OUTPUT_FORMAT + | TIMESTAMP_TYPE_MAPPING + | TIMESTAMP_TZ_OUTPUT_FORMAT + | TIMEZONE + | TIME_INPUT_FORMAT + | TIME_OUTPUT_FORMAT + | TRANSACTION_ABORT_ON_ERROR + | TRANSACTION_DEFAULT_ISOLATION_LEVEL + | TWO_DIGIT_CENTURY_START + | UNSUPPORTED_DDL_ACTION + | USE_CACHED_RESULT + | USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE + | USER_TASK_TIMEOUT_MS + | WEEK_OF_YEAR_POLICY + | WEEK_START + ; + +sessionParameterList: sessionParameter (COMMA sessionParameter)* + ; + +sessionParamsList: sessionParams (COMMA sessionParams)* + ; + +createTask + : CREATE (OR REPLACE)? TASK (IF NOT EXISTS)? dotIdentifier taskParameters* (COMMENT EQ string)? copyGrants? ( + AFTER dotIdentifier (COMMA dotIdentifier)* + )? (WHEN searchCondition)? AS sql + ; + +taskParameters + : taskCompute + | taskSchedule + | taskOverlap + | sessionParamsList + | taskTimeout + | taskSuspendAfterFailureNumber + | taskErrorIntegration + ; + +taskCompute + : WAREHOUSE EQ id + | USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE EQ ( + whCommonSize + | string + ) //Snowflake allow quoted warehouse size but must be without quote. + ; + +taskSchedule: SCHEDULE EQ string + ; + +taskTimeout: USER_TASK_TIMEOUT_MS EQ INT + ; + +taskSuspendAfterFailureNumber: SUSPEND_TASK_AFTER_NUM_FAILURES EQ INT + ; + +taskErrorIntegration: ERROR_INTEGRATION EQ id + ; + +taskOverlap: ALLOW_OVERLAPPING_EXECUTION EQ trueFalse + ; + +sql: EXECUTE IMMEDIATE DOLLAR_STRING | sqlClauses | call + ; + +// Snowfllake allows calls to special internal stored procedures, named x.y!entrypoint +call: CALL dotIdentifier (BANG id)? LPAREN exprList? RPAREN + ; + +createUser + : CREATE (OR REPLACE)? USER (IF NOT EXISTS)? id objectProperties? objectParams? sessionParams? + ; + +viewCol: columnName withMaskingPolicy withTags + ; + +createView + : CREATE (OR REPLACE)? SECURE? RECURSIVE? VIEW (IF NOT EXISTS)? dotIdentifier ( + LPAREN columnListWithComment RPAREN + )? viewCol* withRowAccessPolicy? withTags? copyGrants? (COMMENT EQ string)? AS queryStatement + ; + +createWarehouse + : CREATE (OR REPLACE)? WAREHOUSE (IF NOT EXISTS)? idFn (WITH? whProperties+)? whParams* + ; + +whCommonSize: XSMALL | SMALL | MEDIUM | LARGE | XLARGE | XXLARGE + ; + +whExtraSize: XXXLARGE | X4LARGE | X5LARGE | X6LARGE + ; + +whProperties + : WAREHOUSE_SIZE EQ (whCommonSize | whExtraSize | LOCAL_ID) + | WAREHOUSE_TYPE EQ (STANDARD | string) + | MAX_CLUSTER_COUNT EQ INT + | MIN_CLUSTER_COUNT EQ INT + | SCALING_POLICY EQ (STANDARD | ECONOMY) + | AUTO_SUSPEND (EQ INT | NULL) + | AUTO_RESUME EQ trueFalse + | INITIALLY_SUSPENDED EQ trueFalse + | RESOURCE_MONITOR EQ id + | (COMMENT EQ string) + | ENABLE_QUERY_ACCELERATION EQ trueFalse + | QUERY_ACCELERATION_MAX_SCALE_FACTOR EQ INT + | MAX_CONCURRENCY_LEVEL EQ INT + ; + +whParams + : MAX_CONCURRENCY_LEVEL EQ INT + | STATEMENT_QUEUED_TIMEOUT_IN_SECONDS EQ INT + | STATEMENT_TIMEOUT_IN_SECONDS EQ INT withTags? + ; + +objectTypeName + : ROLE + | USER + | WAREHOUSE + | INTEGRATION + | NETWORK POLICY + | SESSION POLICY + | DATABASE + | SCHEMA + | TABLE + | VIEW + | STAGE + | FILE FORMAT + | STREAM + | TASK + | MASKING POLICY + | ROW ACCESS POLICY + | TAG + | PIPE + | FUNCTION + | PROCEDURE + | SEQUENCE + ; + +objectTypePlural + : ROLES + | USERS + | WAREHOUSES + | INTEGRATIONS + | DATABASES + | SCHEMAS + | TABLES + | VIEWS + | STAGES + | STREAMS + | TASKS + | ALERTS + ; + +// drop commands +dropCommand + : dropObject + | dropAlert + | dropConnection + | dropDatabase + | dropDynamicTable + //| dropEventTable //uses DROP TABLE stmt + | dropExternalTable + | dropFailoverGroup + | dropFileFormat + | dropFunction + | dropIntegration + | dropManagedAccount + | dropMaskingPolicy + | dropMaterializedView + | dropNetworkPolicy + | dropPipe + | dropProcedure + | dropReplicationGroup + | dropResourceMonitor + | dropRole + | dropRowAccessPolicy + | dropSchema + | dropSequence + | dropSessionPolicy + | dropShare + | dropStage + | dropStream + | dropTable + | dropTag + | dropTask + | dropUser + | dropView + | dropWarehouse + ; + +dropObject: DROP objectType (IF EXISTS)? id cascadeRestrict? + ; + +dropAlert: DROP ALERT id + ; + +dropConnection: DROP CONNECTION (IF EXISTS)? id + ; + +dropDatabase: DROP DATABASE (IF EXISTS)? id cascadeRestrict? + ; + +dropDynamicTable: DROP DYNAMIC TABLE id + ; + +dropExternalTable: DROP EXTERNAL TABLE (IF EXISTS)? dotIdentifier cascadeRestrict? + ; + +dropFailoverGroup: DROP FAILOVER GROUP (IF EXISTS)? id + ; + +dropFileFormat: DROP FILE FORMAT (IF EXISTS)? id + ; + +dropFunction: DROP FUNCTION (IF EXISTS)? dotIdentifier argTypes + ; + +dropIntegration: DROP (API | NOTIFICATION | SECURITY | STORAGE)? INTEGRATION (IF EXISTS)? id + ; + +dropManagedAccount: DROP MANAGED ACCOUNT id + ; + +dropMaskingPolicy: DROP MASKING POLICY id + ; + +dropMaterializedView: DROP MATERIALIZED VIEW (IF EXISTS)? dotIdentifier + ; + +dropNetworkPolicy: DROP NETWORK POLICY (IF EXISTS)? id + ; + +dropPipe: DROP PIPE (IF EXISTS)? dotIdentifier + ; + +dropReplicationGroup: DROP REPLICATION GROUP (IF EXISTS)? id + ; + +dropResourceMonitor: DROP RESOURCE MONITOR id + ; + +dropRole: DROP ROLE (IF EXISTS)? id + ; + +dropRowAccessPolicy: DROP ROW ACCESS POLICY (IF EXISTS)? id + ; + +dropSchema: DROP SCHEMA (IF EXISTS)? schemaName cascadeRestrict? + ; + +dropSequence: DROP SEQUENCE (IF EXISTS)? dotIdentifier cascadeRestrict? + ; + +dropSessionPolicy: DROP SESSION POLICY (IF EXISTS)? id + ; + +dropShare: DROP SHARE id + ; + +dropStream: DROP STREAM (IF EXISTS)? dotIdentifier + ; + +dropTable: DROP TABLE (IF EXISTS)? dotIdentifier cascadeRestrict? + ; + +dropTag: DROP TAG (IF EXISTS)? dotIdentifier + ; + +dropTask: DROP TASK (IF EXISTS)? dotIdentifier + ; + +dropUser: DROP USER (IF EXISTS)? id + ; + +dropView: DROP VIEW (IF EXISTS)? dotIdentifier + ; + +dropWarehouse: DROP WAREHOUSE (IF EXISTS)? idFn + ; + +cascadeRestrict: CASCADE | RESTRICT + ; + +argTypes: LPAREN dataTypeList? RPAREN + ; + +// undrop commands +undropCommand + : undropDatabase + | undropSchema + | undropTable + | undropTag //: undropObject + ; + +undropDatabase: UNDROP DATABASE id + ; + +undropSchema: UNDROP SCHEMA schemaName + ; + +undropTable: UNDROP TABLE dotIdentifier + ; + +undropTag: UNDROP TAG dotIdentifier + ; + +// use commands +useCommand: useDatabase | useRole | useSchema | useSecondaryRoles | useWarehouse + ; + +useDatabase: USE DATABASE id + ; + +useRole: USE ROLE id + ; + +useSchema: USE SCHEMA? (id DOT)? id + ; + +useSecondaryRoles: USE SECONDARY ROLES (ALL | NONE) + ; + +useWarehouse: USE WAREHOUSE idFn + ; + +// describe command +describeCommand + : describeAlert + | describeDatabase + | describeDynamicTable + | describeEventTable + | describeExternalTable + | describeFileFormat + | describeFunction + | describeIntegration + | describeMaskingPolicy + | describeMaterializedView + | describeNetworkPolicy + | describePipe + | describeProcedure + | describeResult + | describeRowAccessPolicy + | describeSchema + | describeSearchOptimization + | describeSequence + | describeSessionPolicy + | describeShare + | describeStage + | describeStream + | describeTable + | describeTask + | describeTransaction + | describeUser + | describeView + | describeWarehouse + ; + +describeAlert: DESCRIBE ALERT id + ; + +describeDatabase: DESCRIBE DATABASE id + ; + +describeDynamicTable: DESCRIBE DYNAMIC TABLE id + ; + +describeEventTable: DESCRIBE EVENT TABLE id + ; + +describeExternalTable: DESCRIBE EXTERNAL? TABLE dotIdentifier (TYPE EQ (COLUMNS | STAGE))? + ; + +describeFileFormat: DESCRIBE FILE FORMAT id + ; + +describeFunction: DESCRIBE FUNCTION dotIdentifier argTypes + ; + +describeIntegration: DESCRIBE (API | NOTIFICATION | SECURITY | STORAGE)? INTEGRATION id + ; + +describeMaskingPolicy: DESCRIBE MASKING POLICY id + ; + +describeMaterializedView: DESCRIBE MATERIALIZED VIEW dotIdentifier + ; + +describeNetworkPolicy: DESCRIBE NETWORK POLICY id + ; + +describePipe: DESCRIBE PIPE dotIdentifier + ; + +describeProcedure: DESCRIBE PROCEDURE dotIdentifier argTypes + ; + +describeResult: DESCRIBE RESULT (string | LAST_QUERY_ID LPAREN RPAREN) + ; + +describeRowAccessPolicy: DESCRIBE ROW ACCESS POLICY id + ; + +describeSchema: DESCRIBE SCHEMA schemaName + ; + +describeSearchOptimization: DESCRIBE SEARCH OPTIMIZATION ON dotIdentifier + ; + +describeSequence: DESCRIBE SEQUENCE dotIdentifier + ; + +describeSessionPolicy: DESCRIBE SESSION POLICY id + ; + +describeShare: DESCRIBE SHARE id + ; + +describeStream: DESCRIBE STREAM dotIdentifier + ; + +describeTable: DESCRIBE TABLE dotIdentifier (TYPE EQ (COLUMNS | STAGE))? + ; + +describeTask: DESCRIBE TASK dotIdentifier + ; + +describeTransaction: DESCRIBE TRANSACTION INT + ; + +describeUser: DESCRIBE USER id + ; + +describeView: DESCRIBE VIEW dotIdentifier + ; + +describeWarehouse: DESCRIBE WAREHOUSE id + ; + +// show commands +showCommand + : showAlerts + | showChannels + | showColumns + | showConnections + | showDatabases + | showDatabasesInFailoverGroup + | showDatabasesInReplicationGroup + | showDelegatedAuthorizations + | showDynamicTables + | showEventTables + | showExternalFunctions + | showExternalTables + | showFailoverGroups + | showFileFormats + | showFunctions + | showGlobalAccounts + | showGrants + | showIntegrations + | showLocks + | showManagedAccounts + | showMaskingPolicies + | showMaterializedViews + | showNetworkPolicies + | showObjects + | showOrganizationAccounts + | showParameters + | showPipes + | showPrimaryKeys + | showProcedures + | showRegions + | showReplicationAccounts + | showReplicationDatabases + | showReplicationGroups + | showResourceMonitors + | showRoles + | showRowAccessPolicies + | showSchemas + | showSequences + | showSessionPolicies + | showShares + | showSharesInFailoverGroup + | showSharesInReplicationGroup + | showStages + | showStreams + | showTables + | showTags + | showTasks + | showTransactions + | showUserFunctions + | showUsers + | showVariables + | showViews + | showWarehouses + ; + +showAlerts + : SHOW TERSE? ALERTS likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA schemaName?))? startsWith? limitRows? + ; + +showChannels + : SHOW CHANNELS likePattern? ( + IN (ACCOUNT | DATABASE id? | SCHEMA schemaName? | TABLE | TABLE? dotIdentifier) + )? + ; + +showColumns + : SHOW COLUMNS likePattern? ( + IN ( + ACCOUNT + | DATABASE id? + | SCHEMA schemaName? + | TABLE + | TABLE? dotIdentifier + | VIEW + | VIEW? dotIdentifier + ) + )? + ; + +showConnections: SHOW CONNECTIONS likePattern? + ; + +startsWith: STARTS WITH string + ; + +limitRows: LIMIT INT (FROM string)? + ; + +showDatabases: SHOW TERSE? DATABASES HISTORY? likePattern? startsWith? limitRows? + ; + +showDatabasesInFailoverGroup: SHOW DATABASES IN FAILOVER GROUP id + ; + +showDatabasesInReplicationGroup: SHOW DATABASES IN REPLICATION GROUP id + ; + +showDelegatedAuthorizations + : SHOW DELEGATED AUTHORIZATIONS + | SHOW DELEGATED AUTHORIZATIONS BY USER id + | SHOW DELEGATED AUTHORIZATIONS TO SECURITY INTEGRATION id + ; + +showDynamicTables + : SHOW DYNAMIC TABLES likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA? schemaName?))? startsWith? limitRows? + ; + +showEventTables + : SHOW TERSE? EVENT TABLES likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA? schemaName?))? startsWith? limitRows? + ; + +showExternalFunctions: SHOW EXTERNAL FUNCTIONS likePattern? + ; + +showExternalTables + : SHOW TERSE? EXTERNAL TABLES likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA? schemaName?))? startsWith? limitRows? + ; + +showFailoverGroups: SHOW FAILOVER GROUPS (IN ACCOUNT id)? + ; + +showFileFormats + : SHOW FILE FORMATS likePattern? ( + IN (ACCOUNT | DATABASE | DATABASE id | SCHEMA | SCHEMA schemaName | schemaName) + )? + ; + +showFunctions + : SHOW FUNCTIONS likePattern? (IN ( ACCOUNT | DATABASE | DATABASE id | SCHEMA | SCHEMA id | id))? + ; + +showGlobalAccounts: SHOW GLOBAL ACCOUNTS likePattern? + ; + +showGrants + : SHOW GRANTS showGrantsOpts? + | SHOW FUTURE GRANTS IN SCHEMA schemaName + | SHOW FUTURE GRANTS IN DATABASE id + ; + +showGrantsOpts + : ON ACCOUNT + | ON objectType dotIdentifier + | TO (ROLE id | USER id | SHARE id) + | OF ROLE id + | OF SHARE id + ; + +showIntegrations: SHOW (API | NOTIFICATION | SECURITY | STORAGE)? INTEGRATIONS likePattern? + ; + +showLocks: SHOW LOCKS (IN ACCOUNT)? + ; + +showManagedAccounts: SHOW MANAGED ACCOUNTS likePattern? + ; + +showMaskingPolicies: SHOW MASKING POLICIES likePattern? inObj? + ; + +inObj: IN (ACCOUNT | DATABASE | DATABASE id | SCHEMA | SCHEMA schemaName | schemaName) + ; + +inObj2: IN (ACCOUNT | DATABASE id? | SCHEMA schemaName? | TABLE | TABLE dotIdentifier) + ; + +showMaterializedViews: SHOW MATERIALIZED VIEWS likePattern? inObj? + ; + +showNetworkPolicies: SHOW NETWORK POLICIES + ; + +showObjects: SHOW OBJECTS likePattern? inObj? + ; + +showOrganizationAccounts: SHOW ORGANIZATION ACCOUNTS likePattern? + ; + +inFor: IN | FOR + ; + +showParameters + : SHOW PARAMETERS likePattern? ( + inFor ( + SESSION + | ACCOUNT + | USER id? + | ( WAREHOUSE | DATABASE | SCHEMA | TASK) id? + | TABLE dotIdentifier + ) + )? + ; + +showPipes: SHOW PIPES likePattern? inObj? + ; + +showPrimaryKeys: SHOW TERSE? PRIMARY KEYS inObj2? + ; + +showProcedures: SHOW PROCEDURES likePattern? inObj? + ; + +showRegions: SHOW REGIONS likePattern? + ; + +showReplicationAccounts: SHOW REPLICATION ACCOUNTS likePattern? + ; + +showReplicationDatabases: SHOW REPLICATION DATABASES likePattern? (WITH PRIMARY id DOT id)? + ; + +showReplicationGroups: SHOW REPLICATION GROUPS (IN ACCOUNT id)? + ; + +showResourceMonitors: SHOW RESOURCE MONITORS likePattern? + ; + +showRoles: SHOW ROLES likePattern? + ; + +showRowAccessPolicies: SHOW ROW ACCESS POLICIES likePattern? inObj? + ; + +showSchemas + : SHOW TERSE? SCHEMAS HISTORY? likePattern? (IN ( ACCOUNT | DATABASE id?))? startsWith? limitRows? + ; + +showSequences: SHOW SEQUENCES likePattern? inObj? + ; + +showSessionPolicies: SHOW SESSION POLICIES + ; + +showShares: SHOW SHARES likePattern? + ; + +showSharesInFailoverGroup: SHOW SHARES IN FAILOVER GROUP id + ; + +showSharesInReplicationGroup: SHOW SHARES IN REPLICATION GROUP id + ; + +showStreams: SHOW STREAMS likePattern? inObj? + ; + +showTables: SHOW TABLES likePattern? inObj? + ; + +showTags + : SHOW TAGS likePattern? ( + IN ACCOUNT + | DATABASE + | DATABASE id + | SCHEMA + | SCHEMA schemaName + | schemaName + )? + ; + +showTasks + : SHOW TERSE? TASKS likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA? schemaName?))? startsWith? limitRows? + ; + +showTransactions: SHOW TRANSACTIONS (IN ACCOUNT)? + ; + +showUserFunctions: SHOW USER FUNCTIONS likePattern? inObj? + ; + +showUsers: SHOW TERSE? USERS likePattern? (STARTS WITH string)? (LIMIT INT)? (FROM string)? + ; + +showVariables: SHOW VARIABLES likePattern? + ; + +showViews + : SHOW TERSE? VIEWS likePattern? (IN ( ACCOUNT | DATABASE id? | SCHEMA? schemaName?))? startsWith? limitRows? + ; + +showWarehouses: SHOW WAREHOUSES likePattern? + ; + +likePattern: LIKE string + ; + +// TODO: Fix this - it is jsut a dotIdentifer - if ther are too many dots, its not the parser's problem +schemaName: d = id DOT s = id | s = id + ; + +objectType + : ACCOUNT PARAMETERS + | DATABASES + | INTEGRATIONS + | NETWORK POLICIES + | RESOURCE MONITORS + | ROLES + | SHARES + | USERS + | WAREHOUSES + ; + +objectTypeList: objectType (COMMA objectType)* + ; + +// Strings are not a single token match but a stream of parts which may consist +// of variable references as well as plain text. +string: STRING_START stringPart* STRING_END | DOLLAR_STRING + ; + +stringPart + : ( + VAR_SIMPLE + | VAR_COMPLEX + | STRING_CONTENT + | STRING_UNICODE + | STRING_ESCAPE + | STRING_SQUOTE + | STRING_AMPAMP + ) + ; + +stringList: string (COMMA string)* + ; + +idFn: id | IDENTIFIER LPAREN id RPAREN + ; + +id + : ID + | LOCAL_ID + | DOUBLE_QUOTE_ID + | AMP LCB? ID RCB? // Snowflake variables from CLI or injection - we rely on valid input + | keyword // almost any ketword can be used as an id :( + ; + +pattern: PATTERN EQ string + ; + +columnName: (id DOT)? id + ; + +columnList: columnName (COMMA columnName)* + ; + +columnListWithComment: columnName (COMMENT string)? (COMMA columnName (COMMENT string)?)* + ; + +dotIdentifier: id (DOT id)* + ; + +dotIdentifierOrIdent: dotIdentifier | IDENTIFIER LPAREN string RPAREN + ; + +/*** expressions ***/ +exprList: expr (COMMA expr)* + ; + +// Snowflake stupidly allows AND and OR in any expression that results in a purely logical +// TRUE/FALSE result and is not involved in, say, a predicate. So we must also allow that, +// even though it messes with precedence somewhat. However, as we only see queries that +// parse/work in Snowflake, there is no practical effect on correct parsing. It is a PITA +// that we have had to rename rules to make it make sense though. +expr + : op = NOT+ expr # exprNot + | expr AND expr # exprAnd + | expr OR expr # exprOr + | expression # nonLogicalExpression + ; + +// Use this entry point into epxression when allowing AND and OR would be ambiguous, such as in +// searchConditions. +expression + : LPAREN expression RPAREN # exprPrecedence + | dotIdentifier DOT NEXTVAL # exprNextval + | expression DOT expression # exprDot + | expression COLON expression # exprColon + | expression COLLATE string # exprCollate + | caseExpression # exprCase + | iffExpr # exprIff + | sign expression # exprSign + | expression op = (STAR | DIVIDE | MODULE) expression # exprPrecedence0 + | expression op = (PLUS | MINUS | PIPE_PIPE) expression # exprPrecedence1 + | expression comparisonOperator expression # exprComparison + | expression COLON_COLON dataType # exprAscribe + | expression withinGroup # exprWithinGroup + | expression overClause # exprOver + | castExpr # exprCast + | functionCall # exprFuncCall + | DISTINCT expression # exprDistinct + | LPAREN subquery RPAREN # exprSubquery + | primitiveExpression # exprPrimitive + ; + +withinGroup: WITHIN GROUP LPAREN orderByClause RPAREN + ; + +iffExpr: IFF LPAREN searchCondition COMMA expr COMMA expr RPAREN + ; + +castExpr: castOp = (TRY_CAST | CAST) LPAREN expr AS dataType RPAREN | INTERVAL expr + ; + +jsonLiteral: LCB kvPair (COMMA kvPair)* RCB | LCB RCB + ; + +kvPair: key = string COLON literal + ; + +arrayLiteral: LSB expr (COMMA expr)* RSB | LSB RSB + ; + +dataType + : OBJECT (LPAREN objectField (COMMA objectField)* RPAREN)? + | ARRAY (LPAREN dataType RPAREN)? + | id discard? (LPAREN INT (COMMA INT)? RPAREN)? + ; + +// Caters for things like DOUBLE PRECISION where the PRECISION isn't needed +// and hrow awaya keywords that are also verbose for no good reason +discard: VARYING | id + ; + +objectField: id dataType + ; + +primitiveExpression + : DEFAULT # primExprDefault //? + | id LSB INT RSB # primArrayAccess + | id LSB string RSB # primObjectAccess + | id # primExprColumn + | literal # primExprLiteral + | COLON id # primVariable // TODO: This needs to move to main expression as expression COLON expression when JSON is implemented + ; + +overClause: OVER LPAREN (PARTITION BY expr (COMMA expr)*)? windowOrderingAndFrame? RPAREN + ; + +windowOrderingAndFrame: orderByClause rowOrRangeClause? + ; + +rowOrRangeClause: (ROWS | RANGE) windowFrameExtent + ; + +windowFrameExtent: BETWEEN windowFrameBound AND windowFrameBound + ; + +windowFrameBound: UNBOUNDED (PRECEDING | FOLLOWING) | INT (PRECEDING | FOLLOWING) | CURRENT ROW + ; + +functionCall: builtinFunction | standardFunction | rankingWindowedFunction | aggregateFunction + ; + +builtinFunction: EXTRACT LPAREN (string | ID) FROM expr RPAREN # builtinExtract + ; + +standardFunction + : functionOptionalBrackets (LPAREN exprList? RPAREN)? + | functionName LPAREN (exprList | paramAssocList)? RPAREN + ; + +functionName: id | nonReservedFunctionName + ; + +nonReservedFunctionName + : LEFT + | RIGHT // keywords that cannot be used as id, but can be used as function names + ; + +functionOptionalBrackets + : CURRENT_DATE // https://docs.snowflake.com/en/sql-reference/functions/current_date + | CURRENT_TIMESTAMP // https://docs.snowflake.com/en/sql-reference/functions/current_timestamp + | CURRENT_TIME // https://docs.snowflake.com/en/sql-reference/functions/current_time + | LOCALTIME // https://docs.snowflake.com/en/sql-reference/functions/localtime + | LOCALTIMESTAMP // https://docs.snowflake.com/en/sql-reference/functions/localtimestamp + ; + +paramAssocList: paramAssoc (COMMA paramAssoc)* + ; + +paramAssoc: assocId ASSOC expr + ; + +assocId + : id + | OUTER // Outer is stupidly used as a parameter name in FLATTEN() - but we don't want it as an id + ; + +ignoreOrRepectNulls: (IGNORE | RESPECT) NULLS + ; + +rankingWindowedFunction: standardFunction ignoreOrRepectNulls? overClause + ; + +aggregateFunction + : op = (LISTAGG | ARRAY_AGG) LPAREN DISTINCT? expr (COMMA string)? RPAREN ( + WITHIN GROUP LPAREN orderByClause RPAREN + )? # aggFuncList + | id LPAREN DISTINCT? exprList RPAREN # aggFuncExprList + | id LPAREN STAR RPAREN # aggFuncStar + ; + +literal + : TIMESTAMP string + | string + | sign? INT + | sign? (REAL | FLOAT) + | trueFalse + | jsonLiteral + | arrayLiteral + | NULL + | PARAM // A question mark can be used as a placeholder for a prepared statement that will use binding. + | id string // DATE and any other additions + ; + +constant: string | INT | REAL | FLOAT + ; + +sign: PLUS | MINUS + ; + +caseExpression + : CASE expr switchSection+ (ELSE expr)? END + | CASE switchSearchConditionSection+ (ELSE expr)? END + ; + +switchSearchConditionSection: WHEN searchCondition THEN expr + ; + +switchSection: WHEN expr THEN expr + ; + +queryStatement: withExpression? queryExpression + ; + +withExpression: WITH RECURSIVE? commonTableExpression (COMMA commonTableExpression)* + ; + +commonTableExpression + : tableName = id (LPAREN columnList RPAREN)? AS LPAREN queryExpression RPAREN # CTETable + | id AS LPAREN expr RPAREN # CTEColumn + ; + +queryExpression + // INTERSECT has higher precedence than EXCEPT and UNION ALL. + // MINUS is an alias for EXCEPT + // Reference: https://docs.snowflake.com/en/sql-reference/operators-query.html#:~:text=precedence + : LPAREN queryExpression RPAREN # queryInParenthesis + | queryExpression INTERSECT queryExpression # queryIntersect + | queryExpression (UNION ALL? | EXCEPT | MINUS_) queryExpression # queryUnion + | selectStatement # querySimple + ; + +selectStatement + : selectClause selectOptionalClauses limitClause? + | selectTopClause selectOptionalClauses //TOP and LIMIT are not allowed together + | LPAREN selectStatement RPAREN + ; + +selectOptionalClauses + : intoClause? fromClause? whereClause? (groupByClause | havingClause)? qualifyClause? orderByClause? + ; + +selectClause: SELECT selectListNoTop + ; + +selectTopClause: SELECT selectListTop + ; + +selectListNoTop: allDistinct? selectList + ; + +selectListTop: allDistinct? topClause? selectList + ; + +selectList: selectListElem (COMMA selectListElem)* + ; + +selectListElem + : expressionElem asAlias? + | columnElem asAlias? + | columnElemStar + // | udtElem + ; + +columnElemStar: (dotIdentifier DOT)? STAR + ; + +columnElem: (dotIdentifier DOT)? columnName | (dotIdentifier DOT)? DOLLAR columnPosition + ; + +asAlias: AS? alias + ; + +expressionElem: searchCondition | expr + ; + +columnPosition: INT + ; + +allDistinct: ALL | DISTINCT + ; + +topClause: TOP expr + ; + +intoClause: INTO varList + ; + +varList: var (COMMA var)* + ; + +var: COLON id + ; + +fromClause + : FROM tableSources // objectRef joinClause* + ; + +tableSources: tableSource (COMMA tableSource)* + ; + +tableSource: tableSourceItemJoined sample? | LPAREN tableSource RPAREN + ; + +tableSourceItemJoined: objectRef joinClause* | LPAREN tableSourceItemJoined RPAREN joinClause* + ; + +objectRef + : dotIdentifier atBefore? changes? matchRecognize? pivotUnpivot? tableAlias? # objRefDefault + | TABLE LPAREN functionCall RPAREN pivotUnpivot? tableAlias? # objRefTableFunc + | valuesTable tableAlias? # objRefValues + | LATERAL? (functionCall | (LPAREN subquery RPAREN)) pivotUnpivot? tableAlias? # objRefSubquery + | dotIdentifier START WITH searchCondition CONNECT BY priorList? # objRefStartWith + ; + +tableAlias: AS? alias (LPAREN id (COMMA id)* RPAREN)? + ; + +priorList: priorItem (COMMA priorItem)* + ; + +priorItem: PRIOR? id EQ PRIOR? id + ; + +outerJoin: (LEFT | RIGHT | FULL) OUTER? + ; + +joinType: INNER | outerJoin + ; + +joinClause + : joinType? JOIN objectRef ((ON searchCondition) | (USING LPAREN columnList RPAREN))? + | NATURAL outerJoin? JOIN objectRef + | CROSS JOIN objectRef + ; + +atBefore + : AT_KEYWORD LPAREN ( + TIMESTAMP ASSOC expr + | OFFSET ASSOC expr + | STATEMENT ASSOC string + | STREAM ASSOC string + ) RPAREN + | BEFORE LPAREN STATEMENT ASSOC string RPAREN + ; + +end: END LPAREN ( TIMESTAMP ASSOC expr | OFFSET ASSOC expr | STATEMENT ASSOC string) RPAREN + ; + +changes: CHANGES LPAREN INFORMATION ASSOC defaultAppendOnly RPAREN atBefore end? + ; + +defaultAppendOnly: DEFAULT | APPEND_ONLY + ; + +partitionBy: PARTITION BY exprList + ; + +alias: id + ; + +exprAliasList: expr AS? alias (COMMA expr AS? alias)* + ; + +measures: MEASURES exprAliasList + ; + +matchOpts: SHOW EMPTY MATCHES | OMIT EMPTY MATCHES | WITH UNMATCHED ROWS + ; + +rowMatch: (ONE ROW PER MATCH | ALL ROWS PER MATCH) matchOpts? + ; + +firstLast: FIRST | LAST + ; + +// TODO: This syntax is unfinished and needs to be completed - DUMMY is just a placeholder from the original author +symbol: DUMMY + ; + +afterMatch: AFTER MATCH KWSKIP (PAST LAST ROW | TO NEXT ROW | TO firstLast? symbol) + ; + +symbolList: symbol AS expr (COMMA symbol AS expr)* + ; + +define: DEFINE symbolList + ; + +matchRecognize + : MATCH_RECOGNIZE LPAREN partitionBy? orderByClause? measures? rowMatch? afterMatch? pattern? define? RPAREN + ; + +pivotUnpivot + : PIVOT LPAREN aggregateFunc = id LPAREN pivotColumn = id RPAREN FOR valueColumn = id IN LPAREN values += literal ( + COMMA values += literal + )* RPAREN RPAREN (asAlias columnAliasListInBrackets?)? + | UNPIVOT LPAREN valueColumn = id FOR nameColumn = id IN LPAREN columnList RPAREN RPAREN + ; + +columnAliasListInBrackets: LPAREN id (COMMA id)* RPAREN + ; + +exprListInParentheses: LPAREN exprList RPAREN + ; + +valuesTable: LPAREN valuesTableBody RPAREN | valuesTableBody + ; + +valuesTableBody: VALUES exprListInParentheses (COMMA exprListInParentheses)* + ; + +sampleMethod + : (SYSTEM | BLOCK) LPAREN INT RPAREN # sampleMethodBlock + | (BERNOULLI | ROW)? LPAREN INT ROWS RPAREN # sampleMethodRowFixed + | (BERNOULLI | ROW)? LPAREN INT RPAREN # sampleMethodRowProba + ; + +sample: (SAMPLE | TABLESAMPLE) sampleMethod sampleSeed? + ; + +sampleSeed: (REPEATABLE | SEED) LPAREN INT RPAREN + ; + +comparisonOperator: EQ | GT | LT | LE | GE | LTGT | NE + ; + +subquery: queryStatement + ; + +searchCondition + : LPAREN searchCondition RPAREN # scPrec + | NOT searchCondition # scNot + | searchCondition AND searchCondition # scAnd + | searchCondition OR searchCondition # scOr + | predicate # scPred + ; + +predicate + : EXISTS LPAREN subquery RPAREN # predExists + | expression comparisonOperator expression # predBinop + | expression comparisonOperator (ALL | SOME | ANY) LPAREN subquery RPAREN # predASA + | expression IS NOT? NULL # predIsNull + | expression NOT? IN LPAREN (subquery | exprList) RPAREN # predIn + | expression NOT? BETWEEN expression AND expression # predBetween + | expression NOT? op = (LIKE | ILIKE) expression (ESCAPE expression)? # predLikeSinglePattern + | expression NOT? op = (LIKE | ILIKE) (ANY | ALL) exprListInParentheses (ESCAPE expression)? # predLikeMultiplePatterns + | expression NOT? RLIKE expression # predRLike + | expression # predExpr + ; + +whereClause: WHERE searchCondition + ; + +groupByElem: columnElem | INT | expressionElem + ; + +groupByList: groupByElem (COMMA groupByElem)* + ; + +groupByClause + : GROUP BY groupByList havingClause? + | GROUP BY (GROUPING SETS | id) LPAREN groupByList RPAREN + | GROUP BY ALL + ; + +havingClause: HAVING searchCondition + ; + +qualifyClause: QUALIFY expr + ; + +orderItem: expr (ASC | DESC)? (NULLS ( FIRST | LAST))? + ; + +orderByClause: ORDER BY orderItem (COMMA orderItem)* + ; + +limitClause + : LIMIT expr (OFFSET expr)? + | (OFFSET expr)? (ROW | ROWS)? FETCH (FIRST | NEXT)? expr (ROW | ROWS)? ONLY? + ; diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/.gitignore b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/.gitignore new file mode 100644 index 0000000000..4f62b849d5 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/.gitignore @@ -0,0 +1 @@ +gen diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlLexer.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlLexer.g4 new file mode 100644 index 0000000000..b631ba9147 --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlLexer.g4 @@ -0,0 +1,115 @@ +/* +T-SQL (Transact-SQL, MSSQL) grammar. +The MIT License (MIT). +Copyright (c) 2017, Mark Adams (madams51703@gmail.com) +Copyright (c) 2015-2017, Ivan Kochurkin (kvanttt@gmail.com), Positive Technologies. +Copyright (c) 2016, Scott Ure (scott@redstormsoftware.com). +Copyright (c) 2016, Rui Zhang (ruizhang.ccs@gmail.com). +Copyright (c) 2016, Marcus Henriksson (kuseman80@gmail.com). +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignTrailingComments true +// $antlr-format columnLimit 150 +// $antlr-format maxEmptyLinesToKeep 1 +// $antlr-format reflowComments false +// $antlr-format useTab false +// $antlr-format allowShortRulesOnASingleLine true +// $antlr-format allowShortBlocksOnASingleLine true +// $antlr-format minEmptyLines 0 +// $antlr-format alignSemicolons ownLine +// $antlr-format alignColons trailing +// $antlr-format singleLineOverrulesHangingColon true +// $antlr-format alignLexerCommands true +// $antlr-format alignLabels true +// $antlr-format alignTrailers true +// ================================================================================= +lexer grammar TSqlLexer; + +import commonlex; + +options { + caseInsensitive = true; +} + +@members { + private static int TSQL_DIALECT = 1; + private static int SNOWFLAKE_DIALECT = 2; + private static int dialect = SNOWFLAKE_DIALECT; +} + +// Specials for graph nodes +NODEID: '$NODE_ID'; + +DOLLAR_ACTION: '$ACTION'; + +// Functions starting with double at signs +AAPSEUDO: '@@' ID; + +STRING options { + caseInsensitive = false; +}: 'N'? '\'' ('\\' . | '\'\'' | ~['])* '\''; + +HEX : '0X' HexDigit*; +INT : [0-9]+; +FLOAT : DEC_DOT_DEC; +REAL : (INT | DEC_DOT_DEC) ('E' [+-]? [0-9]+); +MONEY : '$' (INT | FLOAT); + +EQ : '='; +GT : '>'; +LT : '<'; +BANG : '!'; +PE : '+='; +ME : '-='; +SE : '*='; +DE : '/='; +MEA : '%='; +AND_ASSIGN : '&='; +XOR_ASSIGN : '^='; +OR_ASSIGN : '|='; + +DOUBLE_BAR : '||'; +DOT : '.'; +AT : '@'; +DOLLAR : '$'; +LPAREN : '('; +RPAREN : ')'; +COMMA : ','; +SEMI : ';'; +COLON : ':'; +DOUBLE_COLON : '::'; +STAR : '*'; +DIV : '/'; +MOD : '%'; +PLUS : '+'; +MINUS : '-'; +BIT_NOT : '~'; +BIT_OR : '|'; +BIT_AND : '&'; +BIT_XOR : '^'; + +PLACEHOLDER: '?'; + +// TSQL specific +SQUARE_BRACKET_ID : '[' (~']' | ']' ']')* ']'; +TEMP_ID : '#' ([A-Z_$@#0-9] | FullWidthLetter)*; +LOCAL_ID : '@' ([A-Z_$@#0-9] | FullWidthLetter)*; \ No newline at end of file diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlParser.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlParser.g4 new file mode 100644 index 0000000000..1667e46c4f --- /dev/null +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/tsql/TSqlParser.g4 @@ -0,0 +1,3353 @@ +/* +T-SQL (Transact-SQL, MSSQL) grammar. +The MIT License (MIT). +Copyright (c) 2017, Mark Adams (madams51703@gmail.com) +Copyright (c) 2015-2017, Ivan Kochurkin (kvanttt@gmail.com), Positive Technologies. +Copyright (c) 2016, Scott Ure (scott@redstormsoftware.com). +Copyright (c) 2016, Rui Zhang (ruizhang.ccs@gmail.com). +Copyright (c) 2016, Marcus Henriksson (kuseman80@gmail.com). +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +// ================================================================================= +// Please reformat the grammr file before a change commit. See remorph/core/README.md +// For formatting, see: https://github.com/mike-lischke/antlr-format/blob/main/doc/formatting.md + +// $antlr-format alignColons hanging +// $antlr-format columnLimit 150 +// $antlr-format alignSemicolons hanging +// $antlr-format alignTrailingComments true +// ================================================================================= + +parser grammar TSqlParser; + +import procedure, commonparse, jinja; + +options { + tokenVocab = TSqlLexer; +} + +// ============== Dialect compatibiltiy rules ============== +// The following rules provide substitutes for grammar rules referenced in the procedure.g4 grammar, that +// we do not have real equivalents for in this gramamr. +// Over time, as we homogonize more and more of the dialect grammars, these rules will be removed. +// Note that these rules will not be visited by the TSQL transpiler, as they are not part of +// TSQL and we are expecting syntacticly and semanticly sound input. + +// string and stringList will eventually expand to composite token sequences for macro substitution, so they +// are not redundant here. +string: STRING + ; +stringList: string (COMMA string)* + ; + +// expr is just an alias for expression, for Snowflake compatibility +// TODO: Change Snowflake to use the rule name expression instead of expr as this is what the Spark parser uses +expr: expression + ; +// ====================================================== + +// ================= TSQL Specific Rules ======================================== + +tSqlFile: SEMI* batch? EOF + ; + +batch: executeBodyBatch? SEMI* (sqlClauses SEMI*)+ + ; + +// TODO: Properly sort out SEMI colons, which have been haphazzardly added in some +// places and not others. + +sqlClauses + : dmlClause SEMI* + | cflStatement SEMI* + | anotherStatement SEMI* + | ddlClause SEMI* + | dbccClause SEMI* + | backupStatement SEMI* + | createOrAlterFunction SEMI* + | createOrAlterProcedure SEMI* + | createOrAlterTrigger SEMI* + | createView SEMI* + | goStatement SEMI* + | jinjaTemplate SEMI* + ; + +dmlClause: withExpression? ( selectStatement | merge | delete | insert | update | bulkStatement) + ; + +ddlClause + : alterApplicationRole + | alterAssembly + | alterAsymmetricKey + | alterAuthorization + | alterAvailabilityGroup + | alterCertificate + | alterColumnEncryptionKey + | alterCredential + | alterCryptographicProvider + | alterDatabase + | alterDatabaseAuditSpecification + | alterDbRole + | alterEndpoint + | alterExternalDataSource + | alterExternalLibrary + | alterExternalResourcePool + | alterFulltextCatalog + | alterFulltextStoplist + | alterIndex + | alterLoginAzureSql + | alterLoginAzureSqlDwAndPdw + | alterLoginSqlServer + | alterMasterKeyAzureSql + | alterMasterKeySqlServer + | alterMessageType + | alterPartitionFunction + | alterPartitionScheme + | alterRemoteServiceBinding + | alterResourceGovernor + | alterSchemaAzureSqlDwAndPdw + | alterSchemaSql + | alterSequence + | alterServerAudit + | alterServerAuditSpecification + | alterServerConfiguration + | alterServerRole + | alterServerRolePdw + | alterService + | alterServiceMasterKey + | alterSymmetricKey + | alterTable + | alterUser + | alterUserAzureSql + | alterWorkloadGroup + | alterXmlSchemaCollection + | createApplicationRole + | createAssembly + | createAsymmetricKey + | createColumnEncryptionKey + | createColumnMasterKey + | createColumnstoreIndex + | createCredential + | createCryptographicProvider + | createDatabaseScopedCredential + | createDatabase + | createDatabaseAuditSpecification + | createDbRole + | createEndpoint + | createEventNotification + | createExternalLibrary + | createExternalResourcePool + | createExternalDataSource + | createFulltextCatalog + | createFulltextStoplist + | createIndex + | createLoginAzureSql + | createLoginPdw + | createLoginSqlServer + | createMasterKeyAzureSql + | createMasterKeySqlServer + | createNonclusteredColumnstoreIndex + | createOrAlterBrokerPriority + | createOrAlterEventSession + | createPartitionFunction + | createPartitionScheme + | createRemoteServiceBinding + | createResourcePool + | createRoute + | createRule + | createSchema + | createSchemaAzureSqlDwAndPdw + | createSearchPropertyList + | createSecurityPolicy + | createSequence + | createServerAudit + | createServerAuditSpecification + | createServerRole + | createService + | createStatistics + | createSynonym + | createTable + | createType + | createUser + | createUserAzureSqlDw + | createWorkloadGroup + | createXmlIndex + | createXmlSchemaCollection + | triggerDisEn + | dropAggregate + | dropApplicationRole + | dropAssembly + | dropAsymmetricKey + | dropAvailabilityGroup + | dropBrokerPriority + | dropCertificate + | dropColumnEncryptionKey + | dropColumnMasterKey + | dropContract + | dropCredential + | dropCryptograhicProvider + | dropDatabase + | dropDatabaseAuditSpecification + | dropDatabaseEncryptionKey + | dropDatabaseScopedCredential + | dropDbRole + | dropDefault + | dropEndpoint + | dropEventNotifications + | dropEventSession + | dropExternalDataSource + | dropExternalFileFormat + | dropExternalLibrary + | dropExternalResourcePool + | dropExternalTable + | dropFulltextCatalog + | dropFulltextIndex + | dropFulltextStoplist + | dropFunction + | dropIndex + | dropLogin + | dropMasterKey + | dropMessageType + | dropPartitionFunction + | dropPartitionScheme + | dropProcedure + | dropQueue + | dropRemoteServiceBinding + | dropResourcePool + | dropRoute + | dropRule + | dropSchema + | dropSearchPropertyList + | dropSecurityPolicy + | dropSequence + | dropServerAudit + | dropServerAuditSpecification + | dropServerRole + | dropService + | dropSignature + | dropStatistics + | dropStatisticsNameAzureDwAndPdw + | dropSymmetricKey + | dropSynonym + | dropTable + | dropTrigger + | dropType + | dropUser + | dropView + | dropWorkloadGroup + | dropXmlSchemaCollection + | triggerDisEn + | lockTable + | truncateTable + | updateStatistics + ; + +backupStatement + : backupDatabase + | backupLog + | backupCertificate + | backupMasterKey + | backupServiceMasterKey + ; + +cflStatement + : blockStatement + | breakStatement + | continueStatement + | gotoStatement + | ifStatement + | printStatement + | raiseerrorStatement + | returnStatement + | throwStatement + | tryCatchStatement + | waitforStatement + | whileStatement + | receiveStatement + ; + +blockStatement: BEGIN SEMI? sqlClauses* END SEMI? + ; + +breakStatement: BREAK SEMI? + ; + +continueStatement: CONTINUE SEMI? + ; + +gotoStatement: GOTO id COLON? SEMI? + ; + +ifStatement: IF searchCondition sqlClauses (ELSE sqlClauses)? SEMI? + ; + +throwStatement: THROW ( intLocal COMMA stringLocal COMMA intLocal)? SEMI? + ; + +stringLocal: STRING | LOCAL_ID + ; + +intLocal: INT | LOCAL_ID + ; + +tryCatchStatement + : BEGIN TRY SEMI? sqlClauses+ END TRY SEMI? BEGIN CATCH SEMI? sqlClauses* END CATCH SEMI? + ; + +waitforStatement + : WAITFOR ( + DELAY STRING + | id STRING // TIME + | receiveStatement? COMMA? (id expression)? expression? + ) SEMI? + ; + +whileStatement: WHILE searchCondition ( sqlClauses | BREAK SEMI? | CONTINUE SEMI?) + ; + +printStatement: PRINT (expression | DOUBLE_QUOTE_ID) (COMMA LOCAL_ID)* SEMI? + ; + +raiseerrorStatement + : RAISERROR LPAREN (INT | STRING | LOCAL_ID) COMMA constant_LOCAL_ID COMMA constant_LOCAL_ID ( + COMMA (constant_LOCAL_ID | NULL) + )* RPAREN (WITH genericOption)? SEMI? + | RAISERROR INT formatstring = (STRING | LOCAL_ID | DOUBLE_QUOTE_ID) ( + COMMA args += (INT | STRING | LOCAL_ID) + )* // Discontinued in SQL Server 2014 on + ; + +anotherStatement + : alterQueue + | checkpointStatement + | conversationStatement + | createContract + | createQueue + | cursorStatement + | declareStatement + | executeStatement + | killStatement + | messageStatement + | reconfigureStatement + | securityStatement + | setStatement + | setuserStatement + | shutdownStatement + | transactionStatement + | useStatement + ; + +alterApplicationRole + : ALTER APPLICATION ROLE id WITH (COMMA? NAME EQ id)? (COMMA? PASSWORD EQ STRING)? ( + COMMA? DEFAULT_SCHEMA EQ id + )? + ; + +alterXmlSchemaCollection: ALTER XML SCHEMA COLLECTION dotIdentifier ADD STRING + ; + +createApplicationRole: CREATE APPLICATION ROLE id WITH optionList + ; + +dropAggregate: DROP AGGREGATE (IF EXISTS)? dotIdentifier? + ; + +dropApplicationRole: DROP APPLICATION ROLE id + ; + +alterAssembly: ALTER ASSEMBLY id alterAssemblyClause + ; + +alterAssemblyClause + : (FROM (STRING | AS id))? (WITH optionList)? (DROP optionList)? ( + ADD FILE FROM STRING (AS id)? + )? + ; + +createAssembly + : CREATE ASSEMBLY id genericOption? FROM (COMMA? (STRING | HEX))+ (WITH genericOption)? + ; + +dropAssembly: DROP ASSEMBLY (IF EXISTS)? (COMMA? id)+ ( WITH genericOption)? + ; + +alterAsymmetricKey: ALTER ASYMMETRIC KEY id ( asymmetricKeyOption | REMOVE PRIVATE KEY) + ; + +asymmetricKeyOption + : WITH PRIVATE KEY LPAREN asymmetricKeyPasswordChangeOption ( + COMMA asymmetricKeyPasswordChangeOption + )? RPAREN + ; + +asymmetricKeyPasswordChangeOption: DECRYPTION BY genericOption | ENCRYPTION BY genericOption + ; + +createAsymmetricKey + : CREATE ASYMMETRIC KEY id genericOption? (FROM genericOption)? (WITH genericOption)? ( + ENCRYPTION BY genericOption + )? + ; + +dropAsymmetricKey: DROP ASYMMETRIC KEY id (REMOVE PROVIDER KEY)? + ; + +alterAuthorization + : ALTER AUTHORIZATION ON (id id? id? DOUBLE_COLON)? dotIdentifier TO genericOption + ; + +classTypeForGrant + : APPLICATION ROLE + | ASSEMBLY + | ASYMMETRIC KEY + | AUDIT + | AVAILABILITY GROUP + | BROKER PRIORITY + | CERTIFICATE + | COLUMN ( ENCRYPTION | MASTER) KEY + | CONTRACT + | CREDENTIAL + | CRYPTOGRAPHIC PROVIDER + | DATABASE ( + AUDIT SPECIFICATION + | ENCRYPTION KEY + | EVENT SESSION + | SCOPED (CONFIGURATION | CREDENTIAL | RESOURCE GOVERNOR) + )? + | ENDPOINT + | EVENT SESSION + | NOTIFICATION (DATABASE | OBJECT | SERVER) + | EXTERNAL (DATA SOURCE | FILE FORMAT | LIBRARY | RESOURCE POOL | TABLE | CATALOG | STOPLIST) + | LOGIN + | MASTER KEY + | MESSAGE TYPE + | OBJECT + | PARTITION ( FUNCTION | SCHEME) + | REMOTE SERVICE BINDING + | RESOURCE GOVERNOR + | ROLE + | ROUTE + | SCHEMA + | SEARCH PROPERTY LIST + | SERVER ( ( AUDIT SPECIFICATION?) | ROLE)? + | SERVICE + | id LOGIN + | SYMMETRIC KEY + | TRIGGER ( DATABASE | SERVER) + | TYPE + | USER + | XML SCHEMA COLLECTION + ; + +dropAvailabilityGroup: DROP AVAILABILITY GROUP id + ; + +alterAvailabilityGroup: alterAvailabilityGroupStart alterAvailabilityGroupOptions + ; + +alterAvailabilityGroupStart: ALTER AVAILABILITY GROUP id + ; + +// TODO: Consolodate all this junk and remove many lexer tokens! +alterAvailabilityGroupOptions + : SET LPAREN ( + ( + AUTOMATED_BACKUP_PREFERENCE EQ (PRIMARY | SECONDARY_ONLY | SECONDARY | NONE) + | FAILURE_CONDITION_LEVEL EQ INT + | HEALTH_CHECK_TIMEOUT EQ INT + | DB_FAILOVER EQ ( ON | OFF) + | REQUIRED_SYNCHRONIZED_SECONDARIES_TO_COMMIT EQ INT + ) RPAREN + ) + | ADD DATABASE id + | REMOVE DATABASE id + | ADD REPLICA ON STRING ( + WITH LPAREN ( + (ENDPOINT_URL EQ STRING)? ( + COMMA? AVAILABILITY_MODE EQ (SYNCHRONOUS_COMMIT | ASYNCHRONOUS_COMMIT) + )? (COMMA? FAILOVER_MODE EQ (AUTOMATIC | MANUAL))? ( + COMMA? SEEDING_MODE EQ (AUTOMATIC | MANUAL) + )? (COMMA? BACKUP_PRIORITY EQ INT)? ( + COMMA? PRIMARY_ROLE LPAREN ALLOW_CONNECTIONS EQ (READ_WRITE | ALL) RPAREN + )? (COMMA? SECONDARY_ROLE LPAREN ALLOW_CONNECTIONS EQ ( READ_ONLY) RPAREN)? + ) + ) RPAREN + | SECONDARY_ROLE LPAREN ( + ALLOW_CONNECTIONS EQ (NO | READ_ONLY | ALL) + | READ_ONLY_ROUTING_LIST EQ (LPAREN ( ( STRING)) RPAREN) + ) + | PRIMARY_ROLE LPAREN ( + ALLOW_CONNECTIONS EQ (NO | READ_ONLY | ALL) + | READ_ONLY_ROUTING_LIST EQ ( LPAREN ((COMMA? STRING)* | NONE) RPAREN) + | SESSION_TIMEOUT EQ INT + ) + | MODIFY REPLICA ON STRING ( + WITH LPAREN ( + ENDPOINT_URL EQ STRING + | AVAILABILITY_MODE EQ ( SYNCHRONOUS_COMMIT | ASYNCHRONOUS_COMMIT) + | FAILOVER_MODE EQ (AUTOMATIC | MANUAL) + | SEEDING_MODE EQ (AUTOMATIC | MANUAL) + | BACKUP_PRIORITY EQ INT + ) + | SECONDARY_ROLE LPAREN ( + ALLOW_CONNECTIONS EQ (NO | READ_ONLY | ALL) + | READ_ONLY_ROUTING_LIST EQ ( LPAREN ( ( STRING)) RPAREN) + ) + | PRIMARY_ROLE LPAREN ( + ALLOW_CONNECTIONS EQ (NO | READ_ONLY | ALL) + | READ_ONLY_ROUTING_LIST EQ ( LPAREN ((COMMA? STRING)* | NONE) RPAREN) + | SESSION_TIMEOUT EQ INT + ) + ) RPAREN + | REMOVE REPLICA ON STRING + | JOIN + | JOIN AVAILABILITY GROUP ON ( + COMMA? STRING WITH LPAREN ( + LISTENER_URL EQ STRING COMMA AVAILABILITY_MODE EQ ( + SYNCHRONOUS_COMMIT + | ASYNCHRONOUS_COMMIT + ) COMMA FAILOVER_MODE EQ MANUAL COMMA SEEDING_MODE EQ (AUTOMATIC | MANUAL) RPAREN + ) + )+ + | MODIFY AVAILABILITY GROUP ON ( + COMMA? STRING WITH LPAREN ( + LISTENER_URL EQ STRING ( + COMMA? AVAILABILITY_MODE EQ (SYNCHRONOUS_COMMIT | ASYNCHRONOUS_COMMIT) + )? (COMMA? FAILOVER_MODE EQ MANUAL)? (COMMA? SEEDING_MODE EQ (AUTOMATIC | MANUAL))? RPAREN + ) + )+ + | GRANT CREATE ANY DATABASE + | DENY CREATE ANY DATABASE + | FAILOVER + | FORCE_FAILOVER_ALLOW_DATA_LOSS + | ADD LISTENER STRING LPAREN ( + WITH DHCP (ON LPAREN STRING STRING RPAREN) + | WITH IP LPAREN ( + (COMMA? LPAREN ( STRING (COMMA STRING)?) RPAREN)+ RPAREN (COMMA PORT EQ INT)? + ) + ) RPAREN + | MODIFY LISTENER (ADD IP LPAREN ( STRING (COMMA STRING)?) RPAREN | PORT EQ INT) + | RESTART LISTENER STRING + | REMOVE LISTENER STRING + | OFFLINE + | WITH LPAREN DTC_SUPPORT EQ PER_DB RPAREN + ; + +createOrAlterBrokerPriority + : (CREATE | ALTER) BROKER PRIORITY id FOR CONVERSATION SET LPAREN ( + CONTRACT_NAME EQ ( id | ANY) COMMA? + )? (LOCAL_SERVICE_NAME EQ (DOUBLE_FORWARD_SLASH? id | ANY) COMMA?)? ( + REMOTE_SERVICE_NAME EQ (STRING | ANY) COMMA? + )? (PRIORITY_LEVEL EQ ( INT | DEFAULT))? RPAREN + ; + +dropBrokerPriority: DROP BROKER PRIORITY id + ; + +alterCertificate + : ALTER CERTIFICATE id ( + REMOVE PRIVATE_KEY + | WITH PRIVATE KEY LPAREN ( + FILE EQ STRING COMMA? + | DECRYPTION BY PASSWORD EQ STRING COMMA? + | ENCRYPTION BY PASSWORD EQ STRING COMMA? + )+ RPAREN + | WITH ACTIVE FOR BEGIN_DIALOG EQ ( ON | OFF) + ) + ; + +alterColumnEncryptionKey + : ALTER COLUMN ENCRYPTION KEY id (ADD | DROP) VALUE LPAREN COLUMN_MASTER_KEY EQ id ( + COMMA ALGORITHM EQ STRING COMMA ENCRYPTED_VALUE EQ HEX + )? RPAREN + ; + +createColumnEncryptionKey + : CREATE COLUMN ENCRYPTION KEY id WITH VALUES ( + LPAREN COMMA? COLUMN_MASTER_KEY EQ id COMMA ALGORITHM EQ STRING COMMA ENCRYPTED_VALUE EQ HEX RPAREN COMMA? + )+ + ; + +dropCertificate: DROP CERTIFICATE id + ; + +dropColumnEncryptionKey: DROP COLUMN ENCRYPTION KEY id + ; + +dropColumnMasterKey: DROP COLUMN MASTER KEY id + ; + +dropContract: DROP CONTRACT id + ; + +dropCredential: DROP CREDENTIAL id + ; + +dropCryptograhicProvider: DROP CRYPTOGRAPHIC PROVIDER id + ; + +dropDatabase: DROP DATABASE (IF EXISTS)? ( COMMA? id)+ + ; + +dropDatabaseAuditSpecification: DROP DATABASE AUDIT SPECIFICATION id + ; + +dropDatabaseEncryptionKey: DROP DATABASE ENCRYPTION KEY + ; + +dropDatabaseScopedCredential: DROP DATABASE SCOPED CREDENTIAL id + ; + +dropDefault: DROP DEFAULT (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* + ; + +dropEndpoint: DROP ENDPOINT id + ; + +dropExternalDataSource: DROP EXTERNAL DATA SOURCE id + ; + +dropExternalFileFormat: DROP EXTERNAL FILE FORMAT id + ; + +dropExternalLibrary: DROP EXTERNAL LIBRARY id ( AUTHORIZATION id)? + ; + +dropExternalResourcePool: DROP EXTERNAL RESOURCE POOL id + ; + +dropExternalTable: DROP EXTERNAL TABLE dotIdentifier + ; + +dropEventNotifications: DROP EVENT NOTIFICATION id (COMMA id)* ON ( SERVER | DATABASE | QUEUE id) + ; + +dropEventSession: DROP EVENT SESSION id ON SERVER + ; + +dropFulltextCatalog: DROP FULLTEXT CATALOG id + ; + +dropFulltextIndex: DROP FULLTEXT INDEX ON dotIdentifier + ; + +dropFulltextStoplist: DROP FULLTEXT STOPLIST id + ; + +dropLogin: DROP LOGIN id + ; + +dropMasterKey: DROP MASTER KEY + ; + +dropMessageType: DROP MESSAGE TYPE id + ; + +dropPartitionFunction: DROP PARTITION FUNCTION id + ; + +dropPartitionScheme: DROP PARTITION SCHEME id + ; + +dropQueue: DROP QUEUE dotIdentifier + ; + +dropRemoteServiceBinding: DROP REMOTE SERVICE BINDING id + ; + +dropResourcePool: DROP RESOURCE POOL id + ; + +dropDbRole: DROP ROLE (IF EXISTS)? id + ; + +dropRoute: DROP ROUTE id + ; + +dropRule: DROP RULE (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* + ; + +dropSchema: DROP SCHEMA (IF EXISTS)? id + ; + +dropSearchPropertyList: DROP SEARCH PROPERTY LIST id + ; + +dropSecurityPolicy: DROP SECURITY POLICY (IF EXISTS)? dotIdentifier + ; + +dropSequence: DROP SEQUENCE (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* + ; + +dropServerAudit: DROP SERVER AUDIT id + ; + +dropServerAuditSpecification: DROP SERVER AUDIT SPECIFICATION id + ; + +dropServerRole: DROP SERVER ROLE id + ; + +dropService: DROP SERVICE id + ; + +dropSignature + : DROP (COUNTER)? SIGNATURE FROM dotIdentifier BY ( + COMMA? CERTIFICATE cert += id + | COMMA? ASYMMETRIC KEY id + )+ + ; + +dropStatisticsNameAzureDwAndPdw: DROP STATISTICS dotIdentifier + ; + +dropSymmetricKey: DROP SYMMETRIC KEY id ( REMOVE PROVIDER KEY)? + ; + +dropSynonym: DROP SYNONYM (IF EXISTS)? dotIdentifier + ; + +dropUser: DROP USER (IF EXISTS)? id + ; + +dropWorkloadGroup: DROP WORKLOAD GROUP id + ; + +triggerDisEn + : (DISABLE | ENABLE) TRIGGER ( + triggers += dotIdentifier (COMMA triggers += dotIdentifier)* + | ALL + ) ON (dotIdentifier | DATABASE | ALL SERVER) + ; + +lockTable + : LOCK TABLE tableName IN (SHARE | EXCLUSIVE) MODE (id /* WAIT INT | NOWAIT */ INT)? SEMI? + ; + +truncateTable + : TRUNCATE TABLE tableName ( + WITH LPAREN PARTITIONS LPAREN (COMMA? (INT | INT TO INT))+ RPAREN RPAREN + )? + ; + +createColumnMasterKey + : CREATE COLUMN MASTER KEY id WITH LPAREN KEY_STORE_PROVIDER_NAME EQ STRING COMMA KEY_PATH EQ STRING RPAREN + ; + +alterCredential: ALTER CREDENTIAL id WITH IDENTITY EQ STRING ( COMMA SECRET EQ STRING)? + ; + +createCredential + : CREATE CREDENTIAL id WITH IDENTITY EQ STRING (COMMA SECRET EQ STRING)? ( + FOR CRYPTOGRAPHIC PROVIDER id + )? + ; + +alterCryptographicProvider + : ALTER CRYPTOGRAPHIC PROVIDER id (FROM FILE EQ STRING)? (ENABLE | DISABLE)? + ; + +createCryptographicProvider: CREATE CRYPTOGRAPHIC PROVIDER id FROM FILE EQ STRING + ; + +createEndpoint + : CREATE ENDPOINT id (AUTHORIZATION id)? (STATE EQ (STARTED | STOPPED | DISABLED))? AS TCP LPAREN endpointListenerClause RPAREN ( + FOR TSQL LPAREN RPAREN + | FOR SERVICE_BROKER LPAREN endpointAuthenticationClause ( + COMMA? endpointEncryptionAlogorithmClause + )? (COMMA? MESSAGE_FORWARDING EQ (ENABLED | DISABLED))? ( + COMMA? MESSAGE_FORWARD_SIZE EQ INT + )? RPAREN + | FOR DATABASE_MIRRORING LPAREN endpointAuthenticationClause ( + COMMA? endpointEncryptionAlogorithmClause + )? COMMA? ROLE EQ (WITNESS | PARTNER | ALL) RPAREN + ) + ; + +endpointEncryptionAlogorithmClause + : ENCRYPTION EQ (DISABLED | SUPPORTED | REQUIRED) (ALGORITHM (AES RC4? | RC4 AES?))? + ; + +endpointAuthenticationClause + : AUTHENTICATION EQ ( + WINDOWS (NTLM | KERBEROS | NEGOTIATE)? ( CERTIFICATE id)? + | CERTIFICATE id WINDOWS? ( NTLM | KERBEROS | NEGOTIATE)? + ) + ; + +endpointListenerClause + : LISTENER_PORT EQ INT ( + COMMA LISTENER_IP EQ (ALL | LPAREN ( INT DOT INT DOT INT DOT INT | STRING) RPAREN) + )? + ; + +createEventNotification + : CREATE EVENT NOTIFICATION id ON (SERVER | DATABASE | QUEUE id) (WITH FAN_IN)? FOR ( + COMMA? etg += id + )+ TO SERVICE STRING COMMA STRING + ; + +addDropEvent + : ADD EVENT ( + dotIdentifier ( + LPAREN (SET ( COMMA? id EQ ( INT | STRING))*)? ( + ACTION LPAREN dotIdentifier (COMMA dotIdentifier)* RPAREN + )+ (WHERE eventSessionPredicateExpression)? RPAREN + )* + ) + | DROP EVENT dotIdentifier + ; + +addDropEventTarget + : ADD TARGET dotIdentifier (LPAREN SET (COMMA? id EQ ( LPAREN? INT RPAREN? | STRING))+ RPAREN) + | DROP TARGET dotIdentifier + ; + +addDropEventOrTarget: addDropEvent | addDropEventTarget + ; + +createOrAlterEventSession + : (CREATE | ALTER) EVENT SESSION id ON (SERVER | DATABASE) addDropEventOrTarget ( + COMMA addDropEventOrTarget + )* ( + WITH LPAREN (COMMA? MAX_MEMORY EQ INT (KB | MB))? ( + COMMA? EVENT_RETENTION_MODE EQ ( + ALLOW_SINGLE_EVENT_LOSS + | ALLOW_MULTIPLE_EVENT_LOSS + | NO_EVENT_LOSS + ) + )? (COMMA? MAX_DISPATCH_LATENCY EQ ( INT SECONDS | INFINITE))? ( + COMMA? MAX_EVENT_SIZE EQ INT (KB | MB) + )? (COMMA? MEMORY_PARTITION_MODE EQ ( NONE | PER_NODE | PER_CPU))? ( + COMMA? TRACK_CAUSALITY EQ (ON | OFF) + )? (COMMA? STARTUP_STATE EQ (ON | OFF))? RPAREN + )? (STATE EQ (START | STOP))? + ; + +eventSessionPredicateExpression + : ( + COMMA? (AND | OR)? NOT? ( + eventSessionPredicateFactor + | LPAREN eventSessionPredicateExpression RPAREN + ) + )+ + ; + +eventSessionPredicateFactor + : eventSessionPredicateLeaf + | LPAREN eventSessionPredicateExpression RPAREN + ; + +eventSessionPredicateLeaf + : dotIdentifier ((EQ | (LT GT) | (BANG EQ) | GT | (GT EQ) | LT | LT EQ) (INT | STRING))? + | dotIdentifier LPAREN ( dotIdentifier COMMA (INT | STRING)) RPAREN + ; + +createExternalDataSource + : CREATE EXTERNAL DATA SOURCE id WITH LPAREN (COMMA? ( genericOption | connectionOptions))* RPAREN SEMI? + ; + +connectionOptions: id EQ STRING (COMMA STRING)* + ; + +alterExternalDataSource + : ALTER EXTERNAL DATA SOURCE id ( + SET ( + LOCATION EQ STRING COMMA? + | RESOURCE_MANAGER_LOCATION EQ STRING COMMA? + | CREDENTIAL EQ id + )+ + | WITH LPAREN TYPE EQ BLOB_STORAGE COMMA LOCATION EQ STRING (COMMA CREDENTIAL EQ id)? RPAREN + ) + ; + +alterExternalLibrary + : ALTER EXTERNAL LIBRARY id (AUTHORIZATION id)? (SET | ADD) ( + LPAREN CONTENT EQ (STRING | HEX | NONE) (COMMA PLATFORM EQ (WINDOWS | LINUX)? RPAREN) WITH ( + COMMA? LANGUAGE EQ id + | DATA_SOURCE EQ id + )+ RPAREN + ) + ; + +createExternalLibrary + : CREATE EXTERNAL LIBRARY id (AUTHORIZATION id)? FROM ( + COMMA? LPAREN? (CONTENT EQ)? (STRING | HEX | NONE) ( + COMMA PLATFORM EQ (WINDOWS | LINUX)? RPAREN + )? + ) (WITH ( COMMA? LANGUAGE EQ id | DATA_SOURCE EQ id)+ RPAREN)? + ; + +alterExternalResourcePool + : ALTER EXTERNAL RESOURCE POOL (id | DEFAULT_DOUBLE_QUOTE) WITH LPAREN MAX_CPU_PERCENT EQ INT ( + COMMA? AFFINITY CPU EQ ( AUTO | (COMMA? INT TO INT | COMMA INT)+) + | NUMANODE EQ (COMMA? INT TO INT | COMMA? INT)+ + ) (COMMA? MAX_MEMORY_PERCENT EQ INT)? (COMMA? MAX_PROCESSES EQ INT)? RPAREN + ; + +createExternalResourcePool + : CREATE EXTERNAL RESOURCE POOL id WITH LPAREN MAX_CPU_PERCENT EQ INT ( + COMMA? AFFINITY CPU EQ ( AUTO | (COMMA? INT TO INT | COMMA INT)+) + | NUMANODE EQ (COMMA? INT TO INT | COMMA? INT)+ + ) (COMMA? MAX_MEMORY_PERCENT EQ INT)? (COMMA? MAX_PROCESSES EQ INT)? RPAREN + ; + +alterFulltextCatalog + : ALTER FULLTEXT CATALOG id ( + REBUILD (WITH ACCENT_SENSITIVITY EQ (ON | OFF))? + | REORGANIZE + | AS DEFAULT + ) + ; + +createFulltextCatalog + : CREATE FULLTEXT CATALOG id (ON FILEGROUP id)? (IN PATH STRING)? ( + WITH ACCENT_SENSITIVITY EQ (ON | OFF) + )? (AS DEFAULT)? (AUTHORIZATION id)? + ; + +alterFulltextStoplist + : ALTER FULLTEXT STOPLIST id ( + ADD STRING LANGUAGE (STRING | INT | HEX) + | DROP (STRING LANGUAGE (STRING | INT | HEX) | ALL (STRING | INT | HEX) | ALL) + ) + ; + +createFulltextStoplist + : CREATE FULLTEXT STOPLIST id (FROM (dotIdentifier | SYSTEM STOPLIST))? (AUTHORIZATION id)? + ; + +alterLoginSqlServer + : ALTER LOGIN id ( + (ENABLE | DISABLE)? + | WITH ((PASSWORD EQ ( STRING | HEX HASHED)) (MUST_CHANGE | UNLOCK)*)? ( + OLD_PASSWORD EQ STRING ( MUST_CHANGE | UNLOCK)* + )? (DEFAULT_DATABASE EQ id)? (DEFAULT_LANGUAGE EQ id)? (NAME EQ id)? ( + CHECK_POLICY EQ (ON | OFF) + )? (CHECK_EXPIRATION EQ (ON | OFF))? (CREDENTIAL EQ id)? (NO CREDENTIAL)? + | (ADD | DROP) CREDENTIAL id + ) + ; + +createLoginSqlServer + : CREATE LOGIN id ( + WITH ((PASSWORD EQ ( STRING | HEX HASHED)) (MUST_CHANGE | UNLOCK)*)? (COMMA? SID EQ HEX)? ( + COMMA? DEFAULT_DATABASE EQ id + )? (COMMA? DEFAULT_LANGUAGE EQ id)? (COMMA? CHECK_EXPIRATION EQ (ON | OFF))? ( + COMMA? CHECK_POLICY EQ (ON | OFF) + )? (COMMA? CREDENTIAL EQ id)? + | ( + FROM ( + WINDOWS ( + WITH (COMMA? DEFAULT_DATABASE EQ id)? (COMMA? DEFAULT_LANGUAGE EQ STRING)? + ) + | CERTIFICATE id + | ASYMMETRIC KEY id + ) + ) + ) + ; + +alterLoginAzureSql + : ALTER LOGIN id ( + (ENABLE | DISABLE)? + | WITH ( PASSWORD EQ STRING (OLD_PASSWORD EQ STRING)? | NAME EQ id) + ) + ; + +createLoginAzureSql: CREATE LOGIN id WITH PASSWORD EQ STRING ( SID EQ HEX)? + ; + +alterLoginAzureSqlDwAndPdw + : ALTER LOGIN id ( + (ENABLE | DISABLE)? + | WITH (PASSWORD EQ STRING ( OLD_PASSWORD EQ STRING (MUST_CHANGE | UNLOCK)*)? | NAME EQ id) + ) + ; + +createLoginPdw + : CREATE LOGIN id ( + WITH (PASSWORD EQ STRING (MUST_CHANGE)? ( CHECK_POLICY EQ (ON | OFF)?)?) + | FROM WINDOWS + ) + ; + +alterMasterKeySqlServer + : ALTER MASTER KEY ( + (FORCE)? REGENERATE WITH ENCRYPTION BY PASSWORD EQ STRING + | (ADD | DROP) ENCRYPTION BY (SERVICE MASTER KEY | PASSWORD EQ STRING) + ) + ; + +createMasterKeySqlServer: CREATE MASTER KEY ENCRYPTION BY PASSWORD EQ STRING + ; + +alterMasterKeyAzureSql + : ALTER MASTER KEY ( + (FORCE)? REGENERATE WITH ENCRYPTION BY PASSWORD EQ STRING + | ADD ENCRYPTION BY (SERVICE MASTER KEY | PASSWORD EQ STRING) + | DROP ENCRYPTION BY PASSWORD EQ STRING + ) + ; + +createMasterKeyAzureSql: CREATE MASTER KEY ( ENCRYPTION BY PASSWORD EQ STRING)? + ; + +alterMessageType + : ALTER MESSAGE TYPE id VALIDATION EQ ( + NONE + | EMPTY + | WELL_FORMED_XML + | VALID_XML WITH SCHEMA COLLECTION id + ) + ; + +alterPartitionFunction + : ALTER PARTITION FUNCTION id LPAREN RPAREN (SPLIT | MERGE) RANGE LPAREN INT RPAREN + ; + +alterPartitionScheme: ALTER PARTITION SCHEME id NEXT USED (id)? + ; + +alterRemoteServiceBinding + : ALTER REMOTE SERVICE BINDING id WITH (USER EQ id)? (COMMA ANONYMOUS EQ (ON | OFF))? + ; + +createRemoteServiceBinding + : CREATE REMOTE SERVICE BINDING id (AUTHORIZATION id)? TO SERVICE STRING WITH (USER EQ id)? ( + COMMA ANONYMOUS EQ (ON | OFF) + )? + ; + +createResourcePool + : CREATE RESOURCE POOL id ( + WITH LPAREN (COMMA? MIN_CPU_PERCENT EQ INT)? (COMMA? MAX_CPU_PERCENT EQ INT)? ( + COMMA? CAP_CPU_PERCENT EQ INT + )? ( + COMMA? AFFINITY SCHEDULER EQ ( + AUTO + | LPAREN (COMMA? (INT | INT TO INT))+ RPAREN + | NUMANODE EQ LPAREN (COMMA? (INT | INT TO INT))+ RPAREN + ) + )? (COMMA? MIN_MEMORY_PERCENT EQ INT)? (COMMA? MAX_MEMORY_PERCENT EQ INT)? ( + COMMA? MIN_IOPS_PER_VOLUME EQ INT + )? (COMMA? MAX_IOPS_PER_VOLUME EQ INT)? RPAREN + )? + ; + +alterResourceGovernor + : ALTER RESOURCE GOVERNOR ( + (DISABLE | RECONFIGURE) + | WITH LPAREN CLASSIFIER_FUNCTION EQ (dotIdentifier | NULL) RPAREN + | RESET STATISTICS + | WITH LPAREN MAX_OUTSTANDING_IO_PER_VOLUME EQ INT RPAREN + ) + ; + +alterDatabaseAuditSpecification + : ALTER DATABASE AUDIT SPECIFICATION id (FOR SERVER AUDIT id)? ( + auditActionSpecGroup (COMMA auditActionSpecGroup)* + )? (WITH LPAREN STATE EQ (ON | OFF) RPAREN)? + ; + +auditActionSpecGroup: (ADD | DROP) LPAREN (auditActionSpecification | id) RPAREN + ; + +auditActionSpecification + : actionSpecification (COMMA actionSpecification)* ON (auditClassName COLON COLON)? dotIdentifier BY principalId ( + COMMA principalId + )* + ; + +actionSpecification: SELECT | INSERT | UPDATE | DELETE | EXECUTE | RECEIVE | REFERENCES + ; + +auditClassName: OBJECT | SCHEMA | TABLE + ; + +alterDbRole: ALTER ROLE id ( (ADD | DROP) MEMBER id | WITH NAME EQ id) + ; + +createDatabaseAuditSpecification + : CREATE DATABASE AUDIT SPECIFICATION id (FOR SERVER AUDIT id)? ( + auditActionSpecGroup (COMMA auditActionSpecGroup)* + )? (WITH LPAREN STATE EQ (ON | OFF) RPAREN)? + ; + +createDbRole: CREATE ROLE id (AUTHORIZATION id)? + ; + +createRoute + : CREATE ROUTE id (AUTHORIZATION id)? WITH (COMMA? SERVICE_NAME EQ STRING)? ( + COMMA? BROKER_INSTANCE EQ STRING + )? (COMMA? LIFETIME EQ INT)? COMMA? ADDRESS EQ STRING (COMMA MIRROR_ADDRESS EQ STRING)? + ; + +createRule: CREATE RULE (id DOT)? id AS searchCondition + ; + +alterSchemaSql + : ALTER SCHEMA id TRANSFER ((OBJECT | TYPE | XML SCHEMA COLLECTION) DOUBLE_COLON)? id (DOT id)? + ; + +createSchema + : CREATE SCHEMA (id | AUTHORIZATION id | id AUTHORIZATION id) ( + createTable + | createView + | (GRANT | DENY) (SELECT | INSERT | DELETE | UPDATE) ON (SCHEMA DOUBLE_COLON)? id TO id + | REVOKE (SELECT | INSERT | DELETE | UPDATE) ON (SCHEMA DOUBLE_COLON)? id FROM id + )* + ; + +createSchemaAzureSqlDwAndPdw: CREATE SCHEMA id ( AUTHORIZATION id)? + ; + +alterSchemaAzureSqlDwAndPdw: ALTER SCHEMA id TRANSFER (OBJECT DOUBLE_COLON)? dotIdentifier? + ; + +createSearchPropertyList + : CREATE SEARCH PROPERTY LIST id (FROM dotIdentifier)? (AUTHORIZATION id)? + ; + +createSecurityPolicy + : CREATE SECURITY POLICY dotIdentifier ( + COMMA? ADD (FILTER | BLOCK)? PREDICATE dotIdentifier LPAREN (COMMA? id)+ RPAREN ON dotIdentifier ( + COMMA? AFTER (INSERT | UPDATE) + | COMMA? BEFORE (UPDATE | DELETE) + )* + )+ (WITH LPAREN STATE EQ (ON | OFF) ( SCHEMABINDING (ON | OFF))? RPAREN)? (NOT FOR REPLICATION)? + ; + +alterSequence + : ALTER SEQUENCE dotIdentifier (RESTART (WITH INT)?)? (INCREMENT BY INT)? ( + MINVALUE INT + | NO MINVALUE + )? (MAXVALUE INT | NO MAXVALUE)? (CYCLE | NO CYCLE)? (CACHE INT | NO CACHE)? + ; + +createSequence + : CREATE SEQUENCE dotIdentifier (AS dataType)? (START WITH INT)? (INCREMENT BY MINUS? INT)? ( + MINVALUE (MINUS? INT)? + | NO MINVALUE + )? (MAXVALUE (MINUS? INT)? | NO MAXVALUE)? (CYCLE | NO CYCLE)? (CACHE INT? | NO CACHE)? + ; + +alterServerAudit + : ALTER SERVER AUDIT id ( + ( + TO ( + FILE ( + LPAREN ( + COMMA? FILEPATH EQ STRING + | COMMA? MAXSIZE EQ ( INT (MB | GB | TB) | UNLIMITED) + | COMMA? MAX_ROLLOVER_FILES EQ ( INT | UNLIMITED) + | COMMA? MAX_FILES EQ INT + | COMMA? RESERVE_DISK_SPACE EQ (ON | OFF) + )* RPAREN + ) + | APPLICATION_LOG + | SECURITY_LOG + ) + )? ( + WITH LPAREN ( + COMMA? QUEUE_DELAY EQ INT + | COMMA? ON_FAILURE EQ ( CONTINUE | SHUTDOWN | FAIL_OPERATION) + | COMMA? STATE EQ (ON | OFF) + )* RPAREN + )? ( + WHERE ( + COMMA? (NOT?) id (EQ | (LT GT) | (BANG EQ) | GT | (GT EQ) | LT | LT EQ) ( + INT + | STRING + ) + | COMMA? (AND | OR) NOT? (EQ | (LT GT) | (BANG EQ) | GT | (GT EQ) | LT | LT EQ) ( + INT + | STRING + ) + ) + )? + | REMOVE WHERE + | MODIFY NAME EQ id + ) + ; + +createServerAudit + : CREATE SERVER AUDIT id ( + ( + TO ( + FILE ( + LPAREN ( + COMMA? FILEPATH EQ STRING + | COMMA? MAXSIZE EQ ( INT (MB | GB | TB) | UNLIMITED) + | COMMA? MAX_ROLLOVER_FILES EQ ( INT | UNLIMITED) + | COMMA? MAX_FILES EQ INT + | COMMA? RESERVE_DISK_SPACE EQ (ON | OFF) + )* RPAREN + ) + | APPLICATION_LOG + | SECURITY_LOG + ) + )? ( + WITH LPAREN ( + COMMA? QUEUE_DELAY EQ INT + | COMMA? ON_FAILURE EQ ( CONTINUE | SHUTDOWN | FAIL_OPERATION) + | COMMA? STATE EQ (ON | OFF) + | COMMA? AUDIT_GUID EQ id + )* RPAREN + )? ( + WHERE ( + COMMA? (NOT?) id (EQ | (LT GT) | (BANG EQ) | GT | (GT EQ) | LT | LT EQ) ( + INT + | STRING + ) + | COMMA? (AND | OR) NOT? (EQ | (LT GT) | (BANG EQ) | GT | (GT EQ) | LT | LT EQ) ( + INT + | STRING + ) + ) + )? + | REMOVE WHERE + | MODIFY NAME EQ id + ) + ; + +alterServerAuditSpecification + : ALTER SERVER AUDIT SPECIFICATION id (FOR SERVER AUDIT id)? ((ADD | DROP) LPAREN id RPAREN)* ( + WITH LPAREN STATE EQ (ON | OFF) RPAREN + )? + ; + +createServerAuditSpecification + : CREATE SERVER AUDIT SPECIFICATION id (FOR SERVER AUDIT id)? (ADD LPAREN id RPAREN)* ( + WITH LPAREN STATE EQ (ON | OFF) RPAREN + )? + ; + +alterServerConfiguration + : ALTER SERVER CONFIGURATION SET ( + ( + PROCESS AFFINITY ( + CPU EQ (AUTO | (COMMA? INT | COMMA? INT TO INT)+) + | NUMANODE EQ ( COMMA? INT | COMMA? INT TO INT)+ + ) + | DIAGNOSTICS id /* LOG */ ( + onOff + | PATH EQ (STRING | DEFAULT) + | MAX_SIZE EQ (INT MB | DEFAULT) + | MAX_FILES EQ (INT | DEFAULT) + ) + | FAILOVER CLUSTER PROPERTY ( + VERBOSELOGGING EQ (STRING | DEFAULT) + | SQLDUMPERFLAGS EQ (STRING | DEFAULT) + | SQLDUMPERPATH EQ (STRING | DEFAULT) + | SQLDUMPERTIMEOUT (STRING | DEFAULT) + | FAILURECONDITIONLEVEL EQ (STRING | DEFAULT) + | HEALTHCHECKTIMEOUT EQ (INT | DEFAULT) + ) + | HADR CLUSTER CONTEXT EQ (STRING | LOCAL) + | BUFFER POOL EXTENSION ( + ON LPAREN FILENAME EQ STRING COMMA SIZE EQ INT (KB | MB | GB) RPAREN + | OFF + ) + | SET SOFTNUMA (ON | OFF) + ) + ) + ; + +alterServerRole: ALTER SERVER ROLE id ( (ADD | DROP) MEMBER id | WITH NAME EQ id) + ; + +createServerRole: CREATE SERVER ROLE id ( AUTHORIZATION id)? + ; + +alterServerRolePdw: ALTER SERVER ROLE id (ADD | DROP) MEMBER id + ; + +alterService + : ALTER SERVICE id (ON QUEUE dotIdentifier)? (LPAREN optArgClause (COMMA optArgClause)* RPAREN)? + ; + +optArgClause: (ADD | DROP) CONTRACT id + ; + +createService + : CREATE SERVICE id (AUTHORIZATION id)? ON QUEUE dotIdentifier ( + LPAREN (COMMA? (id | DEFAULT))+ RPAREN + )? + ; + +alterServiceMasterKey + : ALTER SERVICE MASTER KEY ( + FORCE? REGENERATE + | ( + WITH ( + OLD_ACCOUNT EQ STRING COMMA OLD_PASSWORD EQ STRING + | NEW_ACCOUNT EQ STRING COMMA NEW_PASSWORD EQ STRING + )? + ) + ) + ; + +alterSymmetricKey + : ALTER SYMMETRIC KEY id ( + (ADD | DROP) ENCRYPTION BY ( + CERTIFICATE id + | PASSWORD EQ STRING + | SYMMETRIC KEY id + | ASYMMETRIC KEY id + ) + ) + ; + +createSynonym: CREATE SYNONYM dotIdentifier FOR dotIdentifier + ; + +alterUser + : ALTER USER id WITH ( + COMMA? NAME EQ id + | COMMA? DEFAULT_SCHEMA EQ ( id | NULL) + | COMMA? LOGIN EQ id + | COMMA? PASSWORD EQ STRING (OLD_PASSWORD EQ STRING)+ + | COMMA? DEFAULT_LANGUAGE EQ ( NONE | INT | id) + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )+ + ; + +createUser + : CREATE USER id ((FOR | FROM) LOGIN id)? ( + WITH ( + COMMA? DEFAULT_SCHEMA EQ id + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )* + )? + | CREATE USER ( + id ( + WITH ( + COMMA? DEFAULT_SCHEMA EQ id + | COMMA? DEFAULT_LANGUAGE EQ ( NONE | INT | id) + | COMMA? SID EQ HEX + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )* + )? + | id WITH PASSWORD EQ STRING ( + COMMA? DEFAULT_SCHEMA EQ id + | COMMA? DEFAULT_LANGUAGE EQ ( NONE | INT | id) + | COMMA? SID EQ HEX + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )* + | id FROM EXTERNAL PROVIDER + ) + | CREATE USER id ( + WITHOUT LOGIN ( + COMMA? DEFAULT_SCHEMA EQ id + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )* + | (FOR | FROM) CERTIFICATE id + | (FOR | FROM) ASYMMETRIC KEY id + ) + | CREATE USER id + ; + +createUserAzureSqlDw + : CREATE USER id ((FOR | FROM) LOGIN id | WITHOUT LOGIN)? (WITH DEFAULT_SCHEMA EQ id)? + | CREATE USER id FROM EXTERNAL PROVIDER ( WITH DEFAULT_SCHEMA EQ id)? + ; + +alterUserAzureSql + : ALTER USER id WITH ( + COMMA? NAME EQ id + | COMMA? DEFAULT_SCHEMA EQ id + | COMMA? LOGIN EQ id + | COMMA? ALLOW_ENCRYPTED_VALUE_MODIFICATIONS EQ ( ON | OFF) + )+ + ; + +alterWorkloadGroup + : ALTER WORKLOAD GROUP (id | DEFAULT_DOUBLE_QUOTE) ( + WITH LPAREN ( + IMPORTANCE EQ (LOW | MEDIUM | HIGH) + | COMMA? REQUEST_MAX_MEMORY_GRANT_PERCENT EQ INT + | COMMA? REQUEST_MAX_CPU_TIME_SEC EQ INT + | REQUEST_MEMORY_GRANT_TIMEOUT_SEC EQ INT + | MAX_DOP EQ INT + | GROUP_MAX_REQUESTS EQ INT + )+ RPAREN + )? (USING (id | DEFAULT_DOUBLE_QUOTE))? + ; + +createWorkloadGroup + : CREATE WORKLOAD GROUP id ( + WITH LPAREN ( + IMPORTANCE EQ (LOW | MEDIUM | HIGH) + | COMMA? REQUEST_MAX_MEMORY_GRANT_PERCENT EQ INT + | COMMA? REQUEST_MAX_CPU_TIME_SEC EQ INT + | REQUEST_MEMORY_GRANT_TIMEOUT_SEC EQ INT + | MAX_DOP EQ INT + | GROUP_MAX_REQUESTS EQ INT + )+ RPAREN + )? (USING (id | DEFAULT_DOUBLE_QUOTE)? ( COMMA? EXTERNAL id | DEFAULT_DOUBLE_QUOTE)?)? + ; + +createPartitionFunction + : CREATE PARTITION FUNCTION id LPAREN dataType RPAREN AS RANGE (LEFT | RIGHT)? FOR VALUES LPAREN expressionList RPAREN + ; + +createPartitionScheme + : CREATE PARTITION SCHEME id AS PARTITION id ALL? TO LPAREN fileGroupNames += id ( + COMMA fileGroupNames += id + )* RPAREN + ; + +createQueue: CREATE QUEUE (tableName | id) queueSettings? (ON id | DEFAULT)? + ; + +queueSettings + : WITH (STATUS EQ onOff COMMA?)? (RETENTION EQ onOff COMMA?)? ( + ACTIVATION LPAREN ( + ( + (STATUS EQ onOff COMMA?)? (PROCEDURE_NAME EQ dotIdentifier COMMA?)? ( + MAX_QUEUE_READERS EQ INT COMMA? + )? (EXECUTE AS (SELF | STRING | OWNER) COMMA?)? + ) + | DROP + ) RPAREN COMMA? + )? (POISON_MESSAGE_HANDLING LPAREN (STATUS EQ onOff) RPAREN)? + ; + +alterQueue: ALTER QUEUE (tableName | id) ( queueSettings | queueAction) + ; + +queueAction + : REBUILD (WITH LPAREN queueRebuildOptions RPAREN)? + | REORGANIZE (WITH LOB_COMPACTION EQ onOff)? + | MOVE TO (id | DEFAULT) + ; + +queueRebuildOptions: genericOption + ; + +createContract + : CREATE CONTRACT contractName (AUTHORIZATION id)? LPAREN ( + (id | DEFAULT) SENT BY (INITIATOR | TARGET | ANY) COMMA? + )+ RPAREN + ; + +conversationStatement + : beginConversationTimer + | beginConversationDialog + | endConversation + | getConversation + | sendConversation + | waitforConversation + ; + +messageStatement + : CREATE MESSAGE TYPE id (AUTHORIZATION id)? ( + VALIDATION EQ (NONE | EMPTY | WELL_FORMED_XML | VALID_XML WITH SCHEMA COLLECTION id) + ) + ; + +merge + : MERGE topClause? INTO? ddlObject withTableHints? asTableAlias? USING tableSources ON searchCondition whenMatch* outputClause? optionClause? SEMI + ? + ; + +whenMatch: WHEN NOT? MATCHED (BY (TARGET | SOURCE))? (AND searchCondition)? THEN mergeAction + ; + +mergeAction + : UPDATE SET updateElem (COMMA updateElem)* + | DELETE + | INSERT (LPAREN cols = expressionList RPAREN)? ( + VALUES LPAREN vals = expressionList RPAREN + | DEFAULT VALUES + ) + ; + +delete + : DELETE topClause? FROM? ddlObject withTableHints? outputClause? (FROM tableSources)? updateWhereClause? optionClause? SEMI? + ; + +bulkStatement + : BULK INSERT dotIdentifier FROM STRING ( + WITH LPAREN bulkInsertOption (COMMA? bulkInsertOption)* RPAREN + )? + ; + +bulkInsertOption: ORDER LPAREN bulkInsertCol (COMMA bulkInsertCol)* RPAREN | genericOption + ; + +bulkInsertCol: id (ASC | DESC)? + ; + +insert + : INSERT topClause? INTO? ddlObject withTableHints? (LPAREN expressionList RPAREN)? outputClause? insertStatementValue optionClause? SEMI? + ; + +insertStatementValue: derivedTable | executeStatement | DEFAULT VALUES + ; + +receiveStatement + : LPAREN? RECEIVE (ALL | DISTINCT | topClause | STAR | ((id | LOCAL_ID EQ expression) COMMA?)*) FROM tableName ( + INTO id (WHERE searchCondition) + )? RPAREN? + ; + +selectStatementStandalone: withExpression? selectStatement + ; + +selectStatement: queryExpression forClause? optionClause? SEMI? + ; + +update + : UPDATE topClause? ddlObject withTableHints? SET updateElem (COMMA updateElem)* outputClause? ( + FROM tableSources + )? updateWhereClause? optionClause? SEMI? + ; + +updateWhereClause: WHERE (searchCondition | CURRENT OF ( GLOBAL? cursorName | LOCAL_ID)) + ; + +outputClause + : OUTPUT outputDmlListElem (COMMA outputDmlListElem)* ( + INTO ddlObject ( LPAREN columnNameList RPAREN)? + )? + ; + +outputDmlListElem: (expression | asterisk) (AS? columnAlias)? + ; + +createDatabase + : CREATE DATABASE id (CONTAINMENT EQ ( NONE | PARTIAL))? ( + ON (PRIMARY? databaseFileSpec ( COMMA databaseFileSpec)*)? ( + id /* LOG */ ON databaseFileSpec (COMMA databaseFileSpec)* + )? + )? (COLLATE id)? (WITH createDatabaseOption (COMMA createDatabaseOption)*)? SEMI? + ; + +createDatabaseScopedCredential + : CREATE DATABASE SCOPED CREDENTIAL id WITH IDENTITY EQ STRING (COMMA SECRET EQ STRING)? SEMI? + ; + +createDatabaseOption + : FILESTREAM (databaseFilestreamOption (COMMA databaseFilestreamOption)*) + | genericOption + ; + +createIndex + : CREATE UNIQUE? clustered? INDEX id ON tableName LPAREN columnNameListWithOrder RPAREN ( + INCLUDE LPAREN columnNameList RPAREN + )? (WHERE searchCondition)? (createIndexOptions)? (ON id)? SEMI? + ; + +createIndexOptions: WITH LPAREN relationalIndexOption ( COMMA relationalIndexOption)* RPAREN + ; + +relationalIndexOption: rebuildIndexOption | genericOption + ; + +alterIndex + : ALTER INDEX (id | ALL) ON tableName ( + DISABLE + | PAUSE + | ABORT + | RESUME resumableIndexOptions? + | reorganizePartition + | setIndexOptions + | rebuildPartition + ) + ; + +resumableIndexOptions: WITH LPAREN ( resumableIndexOption (COMMA resumableIndexOption)*) RPAREN + ; + +resumableIndexOption: genericOption | lowPriorityLockWait + ; + +reorganizePartition: REORGANIZE (PARTITION EQ INT)? reorganizeOptions? + ; + +reorganizeOptions: WITH LPAREN (reorganizeOption (COMMA reorganizeOption)*) RPAREN + ; + +reorganizeOption: LOB_COMPACTION EQ onOff | COMPRESS_ALL_ROW_GROUPS EQ onOff + ; + +setIndexOptions: SET LPAREN setIndexOption (COMMA setIndexOption)* RPAREN + ; + +setIndexOption: genericOption + ; + +rebuildPartition + : REBUILD (PARTITION EQ ALL)? rebuildIndexOptions? + | REBUILD PARTITION EQ INT singlePartitionRebuildIndexOptions? + ; + +rebuildIndexOptions: WITH LPAREN rebuildIndexOption (COMMA rebuildIndexOption)* RPAREN + ; + +rebuildIndexOption + : DATA_COMPRESSION EQ (NONE | ROW | PAGE | COLUMNSTORE | COLUMNSTORE_ARCHIVE) onPartitions? + | XML_COMPRESSION EQ onOff onPartitions? + | genericOption + ; + +singlePartitionRebuildIndexOptions + : WITH LPAREN singlePartitionRebuildIndexOption (COMMA singlePartitionRebuildIndexOption)* RPAREN + ; + +singlePartitionRebuildIndexOption + : genericOption + | DATA_COMPRESSION EQ (NONE | ROW | PAGE | COLUMNSTORE | COLUMNSTORE_ARCHIVE) onPartitions? + | XML_COMPRESSION EQ onOff onPartitions? + | ONLINE EQ (ON (LPAREN lowPriorityLockWait RPAREN)? | OFF) + ; + +onPartitions: ON PARTITIONS LPAREN INT (TO INT)? ( COMMA INT (TO INT)?)* RPAREN + ; + +createColumnstoreIndex + : CREATE CLUSTERED COLUMNSTORE INDEX id ON tableName createColumnstoreIndexOptions? (ON id)? SEMI? + ; + +createColumnstoreIndexOptions + : WITH LPAREN columnstoreIndexOption (COMMA columnstoreIndexOption)* RPAREN + ; + +columnstoreIndexOption + : genericOption + | DATA_COMPRESSION EQ (COLUMNSTORE | COLUMNSTORE_ARCHIVE) onPartitions? + ; + +createNonclusteredColumnstoreIndex + : CREATE NONCLUSTERED? COLUMNSTORE INDEX id ON tableName LPAREN columnNameListWithOrder RPAREN ( + WHERE searchCondition + )? createColumnstoreIndexOptions? (ON id)? SEMI? + ; + +createOrAlterTrigger: createOrAlterDmlTrigger | createOrAlterDdlTrigger + ; + +createOrAlterDmlTrigger + : (CREATE (OR (ALTER | REPLACE))? | ALTER) TRIGGER dotIdentifier ON tableName ( + WITH dmlTriggerOption (COMMA dmlTriggerOption)* + )? (FOR | AFTER | INSTEAD OF) dmlTriggerOperation (COMMA dmlTriggerOperation)* (WITH APPEND)? ( + NOT FOR REPLICATION + )? AS sqlClauses+ + ; + +dmlTriggerOption: ENCRYPTION | executeAs + ; + +dmlTriggerOperation: (INSERT | UPDATE | DELETE) + ; + +createOrAlterDdlTrigger + : (CREATE (OR (ALTER | REPLACE))? | ALTER) TRIGGER dotIdentifier ON (ALL SERVER | DATABASE) ( + WITH dmlTriggerOption (COMMA dmlTriggerOption)* + )? (FOR | AFTER) ddlTriggerOperation (COMMA ddlTriggerOperation)* AS sqlClauses+ + ; + +ddlTriggerOperation: simpleId + ; + +createOrAlterFunction + : ((CREATE (OR ALTER)?) | ALTER) FUNCTION dotIdentifier ( + (LPAREN procedureParam (COMMA procedureParam)* RPAREN) + | LPAREN RPAREN + ) //must have (), but can be empty + (funcBodyReturnsSelect | funcBodyReturnsTable | funcBodyReturnsScalar) SEMI? + ; + +funcBodyReturnsSelect + : RETURNS TABLE (WITH functionOption (COMMA functionOption)*)? AS? ( + (EXTERNAL NAME dotIdentifier) + | RETURN (LPAREN selectStatementStandalone RPAREN | selectStatementStandalone) + ) + ; + +funcBodyReturnsTable + : RETURNS LOCAL_ID tableTypeDefinition (WITH functionOption (COMMA functionOption)*)? AS? ( + (EXTERNAL NAME dotIdentifier) + | BEGIN sqlClauses* RETURN SEMI? END SEMI? + ) + ; + +funcBodyReturnsScalar + : RETURNS dataType (WITH functionOption (COMMA functionOption)*)? AS? ( + (EXTERNAL NAME dotIdentifier) + | BEGIN sqlClauses* RETURN expression SEMI? END + ) + ; + +functionOption + : ENCRYPTION + | SCHEMABINDING + | RETURNS NULL ON NULL INPUT + | CALLED ON NULL INPUT + | executeAs + ; + +createStatistics + : CREATE STATISTICS id ON tableName LPAREN columnNameList RPAREN ( + WITH (FULLSCAN | SAMPLE INT (PERCENT | ROWS) | STATS_STREAM) (COMMA NORECOMPUTE)? ( + COMMA INCREMENTAL EQ onOff + )? + )? SEMI? + ; + +updateStatistics + : UPDATE STATISTICS tableName (id | LPAREN id ( COMMA id)* RPAREN)? updateStatisticsOptions? + ; + +updateStatisticsOptions: WITH updateStatisticsOption (COMMA updateStatisticsOption)* + ; + +updateStatisticsOption: RESAMPLE onPartitions? | optionList + ; + +createTable: CREATE (createExternal | createInternal) + ; + +createInternal + : TABLE tableName (LPAREN columnDefTableConstraints COMMA? RPAREN)? tableOptions? + // This sequence looks strange but alloes CTAS and normal CREATE TABLE to be parsed + createTableAs? tableOptions? (ON id | DEFAULT | onPartitionOrFilegroup)? ( + TEXTIMAGE_ON id + | DEFAULT + )? SEMI? + ; + +createExternal + : EXTERNAL TABLE tableName (LPAREN columnDefTableConstraints RPAREN)? WITH LPAREN optionList RPAREN ( + AS selectStatementStandalone + ) SEMI? + ; + +table: TABLE tableName (LPAREN columnDefTableConstraints? COMMA? RPAREN)? + ; + +createTableAs + : AS selectStatementStandalone + | AS FILETABLE WITH lparenOptionList + | AS (NODE | EDGE) + | AS /* CLONE */ id OF dotIdentifier (AT_KEYWORD STRING)? + ; + +tableIndices + : INDEX id UNIQUE? clustered? LPAREN columnNameListWithOrder RPAREN + | INDEX id CLUSTERED COLUMNSTORE + | INDEX id NONCLUSTERED? COLUMNSTORE LPAREN columnNameList RPAREN createTableIndexOptions? ( + ON id + )? + ; + +tableOptions + : WITH (LPAREN tableOption (COMMA tableOption)* RPAREN | tableOption (COMMA tableOption)*) + ; + +distributionType: HASH LPAREN id (COMMA id)* RPAREN | ROUND_ROBIN | REPLICATE + ; + +tableOption + : DISTRIBUTION EQ distributionType + | CLUSTERED INDEX LPAREN id (ASC | DESC)? ( COMMA id (ASC | DESC)?)* RPAREN + | DATA_COMPRESSION EQ (NONE | ROW | PAGE) onPartitions? + | XML_COMPRESSION EQ onOff onPartitions? + | id EQ ( + OFF (LPAREN id RPAREN)? + | ON (LPAREN tableOptionElement (COMMA tableOptionElement)* RPAREN)? + ) + | genericOption + ; + +tableOptionElement: id EQ dotIdentifier LPAREN optionList RPAREN | genericOption + ; + +createTableIndexOptions + : WITH LPAREN createTableIndexOption (COMMA createTableIndexOption)* RPAREN + ; + +createTableIndexOption + : DATA_COMPRESSION EQ (NONE | ROW | PAGE | COLUMNSTORE | COLUMNSTORE_ARCHIVE) onPartitions? + | XML_COMPRESSION EQ onOff onPartitions? + | genericOption + ; + +createView + : (CREATE (OR ALTER)? | ALTER) VIEW dotIdentifier (LPAREN columnNameList RPAREN)? ( + WITH optionList + )? AS selectStatementStandalone (WITH genericOption)? SEMI? + ; + +alterTable + : ALTER TABLE tableName ( + alterTableColumn + | WITH genericOption // CHECK | NOCHECK + | alterTableAdd + | alterTableDrop + | REBUILD tableOptions + | SWITCH switchPartition + | SET LPAREN id EQ ON LPAREN optionList RPAREN RPAREN + | (SET LPAREN optionList RPAREN | genericOption)+ + ) SEMI? + ; + +alterTableDrop + : DROP ( + dropSet (COMMA dropSet)* + | WITH? (CHECK | NOCHECK) CONSTRAINT (ALL | id (COMMA id)*) + | (ENABLE | DISABLE) TRIGGER (ALL | id (COMMA id)*) + | (ENABLE | DISABLE) CHANGE_TRACKING (WITH LPAREN genericOption RPAREN)? + ) SEMI? + ; + +dropSet + : CONSTRAINT? (IF EXISTS)? dropId (COMMA dropId)* + | COLUMN (IF EXISTS)? id (COMMA id)* + | /* PERIOD */ id FOR /* SYSTEM_TIME */ id + ; + +dropId: id (WITH dropClusteredConstraintOption (COMMA dropClusteredConstraintOption)*?) + ; + +dropClusteredConstraintOption + : genericOption + | MOVE TO id (LPAREN dotIdentifier RPAREN | id | STRING) + ; + +alterTableAdd + : ADD ( + ( computedColumnDefinition | tableConstraint | columnSetDefinition)+ + | (alterGenerated (COMMA alterGenerated)*)? + ) + ; + +alterGenerated + : dotIdentifier id GENERATED ALWAYS AS (ROW | TRANSACTION_ID | SEQUENCE_NUMBER) (START | END) HIDDEN_KEYWORD? ( + NOT? NULL + )? (CONSTRAINT id)? DEFAULT expression (WITH VALUES)? + | /* PERIOD */ id FOR /* SYSTEM_TIME */ id LPAREN (dotIdentifier COMMA dotIdentifier) RPAREN + ; + +alterTableColumn + : ALTER COLUMN dotIdentifier ( + (LPAREN INT (COMMA INT)? RPAREN | xmlSchemaCollection) (COLLATE id)? (NULL | NOT NULL)? SPARSE? ( + WITH LPAREN genericOption RPAREN + )? + | (ADD | DROP) genericOption (WITH LPAREN genericOption RPAREN)? + ) + ; + +switchPartition + : (PARTITION? expression)? TO tableName (PARTITION expression)? (WITH lowPriorityLockWait)? + ; + +lowPriorityLockWait + : WAIT_AT_LOW_PRIORITY LPAREN MAX_DURATION EQ expression MINUTES? COMMA ABORT_AFTER_WAIT EQ abortAfterWait = ( + NONE + | SELF + | BLOCKERS + ) RPAREN + ; + +alterDatabase + : ALTER DATABASE (id | CURRENT) ( + MODIFY NAME EQ id + | COLLATE id + | SET databaseOptionspec (WITH termination)? + | addOrModifyFiles + | addOrModifyFilegroups + ) SEMI? + ; + +addOrModifyFiles + : ADD id? /* LOG */ FILE fileSpec (COMMA fileSpec)* (TO FILEGROUP id)? + | REMOVE FILE (id | fileSpec) + ; + +fileSpec + : LPAREN NAME EQ idOrString (COMMA NEWNAME EQ idOrString)? (COMMA FILENAME EQ STRING)? ( + COMMA SIZE EQ fileSize + )? (COMMA MAXSIZE EQ (fileSize) | UNLIMITED)? (COMMA FILEGROWTH EQ fileSize)? (COMMA OFFLINE)? RPAREN + ; + +addOrModifyFilegroups + : ADD FILEGROUP id (CONTAINS FILESTREAM | CONTAINS MEMORY_OPTIMIZED_DATA)? + | REMOVE FILEGROUP id + | MODIFY FILEGROUP id ( + filegroupUpdatabilityOption + | DEFAULT + | NAME EQ id + | AUTOGROW_SINGLE_FILE + | AUTOGROW_ALL_FILES + ) + ; + +filegroupUpdatabilityOption: READONLY | READWRITE | READ_ONLY | READ_WRITE + ; + +databaseOptionspec + : changeTrackingOption + | autoOption + | containmentOption + | cursorOption + | databaseMirroringOption + | dateCorrelationOptimizationOption + | dbEncryptionOption + | dbStateOption + | dbUpdateOption + | dbUserAccessOption + | delayedDurabilityOption + | externalAccessOption + | FILESTREAM databaseFilestreamOption + | hadrOptions + | mixedPageAllocationOption + | recoveryOption + | serviceBrokerOption + | snapshotOption + | sqlOption + | targetRecoveryTimeOption + | termination + | queryStoreOption + | genericOption + ; + +queryStoreOption + : QUERY_STORE ( + EQ ( + OFF (LPAREN /* FORCED */ id RPAREN)? + | ON (LPAREN queryStoreElementOpt (COMMA queryStoreElementOpt)* RPAREN)? + ) + | LPAREN queryStoreElementOpt RPAREN + | ALL + | id + ) + ; + +queryStoreElementOpt: id EQ LPAREN optionList RPAREN | genericOption + ; + +autoOption + : AUTO_CLOSE onOff + | AUTO_CREATE_STATISTICS OFF + | ON ( INCREMENTAL EQ ON | OFF) + | AUTO_SHRINK onOff + | AUTO_UPDATE_STATISTICS onOff + | AUTO_UPDATE_STATISTICS_ASYNC (ON | OFF) + ; + +changeTrackingOption + : CHANGE_TRACKING ( + EQ ( OFF | ON) + | LPAREN (changeTrackingOpt ( COMMA changeTrackingOpt)*) RPAREN + ) + ; + +changeTrackingOpt: AUTO_CLEANUP EQ onOff | CHANGE_RETENTION EQ INT ( DAYS | HOURS | MINUTES) + ; + +containmentOption: CONTAINMENT EQ (NONE | PARTIAL) + ; + +cursorOption: CURSOR_CLOSE_ON_COMMIT onOff | CURSOR_DEFAULT ( LOCAL | GLOBAL) + ; + +alterEndpoint + : ALTER ENDPOINT id (AUTHORIZATION id)? (STATE EQ state = (STARTED | STOPPED | DISABLED))? AS TCP LPAREN endpointListenerClause RPAREN ( + FOR TSQL LPAREN RPAREN + | FOR SERVICE_BROKER LPAREN endpointAuthenticationClause ( + COMMA? endpointEncryptionAlogorithmClause + )? (COMMA? MESSAGE_FORWARDING EQ (ENABLED | DISABLED))? ( + COMMA? MESSAGE_FORWARD_SIZE EQ INT + )? RPAREN + | FOR DATABASE_MIRRORING LPAREN endpointAuthenticationClause ( + COMMA? endpointEncryptionAlogorithmClause + )? COMMA? ROLE EQ (WITNESS | PARTNER | ALL) RPAREN + ) + ; + +databaseMirroringOption: mirroringSetOption + ; + +mirroringSetOption: mirroringPartner partnerOption | mirroringWitness witnessOption + ; + +mirroringPartner: PARTNER + ; + +mirroringWitness: WITNESS + ; + +witnessPartnerEqual: EQ + ; + +partnerOption + : witnessPartnerEqual partnerServer + | FAILOVER + | FORCE_SERVICE_ALLOW_DATA_LOSS + | OFF + | RESUME + | SAFETY (FULL | OFF) + | SUSPEND + | TIMEOUT INT + ; + +witnessOption: witnessPartnerEqual witnessServer | OFF + ; + +witnessServer: partnerServer + ; + +partnerServer: partnerServerTcpPrefix host mirroringHostPortSeperator portNumber + ; + +mirroringHostPortSeperator: COLON + ; + +partnerServerTcpPrefix: TCP COLON DOUBLE_FORWARD_SLASH + ; + +portNumber: INT + ; + +host: id DOT host | (id DOT | id) + ; + +dateCorrelationOptimizationOption: DATE_CORRELATION_OPTIMIZATION onOff + ; + +dbEncryptionOption: ENCRYPTION onOff + ; + +dbStateOption: (ONLINE | OFFLINE | EMERGENCY) + ; + +dbUpdateOption: READ_ONLY | READ_WRITE + ; + +dbUserAccessOption: SINGLE_USER | RESTRICTED_USER | MULTI_USER + ; + +delayedDurabilityOption: genericOption + ; + +externalAccessOption + : DB_CHAINING onOff + | TRUSTWORTHY onOff + | DEFAULT_LANGUAGE EQ ( id | STRING) + | DEFAULT_FULLTEXT_LANGUAGE EQ ( id | STRING) + | NESTED_TRIGGERS EQ ( OFF | ON) + | TRANSFORM_NOISE_WORDS EQ ( OFF | ON) + | TWO_DIGIT_YEAR_CUTOFF EQ INT + ; + +hadrOptions: HADR ( (AVAILABILITY GROUP EQ id | OFF) | (SUSPEND | RESUME)) + ; + +mixedPageAllocationOption: MIXED_PAGE_ALLOCATION (OFF | ON) + ; + +recoveryOption: genericOption + ; + +serviceBrokerOption + : ENABLE_BROKER + | DISABLE_BROKER + | NEW_BROKER + | ERROR_BROKER_CONVERSATIONS + | HONOR_BROKER_PRIORITY onOff + ; + +snapshotOption + : ALLOW_SNAPSHOT_ISOLATION onOff + | READ_COMMITTED_SNAPSHOT (ON | OFF) + | MEMORY_OPTIMIZED_ELEVATE_TO_SNAPSHOT = (ON | OFF) + ; + +sqlOption + : ANSI_NULL_DEFAULT onOff + | ANSI_NULLS onOff + | ANSI_PADDING onOff + | ANSI_WARNINGS onOff + | ARITHABORT onOff + | COMPATIBILITY_LEVEL EQ INT + | CONCAT_NULL_YIELDS_NULL onOff + | NUMERIC_ROUNDABORT onOff + | QUOTED_IDENTIFIER onOff + | RECURSIVE_TRIGGERS onOff + ; + +targetRecoveryTimeOption: TARGET_RECOVERY_TIME EQ INT (SECONDS | MINUTES) + ; + +termination: ROLLBACK AFTER INT | ROLLBACK IMMEDIATE | NO_WAIT + ; + +dropIndex + : DROP INDEX (IF EXISTS)? ( + dropRelationalOrXmlOrSpatialIndex (COMMA dropRelationalOrXmlOrSpatialIndex)* + | dropBackwardCompatibleIndex ( COMMA dropBackwardCompatibleIndex)* + ) SEMI? + ; + +dropRelationalOrXmlOrSpatialIndex: id ON tableName + ; + +dropBackwardCompatibleIndex: dotIdentifier + ; + +dropTrigger: dropDmlTrigger | dropDdlTrigger + ; + +dropDmlTrigger: DROP TRIGGER (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* SEMI? + ; + +dropDdlTrigger + : DROP TRIGGER (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* ON (DATABASE | ALL SERVER) SEMI? + ; + +dropFunction: DROP FUNCTION (IF EXISTS)? dotIdentifier ( COMMA dotIdentifier)* SEMI? + ; + +dropStatistics: DROP STATISTICS (COMMA? dotIdentifier)+ SEMI + ; + +dropTable: DROP TABLE (IF EXISTS)? tableName (COMMA tableName)* SEMI? + ; + +dropView: DROP VIEW (IF EXISTS)? dotIdentifier (COMMA dotIdentifier)* SEMI? + ; + +createType + : CREATE TYPE dotIdentifier (FROM dataType nullNotnull?)? ( + AS TABLE LPAREN columnDefTableConstraints RPAREN + )? + ; + +dropType: DROP TYPE (IF EXISTS)? dotIdentifier + ; + +rowsetFunctionLimited: openquery | opendatasource + ; + +openquery: OPENQUERY LPAREN id COMMA STRING RPAREN + ; + +opendatasource: OPENDATASOURCE LPAREN STRING COMMA STRING RPAREN dotIdentifier + ; + +// TODO: JI - Simplify me +declareStatement + : DECLARE LOCAL_ID AS? (dataType | tableTypeDefinition | tableName | xmlTypeDefinition) + | DECLARE declareLocal (COMMA declareLocal)* + | WITH xmlNamespaces + ; + +cursorStatement + : CLOSE GLOBAL? cursorName SEMI? + | DEALLOCATE GLOBAL? CURSOR? cursorName SEMI? + | declareCursor + | fetchCursor + | OPEN GLOBAL? cursorName SEMI? + ; + +backupDatabase + : BACKUP DATABASE id (READ_WRITE_FILEGROUPS (COMMA optionList)?)? optionList? (TO optionList) ( + MIRROR TO optionList + )? ( + WITH ( + ENCRYPTION LPAREN ALGORITHM EQ genericOption COMMA SERVER CERTIFICATE EQ genericOption RPAREN + | optionList + ) + )? SEMI? + ; + +backupLog: BACKUP id /* LOG */ id TO optionList (MIRROR TO optionList)? ( WITH optionList)? + ; + +backupCertificate + : BACKUP CERTIFICATE id TO FILE EQ STRING ( + WITH PRIVATE KEY LPAREN ( + COMMA? FILE EQ STRING + | COMMA? ENCRYPTION BY PASSWORD EQ STRING + | COMMA? DECRYPTION BY PASSWORD EQ STRING + )+ RPAREN + )? + ; + +backupMasterKey: BACKUP MASTER KEY TO FILE EQ STRING ENCRYPTION BY PASSWORD EQ STRING + ; + +backupServiceMasterKey + : BACKUP SERVICE MASTER KEY TO FILE EQ STRING ENCRYPTION BY PASSWORD EQ STRING + ; + +killStatement: KILL (killProcess | killQueryNotification | killStatsJob) + ; + +killProcess: (sessionId = (INT | STRING) | UOW) (WITH STATUSONLY)? + ; + +killQueryNotification: QUERY NOTIFICATION SUBSCRIPTION ( ALL | INT) + ; + +killStatsJob: STATS JOB INT + ; + +executeStatement: EXECUTE executeBody SEMI? + ; + +executeBodyBatch + : jinjaTemplate + | dotIdentifier (executeStatementArg (COMMA executeStatementArg)*)? SEMI? + ; + +executeBody + : (LOCAL_ID EQ)? (dotIdentifier | executeVarString) ( + executeStatementArg (COMMA executeStatementArg)* + )? + | LPAREN executeVarString (COMMA executeVarString)* RPAREN (AS (LOGIN | USER) EQ STRING)? ( + AT_KEYWORD id + )? + | AS ( (LOGIN | USER) EQ STRING | CALLER) + ; + +// In practice unnamed arguments must precede named arguments, but we assume the input is syntactically valid +// and accept them in any order to simplitfy the grammar +executeStatementArg: (LOCAL_ID EQ)? executeParameter + ; + +executeParameter: ( constant | LOCAL_ID (OUTPUT | OUT)? | id | DEFAULT | NULL) + ; + +executeVarString + : LOCAL_ID (OUTPUT | OUT)? (PLUS LOCAL_ID (PLUS executeVarString)?)? + | STRING (PLUS LOCAL_ID (PLUS executeVarString)?)? + ; + +securityStatement + : executeAs SEMI? + | GRANT (ALL PRIVILEGES? | grantPermission (LPAREN columnNameList RPAREN)?) ( + ON (classTypeForGrant COLON COLON)? tableName + )? TO toPrincipal += principalId (COMMA toPrincipal += principalId)* (WITH GRANT OPTION)? ( + AS principalId + )? SEMI? + | REVERT (WITH COOKIE EQ LOCAL_ID)? SEMI? + | openKey + | closeKey + | createKey + | createCertificate + ; + +principalId: id | PUBLIC + ; + +createCertificate + : CREATE CERTIFICATE id (AUTHORIZATION id)? (FROM existingKeys | generateNewKeys) ( + ACTIVE FOR BEGIN DIALOG EQ onOff + )? + ; + +existingKeys + : ASSEMBLY id + | EXECUTABLE? FILE EQ STRING (WITH PRIVATE KEY LPAREN privateKeyOptions RPAREN)? + ; + +privateKeyOptions + : (FILE | HEX) EQ STRING (COMMA (DECRYPTION | ENCRYPTION) BY PASSWORD EQ STRING)? + ; + +generateNewKeys: (ENCRYPTION BY PASSWORD EQ STRING)? WITH SUBJECT EQ STRING ( COMMA dateOptions)* + ; + +dateOptions: (START_DATE | EXPIRY_DATE) EQ STRING + ; + +openKey + : OPEN SYMMETRIC KEY id DECRYPTION BY decryptionMechanism + | OPEN MASTER KEY DECRYPTION BY PASSWORD EQ STRING + ; + +closeKey: CLOSE SYMMETRIC KEY id | CLOSE ALL SYMMETRIC KEYS | CLOSE MASTER KEY + ; + +createKey + : CREATE MASTER KEY ENCRYPTION BY PASSWORD EQ STRING + | CREATE SYMMETRIC KEY id (AUTHORIZATION id)? (FROM PROVIDER id)? WITH ( + (keyOptions | ENCRYPTION BY encryptionMechanism) COMMA? + )+ + ; + +keyOptions + : KEY_SOURCE EQ STRING + | ALGORITHM EQ algorithm + | IDENTITY_VALUE EQ STRING + | PROVIDER_KEY_NAME EQ STRING + | CREATION_DISPOSITION EQ (CREATE_NEW | OPEN_EXISTING) + ; + +algorithm + : DES + | TRIPLE_DES + | TRIPLE_DES_3KEY + | RC2 + | RC4 + | RC4_128 + | DESX + | AES_128 + | AES_192 + | AES_256 + ; + +encryptionMechanism: CERTIFICATE id | ASYMMETRIC KEY id | SYMMETRIC KEY id | PASSWORD EQ STRING + ; + +decryptionMechanism + : CERTIFICATE id (WITH PASSWORD EQ STRING)? + | ASYMMETRIC KEY id (WITH PASSWORD EQ STRING)? + | SYMMETRIC KEY id + | PASSWORD EQ STRING + ; + +grantPermission + : ADMINISTER genericOption + | ALTER ( ANY? genericOption)? + | AUTHENTICATE SERVER? + | BACKUP genericOption + | CHECKPOINT + | CONNECT genericOption? + | CONTROL SERVER? + | CREATE genericOption + | DELETE + | EXECUTE genericOption? + | EXTERNAL genericOption + | IMPERSONATE genericOption? + | INSERT + | KILL genericOption + | RECEIVE + | REFERENCES + | SELECT genericOption? + | SEND + | SHOWPLAN + | SHUTDOWN + | SUBSCRIBE QUERY NOTIFICATIONS + | TAKE OWNERSHIP + | UNMASK + | UNSAFE ASSEMBLY + | UPDATE + | VIEW ( ANY genericOption | genericOption) + ; + +setStatement + : SET LOCAL_ID (DOT id)? EQ expression + | SET LOCAL_ID assignmentOperator expression + | SET LOCAL_ID EQ CURSOR declareSetCursorCommon (FOR (READ ONLY | UPDATE (OF columnNameList)?))? + | setSpecial + ; + +transactionStatement + : BEGIN DISTRIBUTED (TRAN | TRANSACTION) (id | LOCAL_ID)? + | BEGIN (TRAN | TRANSACTION) ( (id | LOCAL_ID) (WITH MARK STRING)?)? + | COMMIT (TRAN | TRANSACTION) ( + (id | LOCAL_ID) (WITH LPAREN DELAYED_DURABILITY EQ (OFF | ON) RPAREN)? + )? + | COMMIT WORK? + | COMMIT id + | ROLLBACK id + | ROLLBACK (TRAN | TRANSACTION) (id | LOCAL_ID)? + | ROLLBACK WORK? + | SAVE (TRAN | TRANSACTION) (id | LOCAL_ID)? + ; + +goStatement: GO INT? SEMI? + ; + +useStatement: USE id + ; + +setuserStatement: SETUSER STRING? + ; + +reconfigureStatement: RECONFIGURE (WITH OVERRIDE)? + ; + +shutdownStatement: SHUTDOWN (WITH genericOption)? + ; + +checkpointStatement: CHECKPOINT (INT)? + ; + +dbccCheckallocOption: ALL_ERRORMSGS | NO_INFOMSGS | TABLOCK | ESTIMATEONLY + ; + +dbccCheckalloc + : CHECKALLOC ( + LPAREN (id | STRING | INT) ( + COMMA NOINDEX + | COMMA ( REPAIR_ALLOW_DATA_LOSS | REPAIR_FAST | REPAIR_REBUILD) + )? RPAREN (WITH dbccCheckallocOption (COMMA dbccCheckallocOption)*)? + )? + ; + +dbccCheckcatalog: CHECKCATALOG (LPAREN (id | STRING | INT) RPAREN)? ( WITH NO_INFOMSGS)? + ; + +dbccCheckconstraintsOption: ALL_CONSTRAINTS | ALL_ERRORMSGS | NO_INFOMSGS + ; + +dbccCheckconstraints + : CHECKCONSTRAINTS (LPAREN (id | STRING) RPAREN)? ( + WITH dbccCheckconstraintsOption ( COMMA dbccCheckconstraintsOption)* + )? + ; + +dbccCheckdbTableOption: genericOption + ; + +dbccCheckdb + : CHECKDB ( + LPAREN (id | STRING | INT) ( + COMMA (NOINDEX | REPAIR_ALLOW_DATA_LOSS | REPAIR_FAST | REPAIR_REBUILD) + )? RPAREN + )? (WITH dbccCheckdbTableOption ( COMMA dbccCheckdbTableOption)*)? + ; + +dbccCheckfilegroupOption: genericOption + ; + +dbccCheckfilegroup + : CHECKFILEGROUP ( + LPAREN (INT | STRING) ( + COMMA (NOINDEX | REPAIR_ALLOW_DATA_LOSS | REPAIR_FAST | REPAIR_REBUILD) + )? RPAREN + )? (WITH dbccCheckfilegroupOption ( COMMA dbccCheckfilegroupOption)*)? + ; + +dbccChecktable + : CHECKTABLE LPAREN STRING ( + COMMA (NOINDEX | expression | REPAIR_ALLOW_DATA_LOSS | REPAIR_FAST | REPAIR_REBUILD) + )? RPAREN (WITH dbccCheckdbTableOption (COMMA dbccCheckdbTableOption)*)? + ; + +dbccCleantable + : CLEANTABLE LPAREN (id | STRING | INT) COMMA (id | STRING) (COMMA INT)? RPAREN ( + WITH NO_INFOMSGS + )? + ; + +dbccClonedatabaseOption + : NO_STATISTICS + | NO_QUERYSTORE + | SERVICEBROKER + | VERIFY_CLONEDB + | BACKUP_CLONEDB + ; + +dbccClonedatabase + : CLONEDATABASE LPAREN id COMMA id RPAREN ( + WITH dbccClonedatabaseOption (COMMA dbccClonedatabaseOption)* + )? + ; + +dbccPdwShowspaceused: PDW_SHOWSPACEUSED (LPAREN id RPAREN)? ( WITH IGNORE_REPLICATED_TABLE_CACHE)? + ; + +dbccProccache: PROCCACHE (WITH NO_INFOMSGS)? + ; + +dbccShowcontigOption: genericOption + ; + +dbccShowcontig + : SHOWCONTIG (LPAREN expression ( COMMA expression)? RPAREN)? ( + WITH dbccShowcontigOption (COMMA dbccShowcontigOption)* + )? + ; + +dbccShrinklog + : SHRINKLOG (LPAREN SIZE EQ ((INT ( MB | GB | TB)) | DEFAULT) RPAREN)? (WITH NO_INFOMSGS)? + ; + +dbccDbreindex + : DBREINDEX LPAREN idOrString (COMMA idOrString ( COMMA expression)?)? RPAREN ( + WITH NO_INFOMSGS + )? + ; + +dbccDllFree: id LPAREN FREE RPAREN ( WITH NO_INFOMSGS)? + ; + +dbccDropcleanbuffers: DROPCLEANBUFFERS (LPAREN COMPUTE | ALL RPAREN)? (WITH NO_INFOMSGS)? + ; + +dbccClause + : DBCC ( + dbccCheckalloc + | dbccCheckcatalog + | dbccCheckconstraints + | dbccCheckdb + | dbccCheckfilegroup + | dbccChecktable + | dbccCleantable + | dbccClonedatabase + | dbccDbreindex + | dbccDllFree + | dbccDropcleanbuffers + | dbccPdwShowspaceused + | dbccProccache + | dbccShowcontig + | dbccShrinklog + ) + ; + +executeAs: EXECUTE AS (CALLER | SELF | OWNER | STRING) + ; + +declareLocal: LOCAL_ID AS? dataType (EQ expression)? + ; + +tableTypeDefinition: TABLE LPAREN columnDefTableConstraints ( COMMA? tableTypeIndices)* RPAREN + ; + +tableTypeIndices + : (((PRIMARY KEY | INDEX id) (CLUSTERED | NONCLUSTERED)?) | UNIQUE) LPAREN columnNameListWithOrder RPAREN + | CHECK LPAREN searchCondition RPAREN + ; + +columnDefTableConstraints: columnDefTableConstraint (COMMA? columnDefTableConstraint)* + ; + +columnDefTableConstraint + : columnDefinition + | computedColumnDefinition + | tableConstraint + | tableIndices + ; + +computedColumnDefinition: id AS expression (PERSISTED (NOT NULL)?)? columnConstraint? + ; + +columnSetDefinition: id XML id FOR id + ; + +columnDefinition: id dataType columnDefinitionElement* columnIndex? + ; + +columnDefinitionElement + : MASKED WITH LPAREN FUNCTION EQ STRING RPAREN + | defaultValue + | identityColumn + | generatedAs + | ROWGUIDCOL + | ENCRYPTED WITH LPAREN COLUMN_ENCRYPTION_KEY EQ STRING COMMA ENCRYPTION_TYPE EQ ( + DETERMINISTIC + | RANDOMIZED + ) COMMA ALGORITHM EQ STRING RPAREN + | columnConstraint + | genericOption // TSQL column flags and options that we cannot support in Databricks + ; + +generatedAs + : GENERATED ALWAYS AS (ROW | TRANSACTION_ID | SEQUENCE_NUMBER) (START | END) HIDDEN_KEYWORD? + ; + +identityColumn: IDENTITY (LPAREN INT COMMA INT RPAREN)? + ; + +defaultValue: (CONSTRAINT id)? DEFAULT expression + ; + +columnConstraint + : (CONSTRAINT id)? ( + NOT? NULL + | ((PRIMARY KEY | UNIQUE) clustered? primaryKeyOptions) + | ( (FOREIGN KEY)? foreignKeyOptions) + | checkConstraint + ) + ; + +columnIndex + : INDEX id clustered? createTableIndexOptions? onPartitionOrFilegroup? ( + FILESTREAM_ON id // CHeck for quoted "NULL" + )? + ; + +onPartitionOrFilegroup: ON ( ( id LPAREN id RPAREN) | id | DEFAULT_DOUBLE_QUOTE) + ; + +tableConstraint + : (CONSTRAINT cid = id)? ( + ((PRIMARY KEY | UNIQUE) clustered? LPAREN columnNameListWithOrder RPAREN primaryKeyOptions) + | ( FOREIGN KEY LPAREN columnNameList RPAREN foreignKeyOptions) + | ( CONNECTION LPAREN connectionNode ( COMMA connectionNode)* RPAREN) + | ( DEFAULT expression FOR defid = id ( WITH VALUES)?) + | checkConstraint + ) + ; + +connectionNode: id TO id + ; + +primaryKeyOptions: (WITH FILLFACTOR EQ INT)? alterTableIndexOptions? onPartitionOrFilegroup? + ; + +foreignKeyOptions + : REFERENCES tableName (LPAREN columnNameList RPAREN)? onDelete? onUpdate? ( + NOT FOR REPLICATION + )? + ; + +checkConstraint: CHECK (NOT FOR REPLICATION)? LPAREN searchCondition RPAREN + ; + +onDelete: ON DELETE (NO ACTION | CASCADE | SET NULL | SET DEFAULT) + ; + +onUpdate: ON UPDATE (NO ACTION | CASCADE | SET NULL | SET DEFAULT) + ; + +alterTableIndexOptions: WITH LPAREN alterTableIndexOption ( COMMA alterTableIndexOption)* RPAREN + ; + +alterTableIndexOption + : DATA_COMPRESSION EQ (NONE | ROW | PAGE | COLUMNSTORE | COLUMNSTORE_ARCHIVE) onPartitions? + | XML_COMPRESSION EQ onOff onPartitions? + | DISTRIBUTION EQ HASH LPAREN id RPAREN + | CLUSTERED INDEX LPAREN id (ASC | DESC)? ( COMMA id (ASC | DESC)?)* RPAREN + | ONLINE EQ (ON (LPAREN lowPriorityLockWait RPAREN)? | OFF) + | genericOption + ; + +declareCursor + : DECLARE cursorName ( + CURSOR ( declareSetCursorCommon ( FOR UPDATE (OF columnNameList)?)?)? + | (SEMI_SENSITIVE | INSENSITIVE)? SCROLL? CURSOR FOR selectStatementStandalone ( + FOR (READ ONLY | UPDATE | (OF columnNameList)) + )? + ) SEMI? + ; + +declareSetCursorCommon: declareSetCursorCommonPartial* FOR selectStatementStandalone + ; + +declareSetCursorCommonPartial + : (LOCAL | GLOBAL) + | (FORWARD_ONLY | SCROLL) + | (STATIC | KEYSET | DYNAMIC | FAST_FORWARD) + | (READ_ONLY | SCROLL_LOCKS | OPTIMISTIC) + | TYPE_WARNING + ; + +fetchCursor + : FETCH (( NEXT | PRIOR | FIRST | LAST | (ABSOLUTE | RELATIVE) expression)? FROM)? GLOBAL? cursorName ( + INTO LOCAL_ID (COMMA LOCAL_ID)* + )? SEMI? + ; + +setSpecial + : SET id (id | constant_LOCAL_ID | onOff) SEMI? + | SET STATISTICS expression onOff + // TODO: Extract these keywords (IO | TIME | XML | PROFILE) onOff SEMI? + | SET ROWCOUNT (LOCAL_ID | INT) SEMI? + | SET TEXTSIZE INT SEMI? + | SET TRANSACTION ISOLATION LEVEL ( + READ UNCOMMITTED + | READ COMMITTED + | REPEATABLE READ + | SNAPSHOT + | SERIALIZABLE + | INT + ) SEMI? + | SET IDENTITY_INSERT tableName onOff SEMI? + | SET specialList (COMMA specialList)* onOff + // TODO: Rework when it is time to implement SET modifyMethod + ; + +specialList + : ANSI_NULLS + | QUOTED_IDENTIFIER + | ANSI_PADDING + | ANSI_WARNINGS + | ANSI_DEFAULTS + | ANSI_NULL_DFLT_OFF + | ANSI_NULL_DFLT_ON + | ARITHABORT + | ARITHIGNORE + | CONCAT_NULL_YIELDS_NULL + | CURSOR_CLOSE_ON_COMMIT + | FMTONLY + | FORCEPLAN + | IMPLICIT_TRANSACTIONS + | NOCOUNT + | NOEXEC + | NUMERIC_ROUNDABORT + | PARSEONLY + | REMOTE_PROC_TRANSACTIONS + | SHOWPLAN_ALL + | SHOWPLAN_TEXT + | SHOWPLAN_XML + | XACT_ABORT + ; + +constant_LOCAL_ID: constant | LOCAL_ID + ; + +expression + : LPAREN expression RPAREN # exprPrecedence + | BIT_NOT expression # exprBitNot + | op = (PLUS | MINUS) expression # exprUnary + | expression op = (STAR | DIV | MOD) expression # exprOpPrec1 + | expression op = (PLUS | MINUS) expression # exprOpPrec2 + | expression op = (BIT_AND | BIT_XOR | BIT_OR) expression # exprOpPrec3 + | expression op = DOUBLE_BAR expression # exprOpPrec4 + | expression op = DOUBLE_COLON expression # exprOpPrec4 + | primitiveExpression # exprPrimitive + | functionCall # exprFunc + | functionValues # exprFuncVal + | expression COLLATE id # exprCollate + | caseExpression # exprCase + | expression timeZone # exprTz + | expression overClause # exprOver + | expression withinGroup # exprWithinGroup + | DOLLAR_ACTION # exprDollar + | expression DOT expression # exprDot + | LPAREN selectStatement RPAREN # exprSubquery + | ALL expression # exprAll + | DISTINCT expression # exprDistinct + | DOLLAR_ACTION # exprDollar + | STAR # exprStar + | id # exprId + | jinjaTemplate # exprJinja + ; + +// TODO: Implement this ? +parameter: PLACEHOLDER + ; + +timeZone + : AT_KEYWORD id ZONE expression // AT TIME ZONE + ; + +primitiveExpression: op = (DEFAULT | NULL | LOCAL_ID) | constant + ; + +caseExpression: CASE caseExpr = expression? switchSection+ ( ELSE elseExpr = expression)? END + ; + +withExpression: WITH xmlNamespaces? commonTableExpression ( COMMA commonTableExpression)* + ; + +commonTableExpression: id (LPAREN columnNameList RPAREN)? AS LPAREN selectStatement RPAREN + ; + +updateElem + : (l1 = LOCAL_ID EQ)? (fullColumnName | l2 = LOCAL_ID) op = ( + EQ + | PE + | ME + | SE + | DE + | MEA + | AND_ASSIGN + | XOR_ASSIGN + | OR_ASSIGN + ) expression # updateElemCol + | id DOT id LPAREN expressionList RPAREN # updateElemUdt + ; + +searchCondition + : LPAREN searchCondition RPAREN # scPrec + | NOT searchCondition # scNot + | searchCondition AND searchCondition # scAnd + | searchCondition OR searchCondition # scOr + | predicate # scPred + ; + +predicate + : EXISTS LPAREN selectStatement RPAREN # predExists + | freetextPredicate # predFreetext + | expression comparisonOperator expression # predBinop + | expression comparisonOperator (ALL | SOME | ANY) LPAREN selectStatement RPAREN # predASA + | expression NOT? BETWEEN expression AND expression # predBetween + | expression NOT? IN LPAREN (selectStatement | expressionList) RPAREN # predIn + | expression NOT? LIKE expression (ESCAPE expression)? # predLike + | expression IS NOT? NULL # predIsNull + | expression # predExpression + ; + +queryExpression + // INTERSECT has higher precedence than EXCEPT and UNION ALL. + // Reference: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/set-operators-except-and-intersect-transact-sql?view=sql-server-ver16#:~:text=following%20precedence + : LPAREN queryExpression RPAREN # queryInParenthesis + | queryExpression INTERSECT queryExpression # queryIntersect + | queryExpression (UNION ALL? | EXCEPT) queryExpression # queryUnion + | querySpecification # querySimple + ; + +querySpecification: SELECT (ALL | DISTINCT)? topClause? selectList selectOptionalClauses + ; + +selectOptionalClauses + // TODO: Fix ORDER BY; it needs to be outside the set operations instead of between. + : intoClause? fromClause? whereClause? groupByClause? havingClause? selectOrderByClause? + // Reference: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/set-operators-union-transact-sql?view=sql-server-ver16#c-using-union-of-two-select-statements-with-order-by + ; + +groupByClause + : GROUP BY ( + (ALL? expression (COMMA expression)*) (WITH id)? + // Note that id should be checked for CUBE or ROLLUP + | GROUPING SETS LPAREN groupingSetsItem ( COMMA groupingSetsItem)* RPAREN + ) + ; + +groupingSetsItem: LPAREN? expression (COMMA expression)* RPAREN? | LPAREN RPAREN + ; + +intoClause: INTO tableName + ; + +fromClause: FROM tableSources + ; + +whereClause: WHERE searchCondition + ; + +havingClause: HAVING searchCondition + ; + +topClause: TOP ( expression | LPAREN expression RPAREN) PERCENT? (WITH TIES)? + ; + +orderByClause: ORDER BY orderByExpression (COMMA orderByExpression)* + ; + +selectOrderByClause + : orderByClause ( + OFFSET expression (ROW | ROWS) (FETCH (FIRST | NEXT) expression (ROW | ROWS) ONLY)? + )? + ; + +// Unsupported in Databricks SQL +forClause + : FOR BROWSE + | FOR XML (RAW (LPAREN STRING RPAREN)? | AUTO) xmlCommonDirectives* ( + COMMA (XMLDATA | XMLSCHEMA (LPAREN STRING RPAREN)?) + )? (COMMA ELEMENTS (XSINIL | ABSENT)?)? + | FOR XML EXPLICIT xmlCommonDirectives* (COMMA XMLDATA)? + | FOR XML PATH (LPAREN STRING RPAREN)? xmlCommonDirectives* (COMMA ELEMENTS (XSINIL | ABSENT)?)? + | FOR JSON (AUTO | PATH) ( + COMMA (ROOT (LPAREN STRING RPAREN) | INCLUDE_NULL_VALUES | WITHOUT_ARRAY_WRAPPER) + )* + ; + +orderByExpression: expression (COLLATE expression)? (ASC | DESC)? + ; + +optionClause: OPTION lparenOptionList + ; + +// Note that comma is optional to cater for complications added by Jinja template element +// references, which, if they are specifying the presence of COMMA or not, will look consecutive +// elements without a COMMA, such as: +// +// select +// order_id, +// {%- for payment_method in payment_methods %} +// sum(case when payment_method = '{{payment_method}}' then amount end) as {{payment_method}}_amount +// {%- if not loop.last %},{% endif -%} +// {% endfor %} +selectList: selectListElem selectElemTempl* + ; + +selectElemTempl: COMMA selectListElem | COMMA? jinjaTemplate selectListElem? + ; + +asterisk: (INSERTED | DELETED) DOT STAR | (tableName DOT)? STAR + ; + +expressionElem: columnAlias EQ expression | expression (AS? columnAlias)? + ; + +selectListElem + : asterisk + | LOCAL_ID op = (PE | ME | SE | DE | MEA | AND_ASSIGN | XOR_ASSIGN | OR_ASSIGN | EQ) expression + | expressionElem + ; + +tableSources: source += tableSource (COMMA source += tableSource)* + ; + +tableSource: tableSourceItem joinPart* + ; + +// Almost all tableSource elements allow a table alias, and sone allow a list of column aliaes +// As this parser expects to see valid child anyway, we combine this into a single rule and +// then visit each possible table source individually, applying alias afterwards. This reduces +// rule complexity and parser complexity substantially. +tableSourceItem: tsiElement (asTableAlias columnAliasList?)? withTableHints? + ; + +tsiElement + : tableName # tsiNamedTable + | rowsetFunction # tsiRowsetFunction + | LPAREN derivedTable RPAREN # tsiDerivedTable + | changeTable # tsiChangeTable + | nodesMethod # tsiNodesMethod + | /* TODO (id DOT)? distinguish xml functions */ functionCall # tsiFunctionCall + | LOCAL_ID # tsiLocalId + | LOCAL_ID DOT functionCall # tsiLocalIdFunctionCall + | openXml # tsiOpenXml + | openJson # tsiOpenJson + | dotIdentifier? DOUBLE_COLON functionCall # tsiDoubleColonFunctionCall + | LPAREN tableSource RPAREN # tsiParenTableSource + | jinjaTemplate # tsiJinja + ; + +openJson + : OPENJSON LPAREN expression (COMMA expression)? RPAREN (WITH LPAREN jsonDeclaration RPAREN)? asTableAlias? + ; + +jsonDeclaration: jsonCol += jsonColumnDeclaration ( COMMA jsonCol += jsonColumnDeclaration)* + ; + +jsonColumnDeclaration: columnDeclaration (AS JSON)? + ; + +columnDeclaration: id dataType STRING? + ; + +changeTable: changeTableChanges | changeTableVersion + ; + +changeTableChanges + : CHANGETABLE LPAREN CHANGES tableName COMMA changesid = (NULL | INT | LOCAL_ID) RPAREN + ; + +changeTableVersion + : CHANGETABLE LPAREN VERSION tableName COMMA fullColumnNameList COMMA selectList RPAREN + ; + +joinPart: joinOn | crossJoin | apply | pivot | unpivot + ; + +outerJoin: (LEFT | RIGHT | FULL) OUTER? + ; + +joinType: INNER | outerJoin + ; + +joinOn + : joinType? (joinHint = (LOOP | HASH | MERGE | REMOTE))? JOIN source = tableSource ON cond = searchCondition + ; + +crossJoin: CROSS JOIN tableSourceItem + ; + +apply: (CROSS | OUTER) APPLY tableSourceItem + ; + +pivot: PIVOT pivotClause asTableAlias + ; + +unpivot: UNPIVOT unpivotClause asTableAlias + ; + +pivotClause: LPAREN expression FOR fullColumnName IN columnAliasList RPAREN + ; + +unpivotClause: LPAREN id FOR id IN LPAREN fullColumnNameList RPAREN RPAREN + ; + +fullColumnNameList: column += fullColumnName (COMMA column += fullColumnName)* + ; + +rowsetFunction + : ( + OPENROWSET LPAREN STRING COMMA ((STRING SEMI STRING SEMI STRING) | STRING) ( + COMMA (dotIdentifier | STRING) + ) RPAREN + ) + | (OPENROWSET LPAREN BULK STRING COMMA ( id EQ STRING COMMA optionList? | id) RPAREN) + ; + +derivedTable: selectStatement | tableValueConstructor | LPAREN tableValueConstructor RPAREN + ; + +functionCall + : builtInFunctions + | standardFunction + | freetextFunction + | partitionFunction + | hierarchyidStaticMethod + ; + +// Things that are just special values and not really functions, but are documented as if they are functions +functionValues: f = ( AAPSEUDO | SESSION_USER | SYSTEM_USER | USER) + ; + +// Standard functions that are built in but take standard syntax, or are +// some user function etc +standardFunction: funcId LPAREN (expression (COMMA expression)*)? RPAREN + ; + +funcId: id | FORMAT | LEFT | RIGHT | REPLACE | CONCAT + ; + +partitionFunction: (id DOT)? DOLLAR_PARTITION DOT id LPAREN expression RPAREN + ; + +freetextFunction + : f = ( + SEMANTICSIMILARITYDETAILSTABLE + | SEMANTICSIMILARITYTABLE + | SEMANTICKEYPHRASETABLE + | CONTAINSTABLE + | FREETEXTTABLE + ) LPAREN expression COMMA (expression | LPAREN expressionList RPAREN | STAR) COMMA expression ( + COMMA LANGUAGE expression + )? (COMMA expression)? RPAREN + ; + +freetextPredicate + : CONTAINS LPAREN ( + fullColumnName + | LPAREN fullColumnName (COMMA fullColumnName)* RPAREN + | STAR + | PROPERTY LPAREN fullColumnName COMMA expression RPAREN + ) COMMA expression RPAREN + | FREETEXT LPAREN tableName COMMA ( + fullColumnName + | LPAREN fullColumnName (COMMA fullColumnName)* RPAREN + | STAR + ) COMMA expression (COMMA LANGUAGE expression)? RPAREN + ; + +builtInFunctions + : NEXT VALUE FOR tableName # nextValueFor + | (CAST | TRY_CAST) LPAREN expression AS dataType RPAREN # cast + | JSON_ARRAY LPAREN expressionList? jsonNullClause? RPAREN # jsonArray + | JSON_OBJECT LPAREN (jsonKeyValue (COMMA jsonKeyValue)*)? jsonNullClause? RPAREN # jsonObject + ; + +jsonKeyValue: expression COLON expression + ; + +jsonNullClause: (loseNulls = ABSENT | NULL) ON NULL + ; + +hierarchyidStaticMethod + : HIERARCHYID DOUBLE_COLON (GETROOT LPAREN RPAREN | PARSE LPAREN input = expression RPAREN) + ; + +nodesMethod + : (locId = LOCAL_ID | valueId = fullColumnName | LPAREN selectStatement RPAREN) DOT NODES LPAREN xquery = STRING RPAREN + ; + +switchSection: WHEN searchCondition THEN expression + ; + +asTableAlias: AS? (jinjaTemplate | id | DOUBLE_QUOTE_ID) + ; + +withTableHints: WITH LPAREN tableHint (COMMA? tableHint)* RPAREN + ; + +tableHint + : INDEX EQ? LPAREN expressionList RPAREN + | FORCESEEK ( LPAREN expression LPAREN columnNameList RPAREN RPAREN)? + | genericOption + ; + +columnAliasList: LPAREN columnAlias (COMMA columnAlias)* RPAREN + ; + +columnAlias: jinjaTemplate | id | STRING + ; + +tableValueConstructor: VALUES tableValueRow (COMMA tableValueRow)* + ; + +tableValueRow: LPAREN expressionList RPAREN + ; + +expressionList: exp += expression (COMMA exp += expression)* + ; + +withinGroup: WITHIN GROUP LPAREN orderByClause RPAREN + ; + +// The ((IGNORE | RESPECT) NULLS)? is strictly speaking, not part of the OVER clause +// but trails certain windowing functions such as LAG and LEAD. However, all such functions +// must use the OVER clause, so it is included here to make building the IR simpler. +overClause + : ((IGNORE | RESPECT) NULLS)? OVER LPAREN (PARTITION BY expression (COMMA expression)*)? orderByClause? rowOrRangeClause? RPAREN + ; + +rowOrRangeClause: (ROWS | RANGE) windowFrameExtent + ; + +windowFrameExtent: windowFrameBound | BETWEEN windowFrameBound AND windowFrameBound + ; + +windowFrameBound: UNBOUNDED (PRECEDING | FOLLOWING) | INT (PRECEDING | FOLLOWING) | CURRENT ROW + ; + +databaseFilestreamOption + : LPAREN ((NON_TRANSACTED_ACCESS EQ ( OFF | READ_ONLY | FULL)) | ( DIRECTORY_NAME EQ STRING)) RPAREN + ; + +databaseFileSpec: fileGroup | fileSpecification + ; + +fileGroup + : FILEGROUP id (CONTAINS FILESTREAM)? (DEFAULT)? (CONTAINS MEMORY_OPTIMIZED_DATA)? fileSpecification ( + COMMA fileSpecification + )* + ; + +fileSpecification + : LPAREN NAME EQ (id | STRING) COMMA? FILENAME EQ file = STRING COMMA? ( + SIZE EQ fileSize COMMA? + )? (MAXSIZE EQ (fileSize | UNLIMITED) COMMA?)? (FILEGROWTH EQ fileSize COMMA?)? RPAREN + ; + +tableName: (linkedServer = id DOT DOT)? ids += id (DOT ids += id)* + ; + +dotIdentifier: id (DOT id)* + ; + +ddlObject: tableName | rowsetFunctionLimited | LOCAL_ID + ; + +fullColumnName: ((DELETED | INSERTED | tableName) DOT)? ( id | (DOLLAR (IDENTITY | ROWGUID))) + ; + +columnNameListWithOrder: columnNameWithOrder (COMMA columnNameWithOrder)* + ; + +columnNameWithOrder: id (ASC | DESC)? + ; + +columnNameList: id (COMMA id)* + ; + +cursorName: id | LOCAL_ID + ; + +onOff: ON | OFF + ; + +clustered: CLUSTERED | NONCLUSTERED + ; + +nullNotnull: NOT? NULL + ; + +beginConversationTimer + : BEGIN CONVERSATION TIMER LPAREN LOCAL_ID RPAREN TIMEOUT EQ expression SEMI? + ; + +beginConversationDialog + : BEGIN DIALOG (CONVERSATION)? LOCAL_ID FROM SERVICE serviceName TO SERVICE serviceName ( + COMMA STRING + )? ON CONTRACT contractName ( + WITH ((RELATED_CONVERSATION | RELATED_CONVERSATION_GROUP) EQ LOCAL_ID COMMA?)? ( + LIFETIME EQ (INT | LOCAL_ID) COMMA? + )? (ENCRYPTION EQ onOff)? + )? SEMI? + ; + +contractName: (id | expression) + ; + +serviceName: (id | expression) + ; + +endConversation + : END CONVERSATION LOCAL_ID SEMI? ( + WITH ( + ERROR EQ faliureCode = (LOCAL_ID | STRING) DESCRIPTION EQ failureText = ( + LOCAL_ID + | STRING + ) + )? CLEANUP? + )? + ; + +waitforConversation: WAITFOR? LPAREN getConversation RPAREN (COMMA? TIMEOUT expression)? SEMI? + ; + +getConversation + : GET CONVERSATION GROUP conversationGroupId = (STRING | LOCAL_ID) FROM dotIdentifier SEMI? + ; + +sendConversation + : SEND ON CONVERSATION conversationHandle = (STRING | LOCAL_ID) MESSAGE TYPE expression ( + LPAREN messageBodyExpression = (STRING | LOCAL_ID) RPAREN + )? SEMI? + ; + +dataType + : jinjaTemplate + | dataTypeIdentity + | XML LPAREN id RPAREN + | id (LPAREN (INT | MAX) (COMMA INT)? RPAREN)? + ; + +dataTypeList: dataType (COMMA dataType)* + ; + +dataTypeIdentity: id IDENTITY (LPAREN INT COMMA INT RPAREN)? + ; + +constant: con = (STRING | HEX | INT | REAL | FLOAT | MONEY) | parameter + ; + +id: ID | TEMP_ID | DOUBLE_QUOTE_ID | SQUARE_BRACKET_ID | NODEID | keyword | RAW + ; + +simpleId: ID + ; + +idOrString: id | STRING + ; + +// Spaces are allowed for comparison operators. +comparisonOperator + : EQ + | GT + | LT + | LT EQ + | GT EQ + | LT GT + | EQ + | BANG EQ + | GT + | BANG GT + | LT + | BANG LT + ; + +assignmentOperator: PE | ME | SE | DE | MEA | AND_ASSIGN | XOR_ASSIGN | OR_ASSIGN + ; + +fileSize: INT (KB | MB | GB | TB | MOD)? + ; + +/** + * The parenthesised option list is used in many places, so it is defined here. + */ +lparenOptionList: LPAREN optionList RPAREN + ; + +/** + * The generic option list is used in many places, so it is defined here. + */ +optionList: genericOption (COMMA genericOption)* + ; + +/** + * Generic options have a few different formats, but otherwise they can almost all be + * parsed generically rather than creating potentially hundreds of keywords and rules + * that obfusctate the grammar and make maintenance difficult as TSQL evolves. SQL is, + * or has become, a very verbose language with strange syntactical elements bolted in + * becuase they could not fit otherwise. So, as many options as possible are parsed + * here and the AST builders can decide what to do with them as they have context. + * + * Here are the various formats: + * + * KEYWORD - Just means the option is ON if it is present, OFF if NOT (but check semenatics) + * KEYWORD ON|OFF - The option is on or off (no consitency here) + * KEYWORD = VALUE - The option is set to a value - we accept any expression and assume + * the AST builder will check the type and range, OR that we require valid + * TSQL in the first place. + * KEYWORD = VALUE KB - Some sort of size value, where KB can be various things so is parsed as any id() + * KEYWORD (=)? DEFAULT - A fairly redundant option, but sometimes people want to be explicit + * KEYWORD (=)? AUTO - The option is set to AUTO, which occurs in a few places + * something FOR something - The option is a special case such as OPTIMIZE FOR UNKNOWN + * DEFAULT - The option is set to the default value but is not named + * ON - The option is on but is not named (will get just id) + * OFF - The option is off but is not named (will get just id) + * AUTO - The option is set to AUTO but is not named (will get just id) + * ALL - The option is set to ALL but is not named (will get just id) + */ +genericOption + : id EQ? ( + DEFAULT // Default value - don't resolve with expression + | ON // Simple ON - don't resolve with expression + | OFF // Simple OFF - don't resolve with expression + | AUTO // Simple AUTO - don't resolve with expression + | STRING // String value - don't resolve with expression + | FOR id // FOR id - don't resolve with expression - special case + | expression id? // Catch all for less explicit options, sometimes with extra keywords + )? + ; + +// XML stuff + +dropXmlSchemaCollection + : DROP XML SCHEMA COLLECTION (relationalSchema = id DOT)? sqlIdentifier = id + ; + +schemaDeclaration: columnDeclaration (COMMA columnDeclaration)* + ; + +createXmlSchemaCollection + : CREATE XML SCHEMA COLLECTION (relationalSchema = id DOT)? sqlIdentifier = id AS ( + STRING + | id + | LOCAL_ID + ) + ; + +openXml + : OPENXML LPAREN expression COMMA expression (COMMA expression)? RPAREN ( + WITH LPAREN schemaDeclaration RPAREN + )? asTableAlias? + ; + +xmlNamespaces: XMLNAMESPACES LPAREN xmlDeclaration (COMMA xmlDeclaration)* RPAREN + ; + +xmlDeclaration: STRING AS id | DEFAULT STRING + ; + +xmlTypeDefinition: XML LPAREN (CONTENT | DOCUMENT)? xmlSchemaCollection RPAREN + ; + +xmlSchemaCollection: ID DOT ID + ; + +createXmlIndex + : CREATE PRIMARY? XML INDEX id ON tableName LPAREN id RPAREN ( + USING XML INDEX id (FOR (VALUE | PATH | PROPERTY)?)? + )? xmlIndexOptions? SEMI? + ; + +xmlIndexOptions: WITH LPAREN xmlIndexOption (COMMA xmlIndexOption)* RPAREN + ; + +xmlIndexOption: ONLINE EQ (ON (LPAREN lowPriorityLockWait RPAREN)? | OFF) | genericOption + ; + +xmlCommonDirectives: COMMA ( BINARY id | TYPE | ROOT (LPAREN STRING RPAREN)?) + ; diff --git a/core/src/main/resources/log4j2.properties b/core/src/main/resources/log4j2.properties new file mode 100644 index 0000000000..6aad832e90 --- /dev/null +++ b/core/src/main/resources/log4j2.properties @@ -0,0 +1,25 @@ +# Set to debug or trace if log4j initialization is failing +status = warn + +# Root logger level +rootLogger.level = WARN + +# Console appender configuration +appender.console.type = Console +appender.console.name = consoleLogger +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %level [%C{1.}] - %msg%n + +# Root logger referring to console appender +rootLogger.appenderRef.stdout.ref = consoleLogger + + +logger.spark.name = org.apache.spark +logger.spark.level = INFO + +logger.databricks.name = com.databricks +logger.databricks.level = WARN + +logger.remorph.name = com.databricks.labs.remorph +logger.remorph.level = DEBUG + diff --git a/core/src/main/scala/com/databricks/labs/remorph/ApplicationContext.scala b/core/src/main/scala/com/databricks/labs/remorph/ApplicationContext.scala new file mode 100644 index 0000000000..61ac79de6f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/ApplicationContext.scala @@ -0,0 +1,66 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.coverage.estimation.{EstimationAnalyzer, Estimator, JsonEstimationReporter, SummaryEstimationReporter} +import com.databricks.labs.remorph.coverage.runners.EnvGetter +import com.databricks.labs.remorph.coverage.{CoverageTest, EstimationReport} +import com.databricks.labs.remorph.discovery.{FileQueryHistory, QueryHistoryProvider} +import com.databricks.labs.remorph.generators.orchestration.FileSetGenerator +import com.databricks.labs.remorph.queries.ExampleDebugger +import com.databricks.labs.remorph.support.SupportContext +import com.databricks.labs.remorph.support.snowflake.SnowflakeContext +import com.databricks.labs.remorph.support.tsql.TSqlContext +import com.databricks.labs.remorph.transpilers._ +import com.databricks.sdk.WorkspaceClient +import com.databricks.sdk.core.DatabricksConfig + +import java.io.File +import java.time.Instant + +trait ApplicationContext { + def flags: Map[String, String] + + def envGetter: EnvGetter = new EnvGetter() + + private lazy val supportContext: SupportContext = flags.get("dialect") match { + case Some("snowflake") => new SnowflakeContext(envGetter) + case Some("tsql") => new TSqlContext(envGetter) + case Some(unknown) => throw new IllegalArgumentException(s"--dialect=$unknown is not supported") + case None => throw new IllegalArgumentException("--dialect is required") + } + + lazy val queryHistoryProvider: QueryHistoryProvider = flags.get("source-queries") match { + case Some(folder) => new FileQueryHistory(new File(folder).toPath) + case None => supportContext.remoteQueryHistory + } + + protected val now = Instant.now + + def connectConfig: DatabricksConfig = new DatabricksConfig() + + def workspaceClient: WorkspaceClient = new WorkspaceClient(connectConfig) + + def prettyPrinter[T](v: T): Unit = pprint.pprintln[T](v) + + def exampleDebugger: ExampleDebugger = + new ExampleDebugger(supportContext.planParser, prettyPrinter, supportContext.name) + + def coverageTest: CoverageTest = new CoverageTest + + def estimator: Estimator = new Estimator(queryHistoryProvider, supportContext.planParser, new EstimationAnalyzer()) + + def jsonEstimationReporter( + outputDir: os.Path, + preserveQueries: Boolean, + estimate: EstimationReport): JsonEstimationReporter = + new JsonEstimationReporter(outputDir, preserveQueries, estimate) + + def consoleEstimationReporter(outputDir: os.Path, estimate: EstimationReport): SummaryEstimationReporter = + new SummaryEstimationReporter(outputDir, estimate) + + private def sqlGenerator: SqlGenerator = new SqlGenerator + + private def pySparkGenerator: PySparkGenerator = new PySparkGenerator + + def fileSetGenerator: FileSetGenerator = + new FileSetGenerator(supportContext.planParser, sqlGenerator, pySparkGenerator) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/Main.scala b/core/src/main/scala/com/databricks/labs/remorph/Main.scala new file mode 100644 index 0000000000..2733c84fb2 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/Main.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.generators.orchestration.rules.history.RawMigration +import io.circe.{Decoder, jackson} +import io.circe.generic.semiauto._ + +import java.io.File + +case class Payload(command: String, flags: Map[String, String]) + +object Payload { + implicit val payloadDecoder: Decoder[Payload] = deriveDecoder +} + +object Main extends App with ApplicationContext { + // scalastyle:off println + route match { + case Payload("debug-script", args) => + exampleDebugger.debugExample(args("name")) + case Payload("debug-me", _) => + prettyPrinter(workspaceClient.currentUser().me()) + case Payload("debug-coverage", args) => + coverageTest.run(os.Path(args("src")), os.Path(args("dst")), args("extractor")) + case Payload("debug-estimate", args) => + val report = estimator.run() + jsonEstimationReporter( + os.Path(args("dst")) / s"${now.getEpochSecond}", + args.get("preserve-queries").exists(_.toBoolean), + report).report() + args("console-output") match { + case "true" => consoleEstimationReporter(os.Path(args("dst")) / s"${now.getEpochSecond}", report).report() + } + case Payload("debug-bundle", args) => + val dst = new File(args("dst")) + val queryHistory = queryHistoryProvider.history() + fileSetGenerator.generate(RawMigration(queryHistory)).runAndDiscardState(TranspilerState()) match { + case OkResult(output) => output.persist(dst) + case PartialResult(output, error) => + prettyPrinter(error) + output.persist(dst) + case nok: KoResult => + prettyPrinter(nok) + } + case Payload(command, _) => + println(s"Unknown command: $command") + } + + // make CLI flags available for ApplicationContext + def flags: Map[String, String] = cliFlags + + // placeholder for global CLI flags + private[this] var cliFlags: Map[String, String] = Map.empty + + // parse json from the last CLI argument + private def route: Payload = { + val payload = jackson.decode[Payload](args.last).right.get + cliFlags = payload.flags + payload + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/Phase.scala b/core/src/main/scala/com/databricks/labs/remorph/Phase.scala new file mode 100644 index 0000000000..844a3b2c99 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/Phase.scala @@ -0,0 +1,76 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.generators.GeneratorContext +import com.databricks.labs.remorph.intermediate.{LogicalPlan, RemorphError, TreeNode} +import com.databricks.labs.remorph.preprocessors.jinja.TemplateManager +import org.antlr.v4.runtime.{CommonTokenStream, ParserRuleContext} + +case class TranspilerState(currentPhase: Phase = Init, templateManager: TemplateManager = new TemplateManager) { + def recordError(error: RemorphError): TranspilerState = copy(currentPhase = currentPhase.recordError(error)) +} + +sealed trait Phase { + def previousPhase: Option[Phase] + def recordError(error: RemorphError): Phase +} + +case object Init extends Phase { + override val previousPhase: Option[Phase] = None + + override def recordError(error: RemorphError): Init.type = this +} + +case class PreProcessing( + source: String, + filename: String = "-- test source --", + encounteredErrors: Seq[RemorphError] = Seq.empty, + tokenStream: Option[CommonTokenStream] = None, + preprocessedInputSoFar: String = "") + extends Phase { + override val previousPhase: Option[Phase] = Some(Init) + + override def recordError(error: RemorphError): PreProcessing = + copy(encounteredErrors = this.encounteredErrors :+ error) +} + +case class Parsing( + source: String, + filename: String = "-- test source --", + previousPhase: Option[PreProcessing] = None, + encounteredErrors: Seq[RemorphError] = Seq.empty) + extends Phase { + + override def recordError(error: RemorphError): Parsing = + copy(encounteredErrors = this.encounteredErrors :+ error) +} + +case class BuildingAst( + tree: ParserRuleContext, + previousPhase: Option[Parsing] = None, + encounteredErrors: Seq[RemorphError] = Seq.empty) + extends Phase { + override def recordError(error: RemorphError): BuildingAst = + copy(encounteredErrors = this.encounteredErrors :+ error) +} + +case class Optimizing( + unoptimizedPlan: LogicalPlan, + previousPhase: Option[BuildingAst] = None, + encounteredErrors: Seq[RemorphError] = Seq.empty) + extends Phase { + override def recordError(error: RemorphError): Optimizing = + this.copy(encounteredErrors = this.encounteredErrors :+ error) +} + +case class Generating( + optimizedPlan: LogicalPlan, + currentNode: TreeNode[_], + ctx: GeneratorContext, + totalStatements: Int = 0, + transpiledStatements: Int = 0, + previousPhase: Option[Optimizing] = None, + encounteredErrors: Seq[RemorphError] = Seq.empty) + extends Phase { + override def recordError(error: RemorphError): Generating = + copy(encounteredErrors = this.encounteredErrors :+ error) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/Result.scala b/core/src/main/scala/com/databricks/labs/remorph/Result.scala new file mode 100644 index 0000000000..ac5d90adfc --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/Result.scala @@ -0,0 +1,140 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.intermediate.RemorphError +import com.databricks.labs.remorph.preprocessors.jinja.TemplateManager + +sealed trait WorkflowStage +object WorkflowStage { + case object PARSE extends WorkflowStage + case object PLAN extends WorkflowStage + case object OPTIMIZE extends WorkflowStage + case object GENERATE extends WorkflowStage + case object FORMAT extends WorkflowStage +} + +/** + * Represents a stateful computation that will eventually produce an output of type Out. It manages a state of + * type TranspilerState along the way. Moreover, by relying on the semantics of Result, it is also able to handle + * errors in a controlled way. + * + * It is important to note that nothing won't get evaluated until the run or runAndDiscardState are called. + * @param run + * The computation that will be carried out by this computation. It is basically a function that takes a TranspilerState as + * parameter and returns a Result containing the - possibly updated - State along with the Output. + * @tparam Output + * The type of the produced output. + */ +final class Transformation[+Output](val run: TranspilerState => Result[(TranspilerState, Output)]) { + + /** + * Modify the output of this transformation using the provided function, without changing the managed state. + * + * If this transformation results in a KoResult, the provided function won't be evaluated. + */ + def map[B](f: Output => B): Transformation[B] = new Transformation(run.andThen(_.map { case (s, a) => (s, f(a)) })) + + /** + * Chain this transformation with another one by passing this transformation's output to the provided function. + * + * If this transformation results in a KoResult, the provided function won't be evaluated. + */ + def flatMap[B](f: Output => Transformation[B]): Transformation[B] = new Transformation(run.andThen { + case OkResult((s, a)) => f(a).run(s) + case p @ PartialResult((s, a), err) => p.flatMap { _ => f(a).run(s.recordError(err)) } + case ko: KoResult => ko + }) + + /** + * Runs the computation using the provided initial state and return a Result containing the transformation's output, + * discarding the final state. + */ + def runAndDiscardState(initialState: TranspilerState): Result[Output] = run(initialState).map(_._2) +} + +trait TransformationConstructors { + + /** + * Wraps a value into a successful transformation that ignores its state. + */ + def ok[A](a: A): Transformation[A] = new Transformation(s => OkResult((s, a))) + + /** + * Wraps an error into a failed transformation. + */ + def ko(stage: WorkflowStage, err: RemorphError): Transformation[Nothing] = new Transformation(s => + KoResult(stage, err)) + + /** + * Wraps a Result into a transformation that ignores its state. + */ + def lift[X](res: Result[X]): Transformation[X] = new Transformation(s => res.map(x => (s, x))) + + /** + * A tranformation whose output is the current state. + */ + def getCurrentPhase: Transformation[Phase] = new Transformation(s => OkResult((s, s.currentPhase))) + + /** + * A transformation that replaces the current state with the provided one, and produces no meaningful output. + */ + def setPhase(newPhase: Phase): Transformation[Unit] = new Transformation(s => + OkResult((s.copy(currentPhase = newPhase), ()))) + + /** + * A transformation that updates the current state using the provided partial function, and produces no meaningful + * output. If the provided partial function cannot be applied to the current state, it remains unchanged. + */ + def updatePhase(f: PartialFunction[Phase, Phase]): Transformation[Unit] = new Transformation(state => { + val newState = state.copy(currentPhase = f.applyOrElse(state.currentPhase, identity[Phase])) + OkResult((newState, ())) + }) + + def getTemplateManager: Transformation[TemplateManager] = new Transformation(s => OkResult((s, s.templateManager))) + + def updateTemplateManager(updateFunc: TemplateManager => TemplateManager): Transformation[Unit] = + new Transformation(s => OkResult((s.copy(templateManager = updateFunc(s.templateManager)), ()))) +} + +sealed trait Result[+A] { + def map[B](f: A => B): Result[B] + def flatMap[B](f: A => Result[B]): Result[B] + def isSuccess: Boolean + def withNonBlockingError(error: RemorphError): Result[A] +} + +case class OkResult[A](output: A) extends Result[A] { + override def map[B](f: A => B): Result[B] = OkResult(f(output)) + + override def flatMap[B](f: A => Result[B]): Result[B] = f(output) + + override def isSuccess: Boolean = true + + override def withNonBlockingError(error: RemorphError): Result[A] = PartialResult(output, error) +} + +case class PartialResult[A](output: A, error: RemorphError) extends Result[A] { + + override def map[B](f: A => B): Result[B] = PartialResult(f(output), error) + + override def flatMap[B](f: A => Result[B]): Result[B] = f(output) match { + case OkResult(res) => PartialResult(res, error) + case PartialResult(res, err) => PartialResult(res, RemorphError.merge(error, err)) + case KoResult(stage, err) => KoResult(stage, RemorphError.merge(error, err)) + } + + override def isSuccess: Boolean = true + + override def withNonBlockingError(newError: RemorphError): Result[A] = + PartialResult(output, RemorphError.merge(error, newError)) +} + +case class KoResult(stage: WorkflowStage, error: RemorphError) extends Result[Nothing] { + override def map[B](f: Nothing => B): Result[B] = this + + override def flatMap[B](f: Nothing => Result[B]): Result[B] = this + + override def isSuccess: Boolean = false + + override def withNonBlockingError(newError: RemorphError): Result[Nothing] = + KoResult(stage, RemorphError.merge(error, newError)) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/AcceptanceTestRunner.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/AcceptanceTestRunner.scala new file mode 100644 index 0000000000..45aebd0bcc --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/AcceptanceTestRunner.scala @@ -0,0 +1,25 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.queries.{AcceptanceTest, ExampleSource, QueryExtractor} + +case class AcceptanceTestConfig( + testFileSource: ExampleSource, + queryExtractor: QueryExtractor, + queryRunner: QueryRunner, + ignoredTestNames: String => Boolean = Set.empty, + shouldFailParse: String => Boolean = Set.empty) + +class AcceptanceTestRunner(config: AcceptanceTestConfig) { + + def shouldFailParse: String => Boolean = config.shouldFailParse + + def runAcceptanceTest(acceptanceTest: AcceptanceTest): Option[ReportEntryReport] = { + if (config.ignoredTestNames(acceptanceTest.testName)) { + None + } else { + config.queryExtractor.extractQuery(acceptanceTest.inputFile).map(config.queryRunner.runQuery) + } + } + + def foreachTest(f: AcceptanceTest => Unit): Unit = config.testFileSource.listTests.foreach(f) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/CoverageTest.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/CoverageTest.scala new file mode 100644 index 0000000000..52453a0130 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/CoverageTest.scala @@ -0,0 +1,86 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.queries.{CommentBasedQueryExtractor, NestedFiles, WholeFileQueryExtractor} + +import java.time.Instant + +import io.circe.generic.auto._ +import io.circe.syntax._ + +case class DialectCoverageTest(dialectName: String, queryRunner: QueryRunner) + +case class IndividualError(description: String, nbOccurrences: Int, example: String) + +case class ErrorsSummary(parseErrors: Seq[IndividualError], transpileErrors: Seq[IndividualError]) + +class CoverageTest extends ErrorEncoders { + + private[this] val dialectCoverageTests = Seq( + DialectCoverageTest("snowflake", new IsTranspiledFromSnowflakeQueryRunner), + DialectCoverageTest("tsql", new IsTranspiledFromTSqlQueryRunner)) + + private def getCurrentCommitHash: Option[String] = { + val gitRevParse = os.proc("/usr/bin/git", "rev-parse", "--short", "HEAD").call(os.pwd) + if (gitRevParse.exitCode == 0) { + Some(gitRevParse.out.trim()) + } else { + None + } + } + + private def timeToEpochNanos(instant: Instant) = { + val epoch = Instant.ofEpochMilli(0) + java.time.Duration.between(epoch, instant).toNanos + } + def run(sourceDir: os.Path, outputPath: os.Path, extractor: String): Unit = { + + val now = Instant.now + + val project = "remorph-core" + val commitHash = getCurrentCommitHash + + val outputFilePath = outputPath / s"$project-coverage-${timeToEpochNanos(now)}.json" + + os.makeDir.all(outputPath) + + val reportsByDialect: Seq[(String, Seq[ReportEntry])] = dialectCoverageTests.map { t => + val queryExtractor = extractor match { + case "comment" => new CommentBasedQueryExtractor(t.dialectName, "databricks") + case "full" => new WholeFileQueryExtractor + } + + t.dialectName -> new NestedFiles((sourceDir / t.dialectName).toNIO).listTests.flatMap { test => + queryExtractor + .extractQuery(test.inputFile) + .map { exampleQuery => + val runner = t.queryRunner + val header = ReportEntryHeader( + project = project, + commit_hash = commitHash, + version = "latest", + timestamp = now.toString, + source_dialect = t.dialectName, + target_dialect = "databricks", + file = os.Path(test.inputFile).relativeTo(sourceDir).toString) + val report = runner.runQuery(exampleQuery) + ReportEntry(header, report) + } + .toSeq + + } + } + + reportsByDialect.foreach { case (dialect, reports) => + reports.foreach(report => os.write.append(outputFilePath, report.asJson.noSpaces + "\n")) + val total = reports.size + val parsed = reports.map(_.report.parsed).sum + val transpiled = reports.map(_.report.transpiled).sum + // scalastyle:off + println( + s"remorph -> $dialect: ${100d * parsed / total}% parsed ($parsed / $total)," + + s" ${100d * transpiled / total}% transpiled ($transpiled / $total)") + // scalastyle:on + } + + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/EstimationReport.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/EstimationReport.scala new file mode 100644 index 0000000000..55316dcc9e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/EstimationReport.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.coverage.estimation.{EstimationStatistics, RuleScore, SqlComplexity} +import com.databricks.labs.remorph.discovery.Fingerprint +import com.databricks.labs.remorph.intermediate.RemorphError + +case class EstimationReport( + overallComplexity: EstimationStatistics, + dialect: String, // What dialect of SQL is this a report for? + sampleSize: Int, // How many records were used to estimate + uniqueSuccesses: Int, // How many unique queries were successfully estimated + parseFailures: Int, + transpileFailures: Int, + records: Seq[EstimationReportRecord] // The actual records - includes failures +) { + + def withRecords(newRecords: Seq[EstimationReportRecord]): EstimationReport = { + this.copy(records = newRecords) + } +} + +case class EstimationReportRecord( + transpilationReport: EstimationTranspilationReport, + analysisReport: EstimationAnalysisReport) { + def withQueries(newQuery: String, output: Option[String]): EstimationReportRecord = { + this.copy(transpilationReport = transpilationReport.withQueries(newQuery, output)) + } +} + +case class EstimationTranspilationReport( + query: Option[String] = None, + output: Option[String] = None, + parsed: Int = 0, // 1 for success, 0 for failure + statements: Int = 0, // number of statements parsed + parsing_error: Option[RemorphError] = None, + transpiled: Int = 0, // 1 for success, 0 for failure + transpiled_statements: Int = 0, // number of statements transpiled + transpilation_error: Option[RemorphError] = None) { + + def withQueries(newQuery: String, output: Option[String]): EstimationTranspilationReport = { + this.copy(query = Some(newQuery), output = output) + } +} + +case class EstimationAnalysisReport( + fingerprint: Option[Fingerprint] = None, + complexity: SqlComplexity = SqlComplexity.LOW, + score: RuleScore) diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/QueryRunner.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/QueryRunner.scala new file mode 100644 index 0000000000..f7d77e6cac --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/QueryRunner.scala @@ -0,0 +1,61 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.WorkflowStage.PARSE +import com.databricks.labs.remorph.intermediate.{RemorphError, UnexpectedOutput} +import com.databricks.labs.remorph.queries.ExampleQuery +import com.databricks.labs.remorph.transpilers._ +import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult, PreProcessing, TranspilerState} + +trait QueryRunner extends Formatter { + def runQuery(exampleQuery: ExampleQuery): ReportEntryReport + +} + +abstract class BaseQueryRunner(transpiler: Transpiler) extends QueryRunner { + + private def createReportEntryReport( + exampleQuery: ExampleQuery, + output: String, + error: Option[RemorphError] = None): ReportEntryReport = { + val expected = exampleQuery.expectedTranslation + val parsed = if (error.isEmpty) 1 else 0 + val formattedOutput = if (exampleQuery.shouldFormat) format(output) else output + val formattedExpected = expected.map(e => if (exampleQuery.shouldFormat) format(e) else e) + + formattedExpected match { + case Some(`formattedOutput`) | None => + ReportEntryReport( + parsed = parsed, + transpiled = parsed, + statements = 1, + transpiled_statements = 1, + failures = error) + case Some(expectedOutput) => + ReportEntryReport( + parsed = parsed, + statements = 1, + failures = Some(UnexpectedOutput(expectedOutput, formattedOutput))) + } + } + + override def runQuery(exampleQuery: ExampleQuery): ReportEntryReport = { + transpiler + .transpile(PreProcessing(exampleQuery.query)) + .runAndDiscardState(TranspilerState(PreProcessing(exampleQuery.query))) match { + case KoResult(PARSE, error) => ReportEntryReport(statements = 1, failures = Some(error)) + case KoResult(_, error) => + // If we got past the PARSE stage, then remember to record that we parsed it correctly + ReportEntryReport(parsed = 1, statements = 1, failures = Some(error)) + case PartialResult(output, error) => + // Even with parsing errors, we will attempt to transpile the query, and parsing errors will be recorded in + // entry report + createReportEntryReport(exampleQuery, output, Some(error)) + case OkResult(output) => + createReportEntryReport(exampleQuery, output) + } + } +} + +class IsTranspiledFromSnowflakeQueryRunner extends BaseQueryRunner(new SnowflakeToDatabricksTranspiler) + +class IsTranspiledFromTSqlQueryRunner extends BaseQueryRunner(new TSqlToDatabricksTranspiler) diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/ReportEntry.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/ReportEntry.scala new file mode 100644 index 0000000000..40f414adbf --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/ReportEntry.scala @@ -0,0 +1,42 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.intermediate.{MultipleErrors, ParsingError, RemorphError} +import io.circe.Encoder + +case class ReportEntryHeader( + project: String, + commit_hash: Option[String], + version: String, + timestamp: String, + source_dialect: String, + target_dialect: String, + file: String) +case class ReportEntryReport( + parsed: Int = 0, // 1 for success, 0 for failure + statements: Int = 0, // number of statements parsed + transpiled: Int = 0, // 1 for success, 0 for failure + transpiled_statements: Int = 0, // number of statements transpiled + failures: Option[RemorphError] = None) { + def isSuccess: Boolean = failures.isEmpty + def failedParseOnly: Boolean = failures.forall { + case _: ParsingError => true + case m: MultipleErrors => m.errors.forall(_.isInstanceOf[ParsingError]) + case _ => true + } + + // Transpilation error takes precedence over parsing error as parsing errors will be + // shown in the output. If there is a transpilation error, we should therefore show that instead. + def errorMessage: Option[String] = failures.map(_.msg) +} + +case class ReportEntry(header: ReportEntryHeader, report: ReportEntryReport) + +object ReportEntry extends ErrorEncoders { + import io.circe.generic.auto._ + import io.circe.syntax._ + implicit val reportEntryEncoder: Encoder[ReportEntry] = Encoder.instance { entry => + val header = entry.header.asJson + val report = entry.report.asJson + header.deepMerge(report) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/encoders.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/encoders.scala new file mode 100644 index 0000000000..5888b873bc --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/encoders.scala @@ -0,0 +1,49 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.coverage.estimation.{Rule => EstRule, _} +import com.databricks.labs.remorph.discovery.{Fingerprint, QueryType, WorkloadType} +import com.databricks.labs.remorph.intermediate._ +import io.circe._ +import io.circe.generic.extras.semiauto._ +import io.circe.generic.extras.Configuration +import io.circe.syntax._ + +import java.sql.Timestamp +import java.time.Duration + +trait ErrorEncoders { + + implicit val codecConfiguration: Configuration = + Configuration.default.withSnakeCaseMemberNames + + implicit val singleErrorEncoder: Encoder[SingleError] = Encoder.instance { err => + Json.obj("error_code" -> err.getClass.getSimpleName.asJson, "error_message" -> err.msg.asJson) + } + implicit val remorphErrorEncoder: Encoder[RemorphError] = Encoder.instance { + case s: SingleError => Json.arr(s.asJson) + case m: MultipleErrors => Json.arr(m.errors.map(_.asJson): _*) + } + +} + +trait EstimationReportEncoders extends ErrorEncoders { + + implicit val sqlComplexityEncoder: Encoder[SqlComplexity] = deriveConfiguredEncoder + implicit val parseFailStatsEncoder: Encoder[ParseFailStats] = deriveConfiguredEncoder + implicit val estimationStatisticsEntryEncoder: Encoder[EstimationStatisticsEntry] = deriveConfiguredEncoder + implicit val estimationStatisticsEncoder: Encoder[EstimationStatistics] = deriveConfiguredEncoder + + implicit val timestampEncoder: Encoder[Timestamp] = Encoder.instance(t => t.getTime.asJson) + implicit val durationEncoder: Encoder[Duration] = Encoder.instance(d => d.toMillis.asJson) + implicit val workloadTypeEncoder: Encoder[WorkloadType.WorkloadType] = Encoder.instance(wt => wt.toString.asJson) + implicit val queryTypeEncoder: Encoder[QueryType.QueryType] = Encoder.instance(qt => qt.toString.asJson) + + implicit val ruleEncoder: Encoder[EstRule] = deriveConfiguredEncoder + implicit val ruleScoreEncoder: Encoder[RuleScore] = Encoder.instance(score => Json.obj("rule" -> score.rule.asJson)) + + implicit val fingerPrintEncoder: Encoder[Fingerprint] = deriveConfiguredEncoder + implicit val estimationAnalysisReportEncoder: Encoder[EstimationAnalysisReport] = deriveConfiguredEncoder + implicit val estimationTranspilationReportEncoder: Encoder[EstimationTranspilationReport] = deriveConfiguredEncoder + implicit val estimationReportRecordEncoder: Encoder[EstimationReportRecord] = deriveConfiguredEncoder + implicit val estimationReportEncoder: Encoder[EstimationReport] = deriveConfiguredEncoder +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationAnalyzer.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationAnalyzer.scala new file mode 100644 index 0000000000..0a1bb81dbb --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationAnalyzer.scala @@ -0,0 +1,295 @@ +package com.databricks.labs.remorph.coverage.estimation + +import com.databricks.labs.remorph.coverage.EstimationReportRecord +import com.databricks.labs.remorph.intermediate.{ParsingErrors} +import com.databricks.labs.remorph.{intermediate => ir} +import com.typesafe.scalalogging.LazyLogging + +import scala.util.control.NonFatal + +sealed trait SqlComplexity +object SqlComplexity { + case object LOW extends SqlComplexity + case object MEDIUM extends SqlComplexity + case object COMPLEX extends SqlComplexity + case object VERY_COMPLEX extends SqlComplexity + + // TODO: Define the scores for each complexity level + def fromScore(score: Double): SqlComplexity = score match { + case s if s < 10 => LOW + case s if s < 60 => MEDIUM + case s if s < 120 => COMPLEX + case _ => VERY_COMPLEX + } +} + +case class SourceTextComplexity(lineCount: Int, textLength: Int) + +case class ParseFailStats(ruleNameCounts: Map[String, Int], tokenNameCounts: Map[String, Int]) + +case class EstimationStatistics( + allStats: EstimationStatisticsEntry, + successStats: EstimationStatisticsEntry, + pfStats: ParseFailStats) + +case class EstimationStatisticsEntry( + medianScore: Int, + meanScore: Double, + modeScore: Int, + stdDeviation: Double, + percentile25: Double, + percentile50: Double, + percentile75: Double, + geometricMeanScore: Double, + complexity: SqlComplexity) + +class EstimationAnalyzer extends LazyLogging { + + def evaluateTree(node: ir.TreeNode[_]): RuleScore = { + evaluateTree(node, logicalPlanEvaluator, expressionEvaluator) + } + + def evaluateTree( + node: ir.TreeNode[_], + logicalPlanVisitor: PartialFunction[ir.LogicalPlan, RuleScore], + expressionVisitor: PartialFunction[ir.Expression, RuleScore]): RuleScore = { + + node match { + + case lp: ir.LogicalPlan => + val currentRuleScore = + logicalPlanVisitor.applyOrElse(lp, (_: ir.LogicalPlan) => RuleScore(IrErrorRule(), Seq.empty)) + + val childrenRuleScores = + lp.children.map(child => evaluateTree(child, logicalPlanVisitor, expressionVisitor)) + val expressionRuleScores = + lp.expressions.map(expr => evaluateTree(expr, logicalPlanVisitor, expressionVisitor)) + + val childrenValue = childrenRuleScores.map(_.rule.score).sum + val expressionsValue = expressionRuleScores.map(_.rule.score).sum + + RuleScore( + currentRuleScore.rule.plusScore(childrenValue + expressionsValue), + childrenRuleScores ++ expressionRuleScores) + + case expr: ir.Expression => + val currentRuleScore = + expressionVisitor.applyOrElse(expr, (_: ir.Expression) => RuleScore(IrErrorRule(), Seq.empty)) + val childrenRuleScores = + expr.children.map(child => evaluateTree(child, logicalPlanVisitor, expressionVisitor)) + val childrenValue = childrenRuleScores.map(_.rule.score).sum + + // All expressions have a base cost, plus the cost of the ir.Expression itself and its children + RuleScore(currentRuleScore.rule.plusScore(childrenValue), childrenRuleScores) + + case _ => + throw new IllegalArgumentException(s"Unsupported node type: ${node.getClass.getSimpleName}") + } + } + + /** + *

+ * Given the raw query text, produce some statistics that are purely derived from the text, rather than + * a parsed plan or translation + *

+ *

+ * Text complexity is just one component for the overall score of a query, but it can be a good indicator + * of how complex the query is in terms of a human translating it. For example, a query with many lines + * and a lot of text is likely to take some time to manually translate, even if there are no complex + * expressions, UDFs or subqueries. Text length is of little consequence to the transpiler if it is + * successful in parsing but there is. + *

+ * + * @param query the raw text of the query + * @return a set of statistics about the query text + */ + def sourceTextComplexity(query: String): SourceTextComplexity = { + SourceTextComplexity(query.split("\n").length, query.length) + } + + private def logicalPlanEvaluator: PartialFunction[ir.LogicalPlan, RuleScore] = { case lp: ir.LogicalPlan => + try { + lp match { + case ir.UnresolvedCommand(_, _, _, _) => + RuleScore(UnsupportedCommandRule(), Seq.empty) + + // TODO: Add scores for other logical plans that add more complexity then a simple statement + case _ => + RuleScore(StatementRule(), Seq.empty) // Default case for other logical plans + } + + } catch { + case NonFatal(_) => RuleScore(IrErrorRule(), Seq.empty) + } + } + + private def expressionEvaluator: PartialFunction[ir.Expression, RuleScore] = { case expr: ir.Expression => + try { + expr match { + case ir.ScalarSubquery(relation) => + // ScalarSubqueries are a bit more complex than a simple ir.Expression and their score + // is calculated by an addition for the subquery being present, and the sub-query itself + val subqueryRelationScore = evaluateTree(relation) + RuleScore(SubqueryRule().plusScore(subqueryRelationScore.rule.score), Seq(subqueryRelationScore)) + + case uf: ir.UnresolvedFunction => + // Unsupported functions are a bit more complex than a simple ir.Expression and their score + // is calculated by an addition for the function being present, and the function itself + assessFunction(uf) + + // TODO: Add specific rules for things that are more complicated than simple expressions such as + // UDFs or CASE statements - also cater for all the different Unresolved[type] classes + case _ => + RuleScore(ExpressionRule(), Seq.empty) // Default case for straightforward expressions + } + } catch { + case NonFatal(_) => RuleScore(IrErrorRule(), Seq.empty) + } + } + + /** + * Assess the complexity of an unsupported function conversion based on our internal knowledge of how + * the function is used. Some functions indicate data processing that is not supported in Databricks SQL + * and some will indicate a well-known conversion pattern that is known to be successful. + * @param func the function definition to analyze + * @return the conversion complexity score for the function + */ + private def assessFunction(func: ir.UnresolvedFunction): RuleScore = { + func match { + // For instance XML functions are not supported in Databricks SQL and will require manual conversion, + // which will be a significant amount of work. + case af: ir.UnresolvedFunction => + RuleScore(UnsupportedFunctionRule(funcName = af.function_name).resolve(), Seq.empty) + } + } + + def summarizeComplexity(reportEntries: Seq[EstimationReportRecord]): EstimationStatistics = { + + // We produce a list of all scores and a list of all successful transpile scores, which allows to produce + // statistics on ALL transpilation attempts and at the same time on only successful transpilations. Use case + // will vary on who is consuming the final reports. + val scores = reportEntries.map(_.analysisReport.score.rule.score).sorted + val successScores = reportEntries + .filter(_.transpilationReport.transpiled_statements > 0) + .map(_.analysisReport.score.rule.score) + .sorted + + val medianScore = median(scores) + val meanScore = scores.sum.toDouble / scores.size + val modeScore = scores.groupBy(identity).maxBy(_._2.size)._1 + val variance = scores.map(score => math.pow(score - meanScore, 2)).sum / scores.size + val stdDeviation = math.sqrt(variance) + val percentile25 = percentile(scores, 0.25) + val percentile50 = percentile(scores, 0.50) // Same as median + val percentile75 = percentile(scores, 0.75) + val geometricMeanScore = geometricMean(scores) + + val allStats = EstimationStatisticsEntry( + medianScore, + meanScore, + modeScore, + stdDeviation, + percentile25, + percentile50, + percentile75, + geometricMeanScore, + SqlComplexity.fromScore(geometricMeanScore)) + + val successStats = if (successScores.isEmpty) { + EstimationStatisticsEntry(0, 0, 0, 0, 0, 0, 0, 0, SqlComplexity.LOW) + } else { + EstimationStatisticsEntry( + median(successScores), + successScores.sum.toDouble / successScores.size, + successScores.groupBy(identity).maxBy(_._2.size)._1, + math.sqrt( + successScores + .map(score => math.pow(score - (successScores.sum.toDouble / successScores.size), 2)) + .sum / successScores.size), + percentile(successScores, 0.25), + percentile(successScores, 0.50), // Same as median + percentile(successScores, 0.75), + geometricMean(successScores), + SqlComplexity.fromScore(geometricMean(successScores))) + } + + EstimationStatistics(allStats, successStats, assessParsingFailures(reportEntries)) + } + + private def percentile(scores: Seq[Int], p: Double): Double = { + if (scores.isEmpty) { + 0 + } else { + val k = (p * (scores.size - 1)).toInt + scores(k) + } + } + + private def geometricMean(scores: Seq[Int]): Double = { + val nonZeroScores = scores.filter(_ != 0) + if (nonZeroScores.nonEmpty) { + val logSum = nonZeroScores.map(score => math.log(score.toDouble)).sum + math.exp(logSum / nonZeroScores.size) + } else { + 0.0 + } + } + + def median(scores: Seq[Int]): Int = { + if (scores.isEmpty) { + 0 + } else if (scores.size % 2 == 1) { + scores(scores.size / 2) + } else { + val (up, down) = scores.splitAt(scores.size / 2) + (up.last + down.head) / 2 + } + } + + /** + * Assigns a conversion complexity score based on how much text is in the query, which is a basic + * indicator of how much work will be required to manually inspect a query. + * @param sourceTextComplexity the complexity of the source text + * @return the score for the complexity of the query + */ + def assessText(sourceTextComplexity: SourceTextComplexity): Int = + // TODO: These values are arbitrary and need to be verified in some way + sourceTextComplexity.lineCount + sourceTextComplexity.textLength match { + case l if l < 100 => 1 + case l if l < 500 => 5 + case l if l < 1000 => 10 + case l if l < 5000 => 25 + case _ => 50 + } + + /** + * Find all the report entries where parsing_error is not null, and accumulate the number of times + * each ruleName and tokenName appears in the errors. This will give us an idea of which rules and + * tokens, if implemented correctly would have the most impact on increasing the success rate of the + * parser for the given sample of queries. + * + * @param reportEntries the list of all report records + */ + def assessParsingFailures(reportEntries: Seq[EstimationReportRecord]): ParseFailStats = { + val ruleNameCounts = scala.collection.mutable.Map[String, Int]().withDefaultValue(0) + val tokenNameCounts = scala.collection.mutable.Map[String, Int]().withDefaultValue(0) + + reportEntries.foreach(err => + err.transpilationReport.parsing_error match { + case Some(e: ParsingErrors) => + e.errors.foreach(e => { + val ruleName = e.ruleName + val tokenName = e.offendingTokenName + ruleNameCounts(ruleName) += 1 + tokenNameCounts(tokenName) += 1 + }) + case _ => // No errors + }) + + val topRuleNames = ruleNameCounts.toSeq.sortBy(-_._2).take(10).toMap + val topTokenNames = tokenNameCounts.toSeq.sortBy(-_._2).take(10).toMap + + ParseFailStats(topRuleNames, topTokenNames) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationReporter.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationReporter.scala new file mode 100644 index 0000000000..877841aefc --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/EstimationReporter.scala @@ -0,0 +1,164 @@ +package com.databricks.labs.remorph.coverage.estimation + +import com.databricks.labs.remorph.coverage.{EstimationReport, EstimationReportEncoders} +import io.circe.syntax._ + +trait EstimationReporter { + def report(): Unit +} + +class SummaryEstimationReporter(outputDir: os.Path, estimate: EstimationReport) extends EstimationReporter { + override def report(): Unit = { + val ruleFrequency = estimate.overallComplexity.pfStats.ruleNameCounts.toSeq + .sortBy(-_._2) // Sort by count in ascending order + .map { case (ruleName, count) => s"| $ruleName | $count |\n" } + .mkString + val tokenFrequency = estimate.overallComplexity.pfStats.tokenNameCounts.toSeq + .sortBy(-_._2) // Sort by count in ascending order + .map { case (tokenName, count) => s"| $tokenName | $count |\n" } + .mkString + + val output = + s""" + |# Conversion Complexity Estimation Report + |## Report explanation: + | + |### Sample size + |The number of records used to provide estimate. In other words the total + |number of records in the input dataset. + |### Output record count + |The number of records generated to track output. This includes + |both successful and failed transpilations, but the analysis tries + |to avoid duplicates in both successful and failed transpilations. + |### Unique successful transpiles + |The number of unique queries that were successfully transpiled + |from source dialect to Databricks SQL. This count does not include + |queries that were duplicates of previously seen queries with just + |simple parameter changes. + |### Unique Parse failures + |The number of unique queries that failed to parse and therefore + |could not produce IR or be transpiled. + |### Transpile failures + |The number of unique queries that were parsed, produced an IR Plan, + |but then failed to transpile to Databricks SQL. Note that this means + |that there is a bug in either the IR generation or the source generation. + |### Overall complexity + |The overall complexity of the queries in the dataset. This is a rough + |estimate of how difficult it will be to complete a port of system + |if the supplied queries are representative of the source system. + | + |This can be somewhat subjective in that a query that is very complex in terms + |of the number of statements or the number of joins may not be as complex as + |it appears to be. However, it is a good starting point for understanding the + |scope. + | + |### Statistics used to calculate overall complexity + |While complexity is presented as one of four categories, being, *LOW*, *MEDIUM*, + |*COMPLEX*, and *VERY_COMPLEX*, the actual judgement is based on a numeric score + |calculated from the presence of various elements in the query. A low score + |results in a complexity of *LOW* and a high score results in a complexity of *VERY_COMPLEX*. + | + |The individual scores for each query are then used to calculate the mean, median, and other + |statistics that may be used to determine the overall complexity. The raw values are + |contained in the report so that different interpretations can be made than the ones + |provided by the current version of the estimate command. + | + |## Metrics + | | Metric | Value | + | |:----------------------------|--------------------------------:| + | | Sample size | ${f"${estimate.sampleSize}%,d"}| + | | Output record count | ${f"${estimate.records.size}%,d"}| + | | Unique successful transpiles| ${estimate.uniqueSuccesses} | + | | Unique Parse failures | ${estimate.parseFailures} | + | | Transpile failures | ${estimate.transpileFailures} | + | | Overall complexity (ALL) | ${estimate.overallComplexity.allStats.complexity} | + | | Overall complexity (SUCCESS)| ${estimate.overallComplexity.successStats.complexity} | + | + |## Failing Parser Rule and Failed Token Frequencies + | This table shows the top N ANTLR grammar rules where parsing errors occurred and therefore + | where spent in improving the parser will have the most impact. It should be used as a starting + | point as these counts may include many instances of the same error. So fixing one parsing problem + | may rid you of a large number of failing queries. + | + | | Rule Name | Frequency | + | |:----------------------------|--------------------------------:| + | $ruleFrequency + | + | This table is less useful than the rule table but it may be useful to see if there might be + | a missing token definition or a token that is a keyword but not bing allowed as an identifier etc. + | + | | Token Name | Frequency | + | |:----------------------------|--------------------------------:| + | $tokenFrequency + | + |## Statistics used to calculate overall complexity (ALL results) + | + | | Metric | Value | + | |:----------------------------|--------------------------------:| + | | Mean score | ${f"${estimate.overallComplexity.allStats.meanScore}%,.2f"}| + | | Standard deviation | ${f"${estimate.overallComplexity.allStats.stdDeviation}%,.2f"}| + | | Mode score | ${estimate.overallComplexity.allStats.modeScore}| + | | Median score | ${estimate.overallComplexity.allStats.medianScore}| + | | Percentile 25 | ${f"${estimate.overallComplexity.allStats.percentile25}%,.2f"}| + | | Percentile 50 | ${f"${estimate.overallComplexity.allStats.percentile50}%,.2f"}| + | | Percentile 75 | ${f"${estimate.overallComplexity.allStats.percentile75}%,.2f"}| + | | Geometric mean score | ${f"${estimate.overallComplexity.allStats.geometricMeanScore}%,.2f"}| + | + |## Statistics used to calculate overall complexity (Successful results only) + | + | | Metric | Value | + | |:----------------------------|--------------------------------:| + | | Mean score | ${f"${estimate.overallComplexity.successStats.meanScore}%,.2f"}| + | | Standard deviation | ${f"${estimate.overallComplexity.successStats.stdDeviation}%,.2f"}| + | | Mode score | ${estimate.overallComplexity.successStats.modeScore}| + | | Median score | ${estimate.overallComplexity.successStats.medianScore}| + | | Percentile 25 | ${f"${estimate.overallComplexity.successStats.percentile25}%,.2f"}| + | | Percentile 50 | ${f"${estimate.overallComplexity.successStats.percentile50}%,.2f"}| + | | Percentile 75 | ${f"${estimate.overallComplexity.successStats.percentile75}%,.2f"}| + | | Geometric mean score | ${f"${estimate.overallComplexity.successStats.geometricMeanScore}%,.2f"}| + |""".stripMargin + + val summaryFilePath = outputDir / "summary.md" + os.write(summaryFilePath, output) + // scalastyle:off println + println(s"Summary report written to ${summaryFilePath}") + // scalastyle:on println + } +} + +class JsonEstimationReporter(outputDir: os.Path, preserveQueries: Boolean, estimate: EstimationReport) + extends EstimationReporter + with EstimationReportEncoders { + override def report(): Unit = { + val queriesDir = outputDir / "queries" + os.makeDir.all(queriesDir) + val resultPath = outputDir / s"${estimate.dialect}.json" + + // Iterate over the records and modify the transpilationReport.query field + var count = 0 + val newRecords = estimate.records.map { record => + if (preserveQueries) { + val (queryFilePath, outputFilepath) = record.analysisReport.fingerprint match { + case Some(fingerprint) => + (queriesDir / s"${fingerprint.fingerprint}.sql", queriesDir / s"${fingerprint.fingerprint}_transpiled.sql") + case None => + count += 1 + (queriesDir / s"parse_fail_${count}.sql", null) // No output file for failed transpiles + } + os.write(queryFilePath, record.transpilationReport.query) + record.transpilationReport.output match { + case Some(output) => + os.write(outputFilepath, output) + record.withQueries(queryFilePath.toString, Some(outputFilepath.toString)) + case None => + record.withQueries(queryFilePath.toString, None) + } + } else { + record.withQueries("", None) + } + } + val newEstimate = estimate.withRecords(newRecords) + val jsonReport: String = newEstimate.asJson.spaces4 + os.write(resultPath, jsonReport) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/Estimator.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/Estimator.scala new file mode 100644 index 0000000000..8f2f3bdf19 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/Estimator.scala @@ -0,0 +1,158 @@ +package com.databricks.labs.remorph.coverage.estimation + +import com.databricks.labs.remorph.WorkflowStage.{PARSE, PLAN} +import com.databricks.labs.remorph.coverage._ +import com.databricks.labs.remorph.discovery.{Anonymizer, ExecutedQuery, QueryHistoryProvider} +import com.databricks.labs.remorph.intermediate.{LogicalPlan, ParsingError, TranspileFailure} +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.{KoResult, OkResult, Optimizing, Parsing, TranspilerState} +import com.databricks.labs.remorph.transpilers.SqlGenerator +import com.typesafe.scalalogging.LazyLogging + +class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], analyzer: EstimationAnalyzer) + extends LazyLogging { + + def run(): EstimationReport = { + val history = queryHistory.history() + val anonymizer = new Anonymizer(planParser) + // Hashes of either query strings or plans that we have seen before. + val parsedSet = scala.collection.mutable.Set[String]() + val reportEntries = history.queries.flatMap(processQuery(_, anonymizer, parsedSet)) + val (uniqueSuccesses, parseFailures, transpileFailures) = countReportEntries(reportEntries) + + EstimationReport( + dialect = planParser.dialect, + sampleSize = history.queries.size, + uniqueSuccesses = uniqueSuccesses, + parseFailures = parseFailures, + transpileFailures = transpileFailures, + records = reportEntries, + overallComplexity = analyzer.summarizeComplexity(reportEntries)) + } + + private def processQuery( + query: ExecutedQuery, + anonymizer: Anonymizer, + parsedSet: scala.collection.mutable.Set[String]): Option[EstimationReportRecord] = { + + val initialState = TranspilerState(Parsing(query.source)) + + // Skip entries that have already been seen as text but for which we were unable to parse or + // produce a plan for + val fingerprint = anonymizer(query.source) + if (!parsedSet.contains(fingerprint)) { + parsedSet += fingerprint + planParser.parse + .flatMap(planParser.visit) + .run(initialState) match { + case KoResult(PARSE, error) => + Some( + EstimationReportRecord( + EstimationTranspilationReport(Some(query.source), statements = 1, parsing_error = Some(error)), + EstimationAnalysisReport( + score = RuleScore(ParseFailureRule(), Seq.empty), + complexity = SqlComplexity.VERY_COMPLEX))) + + case KoResult(PLAN, error) => + Some( + EstimationReportRecord( + EstimationTranspilationReport(Some(query.source), statements = 1, transpilation_error = Some(error)), + EstimationAnalysisReport( + score = RuleScore(PlanFailureRule(), Seq.empty), + complexity = SqlComplexity.VERY_COMPLEX))) + + case OkResult(plan) => + val queryHash = anonymizer(plan._2) + val score = analyzer.evaluateTree(plan._2) + // Note that the plan hash will generally be more accurate than the query hash, hence we check here + // as well as against the plain text + if (!parsedSet.contains(queryHash)) { + parsedSet += queryHash + Some(generateReportRecord(query, plan._2, score, anonymizer)) + } else { + None + } + + case _ => + Some( + EstimationReportRecord( + EstimationTranspilationReport( + query = Some(query.source), + statements = 1, + parsing_error = Some(ParsingError(0, 0, "Unexpected result from parse phase", 0, "", "", ""))), + EstimationAnalysisReport( + score = RuleScore(UnexpectedResultRule(), Seq.empty), + complexity = SqlComplexity.VERY_COMPLEX))) + } + } else { + None + } + } + + private def generateReportRecord( + query: ExecutedQuery, + plan: LogicalPlan, + ruleScore: RuleScore, + anonymizer: Anonymizer): EstimationReportRecord = { + val generator = new SqlGenerator + val initialState = TranspilerState(Optimizing(plan, None)) + planParser.optimize(plan).flatMap(generator.generate).run(initialState) match { + case KoResult(_, error) => + // KoResult to transpile means that we need to increase the ruleScore as it will take some + // time to manually investigate and fix the issue + val tfr = RuleScore(TranspileFailureRule().plusScore(ruleScore.rule.score), Seq(ruleScore)) + EstimationReportRecord( + EstimationTranspilationReport( + query = Some(query.source), + statements = 1, + parsed = 1, + transpilation_error = Some(error)), + EstimationAnalysisReport( + fingerprint = Some(anonymizer(query, plan)), + score = tfr, + complexity = SqlComplexity.fromScore(tfr.rule.score))) + + case OkResult((_, output: String)) => + val newScore = + RuleScore(SuccessfulTranspileRule().plusScore(ruleScore.rule.score), Seq(ruleScore)) + EstimationReportRecord( + EstimationTranspilationReport( + query = Some(query.source), + output = Some(output), + statements = 1, + transpiled = 1, + transpiled_statements = 1, + parsed = 1), + EstimationAnalysisReport( + fingerprint = Some(anonymizer(query, plan)), + score = newScore, + complexity = SqlComplexity.fromScore(newScore.rule.score))) + + case _ => + EstimationReportRecord( + EstimationTranspilationReport( + query = Some(query.source), + statements = 1, + parsed = 1, + transpilation_error = + Some(TranspileFailure(new RuntimeException("Unexpected result from transpile phase")))), + EstimationAnalysisReport( + fingerprint = Some(anonymizer(query, plan)), + score = RuleScore(UnexpectedResultRule().plusScore(ruleScore.rule.score), Seq(ruleScore)), + complexity = SqlComplexity.VERY_COMPLEX)) + } + } + + private def countReportEntries(reportEntries: Seq[EstimationReportRecord]): (Int, Int, Int) = { + reportEntries.foldLeft((0, 0, 0)) { case ((uniqueSuccesses, parseFailures, transpileFailures), entry) => + val newUniqueSuccesses = + if (entry.transpilationReport.parsed == 1 && entry.transpilationReport.transpiled == 1) uniqueSuccesses + 1 + else uniqueSuccesses + val newParseFailures = + if (entry.transpilationReport.parsing_error.isDefined) parseFailures + 1 else parseFailures + val newTranspileFailures = + if (entry.transpilationReport.transpilation_error.isDefined) transpileFailures + 1 else transpileFailures + (newUniqueSuccesses, newParseFailures, newTranspileFailures) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/RuleDefinitions.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/RuleDefinitions.scala new file mode 100644 index 0000000000..f9ae47841d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/estimation/RuleDefinitions.scala @@ -0,0 +1,127 @@ +package com.databricks.labs.remorph.coverage.estimation + +/** + *

Defines the rules and their related score for conversion complexity estimation. + *

+ *

+ * The rules are defined as a map of rule name to score, and their descriptions are expected to be + * stored somewhere more relevant to the dashboard reporting system (where they can also be subject + * to i18n/l10n). + *

+ *

+ * Rules that are matched by the analyzer will be used to calculate the complexity of the query in terms + * of how much effort it is to convert it to Databricks SQL and not necessarily how complex the query is in + * terms of say execution time or resource requirements. While there are rules to score for the inability + * to parse, generate IR and transpile, they are essentially capturing work for the core team rather than + * the user/porting team. Such scores can optionally be ruled out of conversion complexity calculations but + * are useful to assess the work required from the core Remorph team. + *

+ */ +sealed trait Rule { + def score: Int + def plusScore(newScore: Int): Rule // Adds the given score to the current score +} + +/** + * Transpilation was successful so we can reduce the score, but it is not zero because there will be some + * effort required to verify the translation. + */ +case class SuccessfulTranspileRule(override val score: Int = 5) extends Rule { + override def plusScore(newScore: Int): SuccessfulTranspileRule = this.copy(score = newScore + this.score) +} + +/** + * We were unable to parse the query at all. This adds a significant amount of work to the conversion, but it is + * work for the core team, not the user, so are able to filter these out of calculations if desired. + */ +case class ParseFailureRule(score: Int = 100) extends Rule { + override def plusScore(newScore: Int): ParseFailureRule = this.copy(score = newScore + this.score) +} + +/** + * We were able to parse this query, but the logical plan was not generated. This is possibly significant work + * required from the core team, but it is not necessarily work for the user, so we can filter out these scores + * from the conversion complexity calculations if desired. + */ +case class PlanFailureRule(score: Int = 100) extends Rule { + override def plusScore(newScore: Int): PlanFailureRule = this.copy(score = newScore + this.score) +} + +/** + * Either the optimizer or the generator failed to produce a This is possibly a significant amount of + * work for the core team, but it is not necessarily work for the user, so we can filter out these scores. + */ +case class TranspileFailureRule(override val score: Int = 100) extends Rule { + override def plusScore(newScore: Int): TranspileFailureRule = this.copy(score = newScore + this.score) +} + +/** + * In theory this cannot happen, but it means the toolchain returned some status that we do not understand + */ +case class UnexpectedResultRule(override val score: Int = 100) extends Rule { + override def plusScore(newScore: Int): UnexpectedResultRule = this.copy(score = newScore + this.score) +} + +/** + * An IR error is only flagged when there is something wrong with the IR generation we received. This generally + * indicates that there is a missing visitor and that the results of visiting some node were null. This is actually + * a bug in the Remorph code and should be fixed by the core team. This is not work for the user, so we can filter. + */ +case class IrErrorRule(override val score: Int = 100) extends Rule { + override def plusScore(newScore: Int): IrErrorRule = this.copy(score = newScore + this.score) +} + +/** + * Each individual statement in a query is a separate unit of work. This is a low level of work, but it is + * counted as it will need to be verified in some way. + */ +case class StatementRule(override val score: Int = 1) extends Rule { + override def plusScore(newScore: Int): StatementRule = this.copy(score = newScore + this.score) +} + +/** + * Any expression in the query is a unit of work. This is also a low level of work, but it is counted as it will + * need to be verified in some way. + */ +case class ExpressionRule(override val score: Int = 1) extends Rule { + override def plusScore(newScore: Int): ExpressionRule = this.copy(score = newScore + this.score) +} + +/** + * Subqueries will tend to add more complexity in human analysis of any query, though their existence does not + * necessarily mean that it is complex to convert to Databricks SQL. The final score for a sub query is also + * a function of its component parts. + */ +case class SubqueryRule(override val score: Int = 5) extends Rule { + override def plusScore(newScore: Int): SubqueryRule = this.copy(score = newScore + this.score) +} + +// Unsupported statements and functions etc + +/** + * When we see a function that we do not already support, it either means that this is either a UDF, + * a function that we have not yet been implemented in the transpiler, or a function that is not + * supported by Databricks SQL at all. + * This is potentially a significant amount of work to convert, but in some case we will identify the + * individual functions that we cannot support automatically at all and provide a higher score for them. + * For instance XML functions in TSQL. + */ +case class UnsupportedFunctionRule(override val score: Int = 10, funcName: String) extends Rule { + override def plusScore(newScore: Int): UnsupportedFunctionRule = this.copy(score = newScore + this.score) + def resolve(): UnsupportedFunctionRule = this.copy(score = funcName match { + + // TODO: Add scores for the various unresolved functions that we know will be extra complicated to convert + case "OPENXML" => 25 + case _ => 10 + }) +} + +/** + * When we see a command that we do not support, it either means that this is a command that we have not yet + * implemented or that we can never implement it, and it is going to add a lot of complexity to the conversion. + */ +case class UnsupportedCommandRule(override val score: Int = 10) extends Rule { + override def plusScore(newScore: Int): UnsupportedCommandRule = this.copy(score = newScore + this.score) +} + +case class RuleScore(rule: Rule, from: Seq[RuleScore]) diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/CsvDumper.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/CsvDumper.scala new file mode 100644 index 0000000000..985bfbb637 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/CsvDumper.scala @@ -0,0 +1,36 @@ +package com.databricks.labs.remorph.coverage.runners + +import com.github.tototoshi.csv.CSVWriter + +import java.io.StringWriter +import java.sql.{Connection, ResultSet} + +class CsvDumper(connection: Connection) { + def queryToCSV(query: String): String = { + val statement = connection.createStatement() + val resultSet = statement.executeQuery(query) + val csv = resultSetToCSV(resultSet) + resultSet.close() + statement.close() + csv + } + + private def resultSetToCSV(resultSet: ResultSet): String = { + val writer = new StringWriter() + val csvWriter = new CSVWriter(writer) + + // write the header + val metaData = resultSet.getMetaData + val columnCount = metaData.getColumnCount + val header = (1 to columnCount).map(metaData.getColumnName) + csvWriter.writeRow(header) + + // write the data + while (resultSet.next()) { + val row = (1 to columnCount).map(resultSet.getString) + csvWriter.writeRow(row) + } + + writer.toString() + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQL.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQL.scala new file mode 100644 index 0000000000..0bfaf83387 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQL.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.coverage.runners + +import com.databricks.connect.DatabricksSession +import com.databricks.sdk.core.DatabricksConfig +import com.databricks.sdk.WorkspaceClient +import com.typesafe.scalalogging.LazyLogging + +class DatabricksSQL(env: EnvGetter) extends LazyLogging { + val config = new DatabricksConfig() + .setHost(env.get("DATABRICKS_HOST")) + // TODO: fix envs to use DATABRICKS_CLUSTER_ID + .setClusterId(env.get("TEST_USER_ISOLATION_CLUSTER_ID")) + + val w = new WorkspaceClient(config) + logger.info("Ensuring cluster is running") + w.clusters().ensureClusterIsRunning(config.getClusterId) + + val spark = DatabricksSession.builder().sdkConfig(config).getOrCreate() + val res = spark.sql("SELECT * FROM samples.tpch.customer LIMIT 10").collect() + logger.info(s"Tables: ${res.mkString(", ")}") +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/EnvGetter.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/EnvGetter.scala new file mode 100644 index 0000000000..62d3bdd449 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/EnvGetter.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.remorph.coverage.runners + +import com.databricks.labs.remorph.utils.Strings +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.fasterxml.jackson.module.scala.ClassTagExtensions +import com.typesafe.scalalogging.LazyLogging + +import java.io.{File, FileNotFoundException} + +case class DebugEnv(ucws: Map[String, String]) + +class EnvGetter extends LazyLogging { + private[this] val env = getDebugEnv() + + def get(key: String): String = env.getOrElse(key, throw new RuntimeException(s"not in env: $key")) + + private def getDebugEnv(): Map[String, String] = { + try { + val debugEnvFile = String.format("%s/.databricks/debug-env.json", System.getProperty("user.home")) + val contents = Strings.fileToString(new File(debugEnvFile)) + logger.debug(s"Found debug env file: $debugEnvFile") + val mapper = new ObjectMapper() with ClassTagExtensions + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + mapper.registerModule(DefaultScalaModule) + val envs = mapper.readValue[DebugEnv](contents) + envs.ucws + } catch { + case _: FileNotFoundException => sys.env + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeRunner.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeRunner.scala new file mode 100644 index 0000000000..c9a3b63285 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeRunner.scala @@ -0,0 +1,44 @@ +package com.databricks.labs.remorph.coverage.runners + +import net.snowflake.client.jdbc.internal.org.bouncycastle.jce.provider.BouncyCastleProvider + +import java.security.spec.PKCS8EncodedKeySpec +import java.security.{KeyFactory, PrivateKey, Security} +import java.sql.DriverManager +import java.util.{Base64, Properties} + +class SnowflakeRunner(env: EnvGetter) { + // scalastyle:off + Class.forName("net.snowflake.client.jdbc.SnowflakeDriver") + // scalastyle:on + + private[this] val url = env.get("TEST_SNOWFLAKE_JDBC") + private[this] val privateKeyPEM = env.get("TEST_SNOWFLAKE_PRIVATE_KEY") + + private def privateKey: PrivateKey = { + Security.addProvider(new BouncyCastleProvider()) + val keySpecPKCS8 = new PKCS8EncodedKeySpec( + Base64.getDecoder.decode( + privateKeyPEM + .split("\n") + .drop(1) + .dropRight(1) + .mkString)) + val kf = KeyFactory.getInstance("RSA") + kf.generatePrivate(keySpecPKCS8) + } + + private[this] val props = { + val p = new Properties() + p.put("privateKey", privateKey) + p + } + private[this] val connection = DriverManager.getConnection(url, props) + private[this] val dumper = new CsvDumper(connection) + + def queryToCSV(query: String): String = dumper.queryToCSV(query) + + def close() { + connection.close() + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/TSqlRunner.scala b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/TSqlRunner.scala new file mode 100644 index 0000000000..d45274d79e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/coverage/runners/TSqlRunner.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.coverage.runners + +import java.sql.DriverManager + +class TSqlRunner(env: EnvGetter) { + // scalastyle:off + Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver") + // scalastyle:on + + private[this] val url = env.get("TEST_TSQL_JDBC") + private[this] val user = env.get("TEST_TSQL_USER") + private[this] val pass = env.get("TEST_TSQL_PASS") + private[this] val connection = DriverManager.getConnection(url, user, pass) + private[this] val dumper = new CsvDumper(connection) + + def queryToCSV(query: String): String = dumper.queryToCSV(query) + + def close() { + connection.close() + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/Anonymizer.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/Anonymizer.scala new file mode 100644 index 0000000000..db526b67aa --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/Anonymizer.scala @@ -0,0 +1,206 @@ +package com.databricks.labs.remorph.discovery + +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.{KoResult, OkResult, Parsing, PartialResult, TranspilerState, WorkflowStage} +import com.typesafe.scalalogging.LazyLogging + +import java.security.MessageDigest +import java.sql.Timestamp +import java.time.Duration + +object WorkloadType extends Enumeration { + type WorkloadType = Value + val ETL, SQL_SERVING, OTHER = Value +} + +object QueryType extends Enumeration { + type QueryType = Value + val DDL, DML, PROC, OTHER = Value +} + +/** + * A fingerprint is a hash of a query plan that can be used to recognize duplicate queries + * + * @param dbQueryHash The hash or id of the query as stored in the database, which can be used to identify the query + * when the queries cannot be stored offsite because of customer data restrictions + * @param timestamp The timestamp of when this query was executed + * @param fingerprint The hash of the query plan, rather than the query itself - can be null if we + * cannot parse the query into a digestible plan + * @param duration how long this query took to execute, which may or may not be an indication of complexity + * @param user The user who executed the query against the database + * @param workloadType The type of workload this query represents (e.g. ETL, SQL_SERVING, OTHER) + * @param queryType The type of query this is (e.g. DDL, DML, PROC, OTHER) + */ +case class Fingerprint( + dbQueryHash: String, + timestamp: Timestamp, + fingerprint: String, + duration: Duration, + user: String, + workloadType: WorkloadType.WorkloadType, + queryType: QueryType.QueryType) {} + +case class Fingerprints(fingerprints: Seq[Fingerprint]) { + def uniqueQueries: Int = fingerprints.map(_.fingerprint).distinct.size +} + +class Anonymizer(parser: PlanParser[_]) extends LazyLogging { + private[this] val placeholder = Literal("?", UnresolvedType) + + def apply(history: QueryHistory): Fingerprints = Fingerprints(history.queries.map(fingerprint)) + def apply(query: ExecutedQuery, plan: LogicalPlan): Fingerprint = fingerprint(query, plan) + def apply(query: ExecutedQuery): Fingerprint = fingerprint(query) + def apply(plan: LogicalPlan): String = fingerprint(plan) + def apply(query: String): String = fingerprint(query) + + private[discovery] def fingerprint(query: ExecutedQuery): Fingerprint = { + parser.parse.flatMap(parser.visit).run(TranspilerState(Parsing(query.source))) match { + case KoResult(WorkflowStage.PARSE, error) => + logger.warn(s"Failed to parse query: ${query.source} ${error.msg}") + Fingerprint( + query.id, + query.timestamp, + fingerprint(query.source), + query.duration, + query.user.getOrElse("unknown"), + WorkloadType.OTHER, + QueryType.OTHER) + case KoResult(_, error) => + logger.warn(s"Failed to produce plan from query: ${query.source} ${error.msg}") + Fingerprint( + query.id, + query.timestamp, + fingerprint(query.source), + query.duration, + query.user.getOrElse("unknown"), + WorkloadType.OTHER, + QueryType.OTHER) + case PartialResult((_, plan), error) => + logger.warn(s"Errors occurred while producing plan from query: ${query.source} ${error.msg}") + Fingerprint( + query.id, + query.timestamp, + fingerprint(plan), + query.duration, + query.user.getOrElse("unknown"), + workloadType(plan), + queryType(plan)) + case OkResult((_, plan)) => + Fingerprint( + query.id, + query.timestamp, + fingerprint(plan), + query.duration, + query.user.getOrElse("unknown"), + workloadType(plan), + queryType(plan)) + } + } + + /** + * Create a fingerprint for a query and its plan, when the plan is already produced + * @param query The executed query + * @param plan The logical plan + * @return A fingerprint representing the query plan + */ + private[discovery] def fingerprint(query: ExecutedQuery, plan: LogicalPlan): Fingerprint = { + Fingerprint( + query.id, + query.timestamp, + fingerprint(plan), + query.duration, + query.user.getOrElse("unknown"), + workloadType(plan), + queryType(plan)) + } + + /** + *

+ * Provide a generic hash for the given plan + *

+ *

+ * Before hashing the plan, we replace all literals with a placeholder. This way we can hash the plan + * without worrying about the actual values and will generate the same hash code for queries that only + * differ by literal values. + *

+ *

+ * This is a very simple anonymization technique, but it's good enough for our purposes. + * e.g. ... "LIMIT 500 OFFSET 0" and "LIMIT 100 OFFSET 20" will have + * the same fingerprint. + *

+ * + * @param plan The plan we want a hash code for + * @return The hash string for the query with literals replaced by placeholders + */ + private def fingerprint(plan: LogicalPlan): String = { + + val erasedLiterals = plan transformAllExpressions { case _: Literal => + placeholder + } + val code = erasedLiterals.asCode + val digest = MessageDigest.getInstance("SHA-1") + digest.update(code.getBytes) + digest.digest().map("%02x".format(_)).mkString + } + + private def workloadType(plan: LogicalPlan): WorkloadType.WorkloadType = { + plan match { + case Batch(Seq(_: Project)) => WorkloadType.SQL_SERVING + case Batch(Seq(_: CreateTableCommand)) => WorkloadType.ETL + case Batch(Seq(_: UpdateTable)) => WorkloadType.ETL + case Batch(Seq(_: DeleteFromTable)) => WorkloadType.ETL + case Batch(Seq(_: MergeIntoTable)) => WorkloadType.ETL + + case _ => WorkloadType.OTHER + } + } + + private def queryType(plan: LogicalPlan): QueryType.QueryType = { + plan match { + case Batch(Seq(_: CreateTableCommand)) => QueryType.DDL + case Batch(Seq(_: AlterTableCommand)) => QueryType.DDL + case Batch(Seq(_: DropTempView)) => QueryType.DDL + case Batch(Seq(_: DropGlobalTempView)) => QueryType.DDL + case Batch(Seq(_: Drop)) => QueryType.DDL + + case Batch(Seq(_: Project)) => QueryType.DML + case Batch(Seq(_: InsertIntoTable)) => QueryType.DML + case Batch(Seq(_: UpdateTable)) => QueryType.DML + case Batch(Seq(_: DeleteFromTable)) => QueryType.DML + case Batch(Seq(_: MergeIntoTable)) => QueryType.DML + + case _ => QueryType.OTHER + } + } + + /** + * Create a fingerprint for a query, when the plan is not yet, or cannot be produced. + *

+ * This is a fallback method for when we cannot parse the query into a plan. It will + * make a crude attempt to hash the query text itself, taking out literals and numerics. + * This gives us a very basic way to identify duplicate queries and not report them as + * unparsable if they are just different by a few values from a previous query. In turn, + * this gives us a better idea of how many unique unparsable queries we have yet to deal + * with, rather than just reporting them all as unparsable and making the core work + * seem bigger than it actually is. + *

+ *

+ * We could improve this hash by removing comments and normalizing whitespace perhaps, + * but whether we would get any gains from that is debatable + *

+ * + * @param query The text of the query to parse + * @return A fingerprint representing the query text + */ + private def fingerprint(query: String): String = { + val masked = query + .replaceAll("\\b\\d+\\b", "42") + .replaceAll("'[^']*'", "?") + .replaceAll("\"[^\"]*\"", "?") + .replaceAll("`[^`]*`", "?") + val digest = MessageDigest.getInstance("SHA-1") + digest.update(masked.getBytes) + digest.digest().map("%02x".format(_)).mkString + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/FileQueryHistory.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/FileQueryHistory.scala new file mode 100644 index 0000000000..0fe6385d78 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/FileQueryHistory.scala @@ -0,0 +1,31 @@ +package com.databricks.labs.remorph.discovery + +import java.io.File +import java.nio.file.{Files, Path} +import scala.collection.JavaConverters._ +import scala.io.Source + +class FileQueryHistory(path: Path) extends QueryHistoryProvider { + private def extractQueriesFromFile(file: File): ExecutedQuery = { + val fileContent = Source.fromFile(file) + ExecutedQuery(file.getName, fileContent.mkString, filename = Some(file.getName)) + } + + private def extractQueriesFromFolder(folder: Path): Seq[ExecutedQuery] = { + val files = + Files + .walk(folder) + .iterator() + .asScala + .filter(f => Files.isRegularFile(f)) + .toSeq + .filter(_.getFileName.toString.endsWith(".sql")) + + files.map(file => extractQueriesFromFile(file.toFile)) + } + + override def history(): QueryHistory = { + val queries = extractQueriesFromFolder(path) + QueryHistory(queries) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeQueryHistory.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeQueryHistory.scala new file mode 100644 index 0000000000..ca90f764b7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeQueryHistory.scala @@ -0,0 +1,51 @@ +package com.databricks.labs.remorph.discovery + +import java.sql.Connection +import java.time.Duration +import scala.collection.mutable.ListBuffer + +class SnowflakeQueryHistory(conn: Connection) extends QueryHistoryProvider { + def history(): QueryHistory = { + val stmt = conn.createStatement() + try { + val rs = stmt.executeQuery(s"""SELECT + | QUERY_HASH, + | QUERY_TEXT, + | USER_NAME, + | WAREHOUSE_NAME, + | START_TIME, + | TOTAL_ELAPSED_TIME + |FROM + | SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY + |WHERE + | START_TIME > CURRENT_DATE - 30 + | AND + | QUERY_TEXT != '' -- Many system queries are empty + | AND + | QUERY_TEXT != '' -- Certain queries are completely redacted + | AND + | QUERY_TEXT IS NOT NULL + |ORDER BY + | START_TIME + |""".stripMargin) + try { + val queries = new ListBuffer[ExecutedQuery]() + while (rs.next()) { + queries.append( + ExecutedQuery( + id = rs.getString("QUERY_HASH"), + source = rs.getString("QUERY_TEXT"), + timestamp = rs.getTimestamp("START_TIME"), + duration = Duration.ofMillis(rs.getLong("TOTAL_ELAPSED_TIME")), + user = Some(rs.getString("USER_NAME")), + filename = None)) + } + QueryHistory(queries) + } finally { + rs.close() + } + } finally { + stmt.close() + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitions.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitions.scala new file mode 100644 index 0000000000..db6a0bd77e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitions.scala @@ -0,0 +1,173 @@ +package com.databricks.labs.remorph.discovery + +import com.databricks.labs.remorph.intermediate.{DataType, Metadata, StructField} +import com.databricks.labs.remorph.parsers.snowflake.{SnowflakeLexer, SnowflakeParser, SnowflakeTypeBuilder} +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} + +import java.sql.Connection +import scala.collection.mutable + +class SnowflakeTableDefinitions(conn: Connection) { + + /** + * Parses a data type string and returns the corresponding DataType object. + * + * @param dataTypeString The string representation of the data type. + * @return The DataType object corresponding to the input string. + */ + private def getDataType(dataTypeString: String): DataType = { + val inputString = CharStreams.fromString(dataTypeString) + val lexer = new SnowflakeLexer(inputString) + val tokenStream = new CommonTokenStream(lexer) + val parser = new SnowflakeParser(tokenStream) + val ctx = parser.dataType() + val dataTypeBuilder = new SnowflakeTypeBuilder + dataTypeBuilder.buildDataType(ctx) + } + + private def getTableDefinitionQuery(catalogName: String): String = { + s"""WITH column_info AS ( + | SELECT + | TABLE_CATALOG, + | TABLE_SCHEMA, + | TABLE_NAME, + | LISTAGG( + | column_name || '§' || CASE + | WHEN numeric_precision IS NOT NULL + | AND numeric_scale IS NOT NULL THEN CONCAT(data_type, '(', numeric_precision, ',', numeric_scale, ')') + | WHEN LOWER(data_type) = 'text' THEN CONCAT('varchar', '(', CHARACTER_MAXIMUM_LENGTH, ')') + | ELSE data_type + | END || '§' || TO_BOOLEAN( + | CASE + | WHEN IS_NULLABLE = 'YES' THEN 'true' + | ELSE 'false' + | END + | ) || '§' || COALESCE(COMMENT, ''), + | '‡' + | ) WITHIN GROUP ( + | ORDER BY + | ordinal_position + | ) AS Schema + | FROM + | ${catalogName}.INFORMATION_SCHEMA.COLUMNS + | GROUP BY + | TABLE_CATALOG, + | TABLE_SCHEMA, + | TABLE_NAME + |) + |SELECT + | sft.TABLE_CATALOG, + | sft.TABLE_SCHEMA, + | sft.TABLE_NAME, + | sft.comment, + | sfe.location, + | sfe.file_format_name, + | sfv.view_definition, + | column_info.Schema AS DERIVED_SCHEMA, + | FLOOR(sft.BYTES / (1024 * 1024 * 1024)) AS SIZE_GB + |FROM + | column_info + | JOIN ${catalogName}.INFORMATION_SCHEMA.TABLES sft ON column_info.TABLE_CATALOG = sft.TABLE_CATALOG + | AND column_info.TABLE_SCHEMA = sft.TABLE_SCHEMA + | AND column_info.TABLE_NAME = sft.TABLE_NAME + | LEFT JOIN ${catalogName}.INFORMATION_SCHEMA.VIEWS sfv ON column_info.TABLE_CATALOG = sfv.TABLE_CATALOG + | AND column_info.TABLE_SCHEMA = sfv.TABLE_SCHEMA + | AND column_info.TABLE_NAME = sfv.TABLE_NAME + | LEFT JOIN ${catalogName}.INFORMATION_SCHEMA.EXTERNAL_TABLES sfe ON column_info.TABLE_CATALOG = sfe.TABLE_CATALOG + | AND column_info.TABLE_SCHEMA = sfe.TABLE_SCHEMA + | AND column_info.TABLE_NAME = sfe.TABLE_NAME + |ORDER BY + | sft.TABLE_CATALOG, + | sft.TABLE_SCHEMA, + | sft.TABLE_NAME; + |""".stripMargin + } + + /** + * Retrieves the definitions of all tables in the Snowflake database. + * + * @return A sequence of TableDefinition objects representing the tables in the database. + */ + private def getTableDefinitions(catalogName: String): Seq[TableDefinition] = { + val stmt = conn.createStatement() + try { + val tableDefinitionList = new mutable.ListBuffer[TableDefinition]() + val rs = stmt.executeQuery(getTableDefinitionQuery(catalogName)) + try { + while (rs.next()) { + val tableCatalog = rs.getString("TABLE_CATALOG") + val tableSchema = rs.getString("TABLE_SCHEMA") + val tableName = rs.getString("TABLE_NAME") + val columns = rs + .getString("DERIVED_SCHEMA") + .split("‡") + .map(x => { + val data = x.split("§") + val name = data(0) + val dataType = getDataType(data(1)) + val nullable = data(2).toBoolean + val comment = if (data.length > 3) Option(data(3)) else None + StructField(name, dataType, nullable, Option(Metadata(comment))) + }) + tableDefinitionList.append( + TableDefinition( + tableCatalog, + tableSchema, + tableName, + Option(rs.getString("LOCATION")), + Option(rs.getString("FILE_FORMAT_NAME")), + Option(rs.getString("VIEW_DEFINITION")), + columns, + rs.getInt("SIZE_GB"), + Option(rs.getString("COMMENT")))) + } + tableDefinitionList + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + + def getAllTableDefinitions: mutable.Seq[TableDefinition] = { + getAllCatalogs.flatMap(getTableDefinitions) + } + + def getAllSchemas(catalogName: String): mutable.ListBuffer[String] = { + val stmt = conn.createStatement() + try { + val rs = stmt.executeQuery(s"SHOW SCHEMAS IN $catalogName") + try { + val schemaList = new mutable.ListBuffer[String]() + while (rs.next()) { + schemaList.append(rs.getString("name")) + } + schemaList + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + + def getAllCatalogs: mutable.ListBuffer[String] = { + val stmt = conn.createStatement() + try { + val rs = stmt.executeQuery("SHOW DATABASES") + try { + val catalogList = new mutable.ListBuffer[String]() + while (rs.next()) { + catalogList.append(rs.getString("name")) + } + catalogList + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitions.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitions.scala new file mode 100644 index 0000000000..bb2f363a58 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitions.scala @@ -0,0 +1,253 @@ +package com.databricks.labs.remorph.discovery + +import com.databricks.labs.remorph.intermediate.{DataType, Metadata, StructField} +import com.databricks.labs.remorph.parsers.tsql.{DataTypeBuilder, TSqlLexer, TSqlParser} +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} + +import java.sql.Connection +import scala.collection.mutable + +class TSqlTableDefinitions(conn: Connection) { + + /** + * Parses a data type string and returns the corresponding DataType object. + * + * @param dataTypeString The string representation of the data type. + * @return The DataType object corresponding to the input string. + */ + private def getDataType(dataTypeString: String): DataType = { + val inputString = CharStreams.fromString(dataTypeString) + val lexer = new TSqlLexer(inputString) + val tokenStream = new CommonTokenStream(lexer) + val parser = new TSqlParser(tokenStream) + val ctx = parser.dataType() + val dataTypeBuilder = new DataTypeBuilder + dataTypeBuilder.build(ctx) + } + + private def getTableDefinitionQuery(catalogName: String): String = { + s"""WITH column_info AS ( + |SELECT + | TABLE_CATALOG, + | TABLE_SCHEMA, + | TABLE_NAME, + | STRING_AGG( + | CONCAT( + | column_name, + | '§', + | CASE + | WHEN numeric_precision IS NOT NULL + | AND numeric_scale IS NOT NULL THEN CONCAT(data_type, '(', numeric_precision, ',', numeric_scale, ')') + | WHEN LOWER(data_type) = 'text' THEN CONCAT('varchar', '(', CHARACTER_MAXIMUM_LENGTH, ')') + | ELSE data_type + | END, + | '§', + | CASE + | WHEN cis.IS_NULLABLE = 'YES' THEN 'true' + | ELSE 'false' + | END, + | '§', + | ISNULL(CAST(ep_col.value AS NVARCHAR(MAX)), '') + | ), + | '‡' + | ) WITHIN GROUP ( + | ORDER BY + | ordinal_position + | ) AS DERIVED_SCHEMA + |FROM + | ${catalogName}.sys.tables t + | INNER JOIN ${catalogName}.sys.columns c ON t.object_id = c.object_id + | INNER JOIN ${catalogName}.INFORMATION_SCHEMA.COLUMNS cis ON t.name = cis.TABLE_NAME + | AND c.name = cis.COLUMN_NAME + | OUTER APPLY ( + | SELECT + | TOP 1 value + | FROM + | ${catalogName}.sys.extended_properties + | WHERE + | major_id = t.object_id + | AND minor_id = 0 + | ORDER BY + | name DESC + | ) ep_tbl + | OUTER APPLY ( + | SELECT + | TOP 1 value + | FROM + | ${catalogName}.sys.extended_properties + | WHERE + | major_id = c.object_id + | AND minor_id = c.column_id + | ORDER BY + | name DESC + | ) ep_col + |GROUP BY + | TABLE_CATALOG, + | TABLE_SCHEMA, + | TABLE_NAME + |), + |table_file_info AS ( + | SELECT + | s.name AS TABLE_SCHEMA, + | t.name AS TABLE_NAME, + | f.physical_name AS location, + | f.type_desc AS TABLE_FORMAT, + | CAST(ROUND(SUM(a.used_pages) * 8.0 / 1024, 2) AS DECIMAL(18, 2)) AS SIZE_GB + | FROM + | ${catalogName}.sys.tables t + | INNER JOIN ${catalogName}.sys.indexes i ON t.object_id = i.object_id + | INNER JOIN ${catalogName}.sys.partitions p ON i.object_id = p.object_id + | AND i.index_id = p.index_id + | INNER JOIN ${catalogName}.sys.allocation_units a ON p.partition_id = a.container_id + | INNER JOIN ${catalogName}.sys.schemas s ON t.schema_id = s.schema_id + | INNER JOIN ${catalogName}.sys.database_files f ON a.data_space_id = f.data_space_id + | LEFT JOIN ${catalogName}.sys.extended_properties ep ON ep.major_id = t.object_id + | AND ep.minor_id = 0 + | GROUP BY + | s.name, + | t.name, + | f.name, + | f.physical_name, + | f.type_desc + |), + |table_comment_info AS ( + | SELECT + | s.name AS TABLE_SCHEMA, + | t.name AS TABLE_NAME, + | CAST(ep.value AS NVARCHAR(MAX)) AS TABLE_COMMENT + | FROM + | ${catalogName}.sys.tables t + | INNER JOIN ${catalogName}.sys.schemas s ON t.schema_id = s.schema_id + | OUTER APPLY ( + | SELECT + | TOP 1 value + | FROM + | ${catalogName}.sys.extended_properties + | WHERE + | major_id = t.object_id + | AND minor_id = 0 + | ORDER BY + | name DESC + | ) ep + |) + |SELECT + | sft.TABLE_CATALOG, + | sft.TABLE_SCHEMA, + | sft.TABLE_NAME, + | tfi.location, + | tfi.TABLE_FORMAT, + | sfv.view_definition, + | column_info.DERIVED_SCHEMA, + | tfi.SIZE_GB, + | tci.TABLE_COMMENT + |FROM + | column_info + | JOIN ${catalogName}.INFORMATION_SCHEMA.TABLES sft ON column_info.TABLE_CATALOG = sft.TABLE_CATALOG + | AND column_info.TABLE_SCHEMA = sft.TABLE_SCHEMA + | AND column_info.TABLE_NAME = sft.TABLE_NAME + | LEFT JOIN ${catalogName}.INFORMATION_SCHEMA.VIEWS sfv ON column_info.TABLE_CATALOG = sfv.TABLE_CATALOG + | AND column_info.TABLE_SCHEMA = sfv.TABLE_SCHEMA + | AND column_info.TABLE_NAME = sfv.TABLE_NAME + | LEFT JOIN table_file_info tfi ON column_info.TABLE_SCHEMA = tfi.TABLE_SCHEMA + | AND column_info.TABLE_NAME = tfi.TABLE_NAME + | LEFT JOIN table_comment_info tci ON column_info.TABLE_SCHEMA = tci.TABLE_SCHEMA + | AND column_info.TABLE_NAME = tci.TABLE_NAME + |ORDER BY + | sft.TABLE_CATALOG, + | sft.TABLE_SCHEMA, + | sft.TABLE_NAME; + |""".stripMargin + } + + /** + * Retrieves the definitions of all tables in the Snowflake database. + * + * @return A sequence of TableDefinition objects representing the tables in the database. + */ + private def getTableDefinitions(catalogName: String): Seq[TableDefinition] = { + val stmt = conn.createStatement() + try { + val tableDefinitionList = new mutable.ListBuffer[TableDefinition]() + val rs = stmt.executeQuery(getTableDefinitionQuery(catalogName)) + try { + while (rs.next()) { + val tableSchema = rs.getString("TABLE_SCHEMA") + val tableName = rs.getString("TABLE_NAME") + val tableCatalog = rs.getString("TABLE_CATALOG") + val columns = rs + .getString("DERIVED_SCHEMA") + .split("‡") + .map(x => { + val data = x.split("§") + val name = data(0) + val dataType = getDataType(data(1)) + val nullable = data(2).toBoolean + val comment = if (data.length > 3) Option(data(3)) else None + StructField(name, dataType, nullable, Some(Metadata(comment))) + }) + tableDefinitionList.append( + TableDefinition( + tableCatalog, + tableSchema, + tableName, + Option(rs.getString("LOCATION")), + Option(rs.getString("TABLE_FORMAT")), + Option(rs.getString("VIEW_DEFINITION")), + columns, + rs.getInt("SIZE_GB"), + Option(rs.getString("TABLE_COMMENT")))) + } + tableDefinitionList + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + + def getAllTableDefinitions: mutable.Seq[TableDefinition] = { + getAllCatalogs.flatMap(getTableDefinitions) + } + + def getAllSchemas(catalogName: String): mutable.ListBuffer[String] = { + val stmt = conn.createStatement() + try { + val rs = stmt.executeQuery(s"""select SCHEMA_NAME from ${catalogName}.INFORMATION_SCHEMA.SCHEMATA""") + try { + val schemaList = new mutable.ListBuffer[String]() + while (rs.next()) { + schemaList.append(rs.getString("SCHEMA_NAME")) + } + schemaList + } catch { + case e: Exception => + e.printStackTrace() + throw e + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + + def getAllCatalogs: mutable.ListBuffer[String] = { + val stmt = conn.createStatement() + try { + val rs = stmt.executeQuery("SELECT NAME FROM sys.databases WHERE NAME != 'MASTER'") + try { + val catalogList = new mutable.ListBuffer[String]() + while (rs.next()) { + catalogList.append(rs.getString("name")) + } + catalogList + } finally { + rs.close() + } + } finally { + stmt.close() + } + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/discovery/queries.scala b/core/src/main/scala/com/databricks/labs/remorph/discovery/queries.scala new file mode 100644 index 0000000000..5f7d462a0c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/discovery/queries.scala @@ -0,0 +1,45 @@ +package com.databricks.labs.remorph.discovery + +import com.databricks.labs.remorph.intermediate.StructField + +import java.sql.Timestamp +import java.time.Duration + +case class ExecutedQuery( + id: String, + source: String, + timestamp: Timestamp = new Timestamp(System.currentTimeMillis()), + duration: Duration = Duration.ofMillis(0), + user: Option[String] = None, + filename: Option[String] = None) + +case class QueryHistory(queries: Seq[ExecutedQuery]) + +case class UnparsedQuery(timestamp: Timestamp, source: String) + +case class TableDefinition( + catalog: String, + schema: String, + table: String, + location: Option[String] = None, + tableFormat: Option[String] = None, + viewText: Option[String] = None, + columns: Seq[StructField] = Seq.empty, + sizeGb: Int = 0, + comment: Option[String] = None) + +case class Grant(objectType: String, objectKey: String, principal: String, action: String) + +case class ComputeCapacity( + startTs: Timestamp, + endTs: Timestamp, + name: String, + availableCPUs: Int, + availableMemoryGb: Int, + usedCPUs: Int, + usedMemoryGb: Int, + listPrice: Double) + +trait QueryHistoryProvider { + def history(): QueryHistory +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/Generator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/Generator.scala new file mode 100644 index 0000000000..5c06eef457 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/Generator.scala @@ -0,0 +1,88 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.{Generating, KoResult, OkResult, Transformation, TransformationConstructors, TranspilerState, WorkflowStage} +import com.databricks.labs.remorph.intermediate.{IncoherentState, TreeNode, UnexpectedNode} + +trait Generator[In <: TreeNode[In], Out] extends TransformationConstructors { + def generate(tree: In): Transformation[Out] + def unknown(tree: In): Transformation[Nothing] = + ko(WorkflowStage.GENERATE, UnexpectedNode(tree.getClass.getSimpleName)) +} + +trait CodeGenerator[In <: TreeNode[In]] extends Generator[In, String] { + + private def generateAndJoin(trees: Seq[In], separator: String): Transformation[String] = { + trees.map(generate).sequence.map(_.mkString(separator)) + } + + /** + * Apply the generator to the input nodes and join the results with commas. + */ + def commas(trees: Seq[In]): Transformation[String] = generateAndJoin(trees, ", ") + + /** + * Apply the generator to the input nodes and join the results with whitespaces. + */ + def spaces(trees: Seq[In]): Transformation[String] = generateAndJoin(trees, " ") + + /** + * When the current Phase is Generating, update its GeneratorContext with the provided function. + * @param f + * A function for updating a GeneratorContext. + * @return + * A transformation that: + * - updates the state according to f and produces no meaningful output when the current Phase is Generating. + * - fails if the current Phase is different from Generating. + */ + def updateGenCtx(f: GeneratorContext => GeneratorContext): Transformation[Unit] = new Transformation({ + case TranspilerState(g: Generating, tm) => OkResult((TranspilerState(g.copy(ctx = f(g.ctx)), tm), ())) + case s => KoResult(WorkflowStage.GENERATE, IncoherentState(s.currentPhase, classOf[Generating])) + }) + + /** + * When the current Phase is Generating, update the GeneratorContext by incrementing the indentation level. + * @return + * A tranformation that increases the indentation level when the current Phase is Generating and fails otherwise. + */ + def nest: Transformation[Unit] = updateGenCtx(_.nest) + + /** + * When the current Phase is Generating, update the GeneratorContext by decrementing the indentation level. + * @return + * A tranformation that decreases the indentation level when the current Phase is Generating and fails otherwise. + */ + def unnest: Transformation[Unit] = updateGenCtx(_.unnest) + + /** + * When the current Phase is Generating, produce a block of code where the provided body is nested under the header. + * @param header + * A transformation that produces the header, which will remain unindented. Could be a function signature, + * a class definition, etc. + * @param body + * A transformation that produces the body, which will be indented one level under the header. + * @return + * A transformation that produces the header followed by the indented body (separated by a newline) and restores + * the indentation level to its original value. Said transformation will fail if the current Phase isn't + * Generating. + */ + def withIndentedBlock(header: Transformation[String], body: Transformation[String]): Transformation[String] = + for { + h <- header + _ <- nest + b <- body + _ <- unnest + } yield h + "\n" + b + + /** + * When the current Phase is Generating, allows for building transformations that use the current GeneratorContext. + * @param transfoUsingCtx + * A function that will receive the current GeneratorContext and produce a Transformation. + * @return + * A transformation that uses the current GeneratorContext if the current Phase is Generating and fails otherwise. + */ + def withGenCtx(transfoUsingCtx: GeneratorContext => Transformation[String]): Transformation[String] = + new Transformation[GeneratorContext]({ + case TranspilerState(g: Generating, tm) => OkResult((TranspilerState(g, tm), g.ctx)) + case s => KoResult(WorkflowStage.GENERATE, IncoherentState(s.currentPhase, classOf[Generating])) + }).flatMap(transfoUsingCtx) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/GeneratorContext.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/GeneratorContext.scala new file mode 100644 index 0000000000..150040e3fe --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/GeneratorContext.scala @@ -0,0 +1,41 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.{intermediate => ir} + +case class GeneratorContext( + // needed for sql"EXISTS (${ctx.logical.generate(ctx, subquery)})" in SQLGenerator + logical: Generator[ir.LogicalPlan, String], + maxLineWidth: Int = 120, + private val indent: Int = 0, + private val layer: Int = 0, + private val joins: Int = 0, + wrapLiteral: Boolean = true) { + def nest: GeneratorContext = + GeneratorContext(logical, maxLineWidth = maxLineWidth, joins = joins, layer = layer, indent = indent + 1) + + def unnest: GeneratorContext = + GeneratorContext( + logical, + maxLineWidth = maxLineWidth, + joins = joins, + layer = layer, + indent = Math.max(0, indent - 1)) + + def ws: String = " " * indent + + def subQuery: GeneratorContext = + GeneratorContext(logical, maxLineWidth = maxLineWidth, joins = joins, layer = layer + 1, indent = indent + 1) + + def layerName: String = s"layer_$layer" + + def withRawLiteral: GeneratorContext = + GeneratorContext( + logical, + maxLineWidth = maxLineWidth, + joins = joins, + indent = indent, + layer = layer, + wrapLiteral = false) + + def hasJoins: Boolean = joins > 0 +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSet.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSet.scala new file mode 100644 index 0000000000..dfc56e7262 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSet.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.remorph.generators.orchestration + +import java.io.File +import java.nio.file.Files +import scala.collection.mutable + +class FileSet { + private[this] val files = new mutable.HashMap[String, String]() + + def withFile(name: String, content: String): FileSet = { + files(name) = content + this + } + + def getFile(name: String): Option[String] = { + files.get(name) + } + + def removeFile(name: String): Unit = { + files.remove(name) + } + + def persist(path: File): Unit = { + files.foreach { case (name, content) => + val file = new File(path, name) + if (!file.getParentFile.exists()) { + file.getParentFile.mkdirs() + } + if (!file.exists()) { + file.createNewFile() + } + Files.write(file.toPath, content.getBytes) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGenerator.scala new file mode 100644 index 0000000000..b2fceb265f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGenerator.scala @@ -0,0 +1,37 @@ +package com.databricks.labs.remorph.generators.orchestration + +import com.databricks.labs.remorph.Transformation +import com.databricks.labs.remorph.generators.Generator +import com.databricks.labs.remorph.generators.orchestration.rules.converted.CreatedFile +import com.databricks.labs.remorph.generators.orchestration.rules._ +import com.databricks.labs.remorph.intermediate.Rules +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.transpilers.{PySparkGenerator, SqlGenerator} + +class FileSetGenerator( + private[this] val parser: PlanParser[_], + private[this] val sqlGen: SqlGenerator, + private[this] val pyGen: PySparkGenerator) + extends Generator[JobNode, FileSet] { + private[this] val rules = Rules( + new QueryHistoryToQueryNodes(parser), + new DefineSchemas(), + new ExtractVariables(), + new TryGenerateSQL(sqlGen), + new TryGeneratePythonNotebook(pyGen), + new TrySummarizeFailures(), + new ReformatCode(), + new DefineJob(), + new GenerateBundleFile()) + + override def generate(tree: JobNode): Transformation[FileSet] = { + val fileSet = new FileSet() + rules(tree) foreachUp { + case CreatedFile(name, code) => + fileSet.withFile(name, code) + case _ => // noop + } + ok(fileSet) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineJob.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineJob.scala new file mode 100644 index 0000000000..9ab4c7e00e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineJob.scala @@ -0,0 +1,40 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.bundles.Schema +import com.databricks.labs.remorph.generators.orchestration.rules.converted.{CreatedFile, NeedsVariables, PythonNotebookTask, SqlNotebookTask, SuccessPy, SuccessSQL, ToTask} +import com.databricks.labs.remorph.generators.orchestration.rules.history.Migration +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobSettings +import com.databricks.labs.remorph.intermediate.workflows.tasks.Task + +class DefineJob extends Rule[JobNode] { + override def apply(tree: JobNode): JobNode = tree transformUp { + case NeedsVariables(SuccessPy(name, code), variables) => + PythonNotebookTask(CreatedFile(s"notebooks/$name.py", code), variables.map(_ -> "FILL_ME").toMap) + case SuccessPy(name, code) => + PythonNotebookTask(CreatedFile(s"notebooks/$name.py", code)) + case NeedsVariables(SuccessSQL(name, code), variables) => + SqlNotebookTask(CreatedFile(s"notebooks/$name.sql", code), variables.map(_ -> "FILL_ME").toMap) + case SuccessSQL(name, code) => + SqlNotebookTask(CreatedFile(s"notebooks/$name.sql", code)) + case m: Migration => + // TODO: create multiple jobs, once we realise we need that + // TODO: add task dependencies via com.databricks.labs.remorph.graph.TableGraph + var tasks = Seq[Task]() + var other = Seq[JobNode]() + m foreachUp { + case toTask: ToTask => + tasks +:= toTask.toTask + case task: Task => + tasks +:= task + case file: CreatedFile => + other +:= file + case schema: Schema => + other +:= schema + case _ => // noop + } + val job = JobSettings("Migrated via Remorph", tasks.sortBy(_.taskKey), tags = Map("generator" -> "remorph")) + Migration(other :+ job) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineSchemas.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineSchemas.scala new file mode 100644 index 0000000000..b37f9453e5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/DefineSchemas.scala @@ -0,0 +1,28 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.bundles.Schema +import com.databricks.labs.remorph.generators.orchestration.rules.history.{Migration, QueryPlan} +import com.databricks.labs.remorph.intermediate.{NamedTable, Rule} +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +class DefineSchemas extends Rule[JobNode] { + override def apply(tree: JobNode): JobNode = tree transformUp { case Migration(children) => + var schemas = Seq[Schema]() + children foreach { + case QueryPlan(plan, _) => + plan foreach { + case NamedTable(unparsedName, _, _) => + val parts = unparsedName.split("\\.") + if (parts.size == 1) { + schemas +:= Schema("main", "default") + } else { + schemas +:= Schema("main", parts(0)) + } + case _ => // noop + } + case _ => // noop + } + schemas = schemas.distinct.sortBy(_.name) + Migration(schemas ++ children) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ExtractVariables.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ExtractVariables.scala new file mode 100644 index 0000000000..c035c8d6d4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ExtractVariables.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.converted.NeedsVariables +import com.databricks.labs.remorph.generators.orchestration.rules.history.QueryPlan +import com.databricks.labs.remorph.intermediate.{Rule, Variable} +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +class ExtractVariables extends Rule[JobNode] { + override def apply(tree: JobNode): JobNode = tree transformUp { case q: QueryPlan => + val variables = q.plan.expressions + .filter(_.isInstanceOf[Variable]) + .map { case Variable(name) => name } + if (variables.nonEmpty) { + NeedsVariables(q, variables) + } else { + q + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFile.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFile.scala new file mode 100644 index 0000000000..5af21423e2 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFile.scala @@ -0,0 +1,47 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.bundles._ +import com.databricks.labs.remorph.generators.orchestration.rules.converted.CreatedFile +import com.databricks.labs.remorph.generators.orchestration.rules.history.Migration +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobSettings +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +// see https://docs.databricks.com/en/dev-tools/bundles/settings.html +class GenerateBundleFile extends Rule[JobNode] { + private[this] val mapper = + new ObjectMapper(new YAMLFactory()) + .setSerializationInclusion(Include.NON_DEFAULT) + .registerModule(DefaultScalaModule) + + override def apply(tree: JobNode): JobNode = tree transform { case Migration(children) => + val resources = findResources(children) + Migration(children ++ Seq(bundleDefinition(resources))) + } + + private def findResources(children: Seq[JobNode]): Resources = { + var resources = Resources() + children foreach { + case schema: Schema => + resources = resources.withSchema(schema) + case job: JobSettings => + resources = resources.withJob(job) + case _ => // noop + } + resources + } + + private def bundleDefinition(resources: Resources): CreatedFile = { + val bundle = BundleFile( + resources = Some(resources), + bundle = Some(Bundle("remorphed")), + targets = + Map("dev" -> Target(mode = Some("development"), default = true), "prod" -> Target(mode = Some("production")))) + val yml = mapper.writeValueAsString(bundle) + CreatedFile("databricks.yml", yml) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/QueryHistoryToQueryNodes.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/QueryHistoryToQueryNodes.scala new file mode 100644 index 0000000000..a58faecac5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/QueryHistoryToQueryNodes.scala @@ -0,0 +1,26 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.{KoResult, OkResult, Parsing, PartialResult, TranspilerState} +import com.databricks.labs.remorph.discovery.{ExecutedQuery, QueryHistory} +import com.databricks.labs.remorph.generators.orchestration.rules.history.{FailedQuery, Migration, PartialQuery, QueryPlan, RawMigration} +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.parsers.PlanParser + +class QueryHistoryToQueryNodes(val parser: PlanParser[_]) extends Rule[JobNode] { + override def apply(plan: JobNode): JobNode = plan match { + case RawMigration(QueryHistory(queries)) => Migration(queries.par.map(executedQuery).seq) + } + + private def executedQuery(query: ExecutedQuery): JobNode = { + val state = TranspilerState(Parsing(query.source, query.id)) + parser.parse + .flatMap(parser.visit) + .flatMap(parser.optimize) + .run(state) match { + case OkResult((_, plan)) => QueryPlan(plan, query) + case PartialResult((_, plan), error) => PartialQuery(query, error.msg, QueryPlan(plan, query)) + case KoResult(stage, error) => FailedQuery(query, error.msg, stage) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ReformatCode.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ReformatCode.scala new file mode 100644 index 0000000000..4495df1524 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/ReformatCode.scala @@ -0,0 +1,24 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult} +import com.databricks.labs.remorph.generators.orchestration.rules.converted.{SuccessPy, SuccessSQL} +import com.databricks.labs.remorph.generators.py.RuffFormatter +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.github.vertical_blank.sqlformatter.SqlFormatter +import com.github.vertical_blank.sqlformatter.languages.Dialect + +class ReformatCode extends Rule[JobNode] { + private[this] val ruff = new RuffFormatter() + private[this] val sqlf = SqlFormatter.of(Dialect.SparkSql) + + override def apply(tree: JobNode): JobNode = tree transformUp { + case SuccessSQL(name, code) => SuccessSQL(name, sqlf.format(code)) + case SuccessPy(name, code) => + ruff.format(code) match { + case OkResult(output) => SuccessPy(name, output) + case PartialResult(output, _) => SuccessPy(name, output) + case KoResult(_, _) => SuccessPy(name, code) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGeneratePythonNotebook.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGeneratePythonNotebook.scala new file mode 100644 index 0000000000..3a711a2f8d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGeneratePythonNotebook.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.{Generating, KoResult, OkResult, PartialResult, TranspilerState} +import com.databricks.labs.remorph.generators.GeneratorContext +import com.databricks.labs.remorph.generators.orchestration.rules.converted.SuccessPy +import com.databricks.labs.remorph.generators.orchestration.rules.history.{FailedQuery, PartialQuery, QueryPlan} +import com.databricks.labs.remorph.generators.py.LogicalPlanGenerator +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.transpilers.PySparkGenerator + +class TryGeneratePythonNotebook(generator: PySparkGenerator) extends Rule[JobNode] { + override def apply(tree: JobNode): JobNode = tree transformDown { case n @ QueryPlan(plan, query) => + val state = TranspilerState(Generating(plan, n, GeneratorContext(new LogicalPlanGenerator))) + generator.generate(plan).run(state) match { + case OkResult((_, sql)) => SuccessPy(query.id, sql) + case PartialResult((_, sql), error) => PartialQuery(query, error.msg, SuccessPy(query.id, sql)) + case KoResult(stage, error) => FailedQuery(query, error.msg, stage) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGenerateSQL.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGenerateSQL.scala new file mode 100644 index 0000000000..a69ad680f6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TryGenerateSQL.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.converted.SuccessSQL +import com.databricks.labs.remorph.{Generating, KoResult, OkResult, PartialResult, TranspilerState} +import com.databricks.labs.remorph.generators.orchestration.rules.history.{FailedQuery, PartialQuery, QueryPlan} +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.transpilers.SqlGenerator + +class TryGenerateSQL(generator: SqlGenerator) extends Rule[JobNode] { + override def apply(tree: JobNode): JobNode = tree transformDown { case n @ QueryPlan(plan, query) => + val state = TranspilerState(Generating(plan, n, generator.initialGeneratorContext)) + generator.generate(plan).run(state) match { + case OkResult((_, sql)) => SuccessSQL(query.id, sql) + case PartialResult((_, sql), error) => PartialQuery(query, error.msg, SuccessSQL(query.id, sql)) + case KoResult(stage, error) => FailedQuery(query, error.msg, stage) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TrySummarizeFailures.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TrySummarizeFailures.scala new file mode 100644 index 0000000000..a153c7bc09 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/TrySummarizeFailures.scala @@ -0,0 +1,59 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.discovery.ExecutedQuery +import com.databricks.labs.remorph.generators.orchestration.rules.converted.CreatedFile +import com.databricks.labs.remorph.generators.orchestration.rules.history.{FailedQuery, Migration, PartialQuery} +import com.databricks.labs.remorph.intermediate.Rule +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +class TrySummarizeFailures extends Rule[JobNode] { + + override def apply(tree: JobNode): JobNode = { + var partials = Seq.empty[(ExecutedQuery, String)] + val removedPartials = tree transformUp { case PartialQuery(executed, message, query) => + partials = partials :+ ((executed, message)) + query + } + removedPartials transformUp { case Migration(queries) => + var children = queries.filterNot(_.isInstanceOf[FailedQuery]) + val failedQueries = queries.filter(_.isInstanceOf[FailedQuery]).map(_.asInstanceOf[FailedQuery]) + if (failedQueries.nonEmpty) { + children = children :+ failedQueryInfo(failedQueries.sortBy(_.query.id)) + } + if (partials.nonEmpty) { + children = children :+ partialQueryInfo(partials.distinct.sortBy(_._1.id)) + } + Migration(children) + } + } + + private def failedQueryInfo(failedQueries: Seq[FailedQuery]): CreatedFile = { + CreatedFile( + "failed_queries.md", + failedQueries map { case FailedQuery(query, message, stage) => + s""" + |# query: `${query.id}` + |$stage: $message + | + |```sql + |${query.source} + |``` + |""".stripMargin + } mkString "\n") + } + + private def partialQueryInfo(partials: Seq[(ExecutedQuery, String)]): CreatedFile = { + CreatedFile( + "partial_failures.md", + partials map { case (query, message) => + s""" + |# query: `${query.id}` + |$message + | + |```sql + |${query.source} + |``` + |""".stripMargin + } mkString "\n") + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Bundle.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Bundle.scala new file mode 100644 index 0000000000..0805ffd247 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Bundle.scala @@ -0,0 +1,6 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.fasterxml.jackson.annotation.JsonProperty + +case class Bundle(@JsonProperty name: String) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/BundleFile.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/BundleFile.scala new file mode 100644 index 0000000000..86b3ccecec --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/BundleFile.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.fasterxml.jackson.annotation.JsonProperty + +case class BundleFile( + @JsonProperty bundle: Option[Bundle] = None, + @JsonProperty targets: Map[String, Target] = Map.empty, + @JsonProperty resources: Option[Resources] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ bundle ++ resources ++ targets.values +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/JobReference.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/JobReference.scala new file mode 100644 index 0000000000..9f7fa79049 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/JobReference.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +case class JobReference(name: String) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Resources.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Resources.scala new file mode 100644 index 0000000000..8f2bcab1c2 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Resources.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobSettings +import com.databricks.sdk.service.jobs.CreateJob +import com.fasterxml.jackson.annotation.JsonProperty + +case class Resources( + @JsonProperty jobs: Map[String, CreateJob] = Map.empty, + @JsonProperty schemas: Map[String, Schema] = Map.empty) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ schemas.values + def withSchema(schema: Schema): Resources = copy(schemas = schemas + (schema.name -> schema)) + def withJob(job: JobSettings): Resources = { + val withTarget = job.copy(name = s"[$${bundle.target}] ${job.name}") + copy(jobs = jobs + (job.resourceName -> withTarget.toCreate)) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Schema.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Schema.scala new file mode 100644 index 0000000000..855f62c621 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Schema.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.fasterxml.jackson.annotation.JsonProperty + +case class Schema( + @JsonProperty("catalog_name") catalogName: String, + @JsonProperty name: String, + @JsonProperty comment: Option[String] = None) + extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Target.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Target.scala new file mode 100644 index 0000000000..1c50ba2f38 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/bundles/Target.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.bundles + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.fasterxml.jackson.annotation.JsonProperty + +case class Target( + @JsonProperty mode: Option[String] = None, + @JsonProperty default: Boolean = false, + @JsonProperty resources: Option[Resources] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ resources +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/CreatedFile.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/CreatedFile.scala new file mode 100644 index 0000000000..eca1e3305e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/CreatedFile.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +import java.util.Locale + +case class CreatedFile(name: String, text: String) extends LeafJobNode { + def resourceName: String = { + val pathParts = name.split("/") + val baseNameParts = pathParts.last.split("\\.") + val lowerCaseName = baseNameParts.head.toLowerCase(Locale.ROOT) + lowerCaseName.replaceAll("[^A-Za-z0-9]", "_") + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/NeedsVariables.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/NeedsVariables.scala new file mode 100644 index 0000000000..5bd43a8a03 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/NeedsVariables.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +case class NeedsVariables(child: JobNode, variables: Seq[String]) extends JobNode { + override def children: Seq[JobNode] = Seq(child) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/PythonNotebookTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/PythonNotebookTask.scala new file mode 100644 index 0000000000..5e7503ddeb --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/PythonNotebookTask.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.tasks.NotebookTask + +case class PythonNotebookTask(file: CreatedFile, baseParameters: Map[String, String] = Map.empty) + extends JobNode + with ToNotebookTask { + override def children: Seq[JobNode] = Seq(file) + override def resourceName: String = file.resourceName + override def toNotebookTask: NotebookTask = NotebookTask(file.name, Some(baseParameters), None) + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/RunNotebookJobTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/RunNotebookJobTask.scala new file mode 100644 index 0000000000..6cc56defc4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/RunNotebookJobTask.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobSettings +import com.databricks.labs.remorph.intermediate.workflows.tasks.RunJobTask + +// TODO: if we have this node, then add new rule to inject DynamicJobs(m: Migration, deploying: Map[Int,String]) +// and replace the in-memory integers with `"${resources.jobs.STRING.id}"` in low-level YAML rewrite +case class RunNotebookJobTask(job: JobSettings, params: Map[String, String] = Map.empty) extends JobNode { + override def children: Seq[JobNode] = Seq(job) + def toRunJobTask(id: Long): RunJobTask = RunJobTask(id, notebookParams = params) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlNotebookTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlNotebookTask.scala new file mode 100644 index 0000000000..6fe1e20b3c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlNotebookTask.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.tasks.{NeedsWarehouse, NotebookTask} + +case class SqlNotebookTask(file: CreatedFile, baseParameters: Map[String, String] = Map.empty) + extends JobNode + with ToNotebookTask + with NeedsWarehouse { + override def children: Seq[JobNode] = Seq(file) + override def resourceName: String = file.resourceName + override def toNotebookTask: NotebookTask = NotebookTask(file.name, Some(baseParameters), Some(DEFAULT_WAREHOUSE_ID)) + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlWorkspaceFileTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlWorkspaceFileTask.scala new file mode 100644 index 0000000000..7962d8e02d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SqlWorkspaceFileTask.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.sql.SqlTaskFile +import com.databricks.labs.remorph.intermediate.workflows.tasks.{NeedsWarehouse, SqlTask, Task} + +case class SqlWorkspaceFileTask(file: CreatedFile, parameters: Map[String, String] = Map.empty) + extends JobNode + with ToTask + with NeedsWarehouse { + override def children: Seq[JobNode] = Seq(file) + override def toTask: Task = Task( + taskKey = file.resourceName, + sqlTask = Some( + SqlTask(warehouseId = DEFAULT_WAREHOUSE_ID, parameters = Some(parameters), file = Some(SqlTaskFile(file.name))))) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessPy.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessPy.scala new file mode 100644 index 0000000000..8cc2f5de09 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessPy.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +case class SuccessPy(id: String, code: String) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessSQL.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessSQL.scala new file mode 100644 index 0000000000..c2b62d8952 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/SuccessSQL.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +case class SuccessSQL(id: String, query: String) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToNotebookTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToNotebookTask.scala new file mode 100644 index 0000000000..937c15220c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToNotebookTask.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.tasks.{NotebookTask, Task} + +trait ToNotebookTask extends ToTask { + def resourceName: String + def toNotebookTask: NotebookTask + def toTask: Task = Task(taskKey = resourceName, notebookTask = Some(toNotebookTask)) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToTask.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToTask.scala new file mode 100644 index 0000000000..eb3fc5b127 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/converted/ToTask.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.converted + +import com.databricks.labs.remorph.intermediate.workflows.tasks.Task + +trait ToTask { + def toTask: Task +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/FailedQuery.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/FailedQuery.scala new file mode 100644 index 0000000000..a4d714ff09 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/FailedQuery.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.history + +import com.databricks.labs.remorph.WorkflowStage +import com.databricks.labs.remorph.discovery.ExecutedQuery +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +case class FailedQuery(query: ExecutedQuery, message: String, stage: WorkflowStage) extends JobNode { + override def children: Seq[JobNode] = Seq() +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/Migration.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/Migration.scala new file mode 100644 index 0000000000..9e2b0f1978 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/Migration.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.history + +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +case class Migration(children: Seq[JobNode]) extends JobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/PartialQuery.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/PartialQuery.scala new file mode 100644 index 0000000000..7fb229b18d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/PartialQuery.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.history + +import com.databricks.labs.remorph.discovery.ExecutedQuery +import com.databricks.labs.remorph.intermediate.workflows.JobNode + +case class PartialQuery(executed: ExecutedQuery, message: String, query: JobNode) extends JobNode { + override def children: Seq[JobNode] = Seq(query) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/QueryPlan.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/QueryPlan.scala new file mode 100644 index 0000000000..2d81576969 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/QueryPlan.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.history + +import com.databricks.labs.remorph.discovery.ExecutedQuery +import com.databricks.labs.remorph.intermediate.LogicalPlan +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +case class QueryPlan(plan: LogicalPlan, query: ExecutedQuery) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/RawMigration.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/RawMigration.scala new file mode 100644 index 0000000000..9fdf5f2142 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/orchestration/rules/history/RawMigration.scala @@ -0,0 +1,6 @@ +package com.databricks.labs.remorph.generators.orchestration.rules.history + +import com.databricks.labs.remorph.discovery.QueryHistory +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode + +case class RawMigration(queryHistory: QueryHistory) extends LeafJobNode diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/package.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/package.scala new file mode 100644 index 0000000000..efbb3e3c7b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/package.scala @@ -0,0 +1,74 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.intermediate.UncaughtException + +import scala.util.control.NonFatal + +package object generators { + + implicit class CodeInterpolator(sc: StringContext) extends TransformationConstructors { + def code(args: Any*): Transformation[String] = { + + args + .map { + case tba: Transformation[_] => tba.asInstanceOf[Transformation[String]] + case x => ok(x.toString) + } + .sequence + .flatMap { a => + val stringParts = sc.parts.iterator + val arguments = a.iterator + var failureOpt: Option[Transformation[String]] = None + val sb = new StringBuilder() + try { + sb.append(StringContext.treatEscapes(stringParts.next())) + } catch { + case NonFatal(e) => + failureOpt = Some(lift(KoResult(WorkflowStage.GENERATE, UncaughtException(e)))) + } + while (failureOpt.isEmpty && arguments.hasNext) { + try { + sb.append(arguments.next()) + sb.append(StringContext.treatEscapes(stringParts.next())) + } catch { + case NonFatal(e) => + failureOpt = Some(lift(KoResult(WorkflowStage.GENERATE, UncaughtException(e)))) + } + } + failureOpt.getOrElse(ok(sb.toString())) + } + } + } + + implicit class TBAOps(sql: Transformation[String]) { + def nonEmpty: Transformation[Boolean] = sql.map(_.nonEmpty) + def isEmpty: Transformation[Boolean] = sql.map(_.isEmpty) + } + + implicit class TBASeqOps(tbas: Seq[Transformation[String]]) extends TransformationConstructors { + + def mkCode: Transformation[String] = mkCode("", "", "") + + def mkCode(sep: String): Transformation[String] = mkCode("", sep, "") + + def mkCode(start: String, sep: String, end: String): Transformation[String] = { + tbas.sequence.map(_.mkString(start, sep, end)) + } + + /** + * Combine multiple Transformation[RemorphContext, String] into a Transformation[ RemorphContext, Seq[String] ]. + * The resulting Transformation will run each individual Transformation in sequence, accumulating all the effects + * along the way. + * + * For example, when a Transformation in the input Seq modifies the state, TBAs that come after it in the input + * Seq will see the modified state. + */ + def sequence: Transformation[Seq[String]] = + tbas.foldLeft(ok(Seq.empty[String])) { case (agg, item) => + for { + aggSeq <- agg + i <- item + } yield aggSeq :+ i + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/BasePythonGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/BasePythonGenerator.scala new file mode 100644 index 0000000000..290b35fc7e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/BasePythonGenerator.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.PartialResult +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.intermediate.{RemorphError, TreeNode, UnexpectedNode} + +abstract class BasePythonGenerator[In <: TreeNode[In]] extends CodeGenerator[In] { + + def partialResult(tree: In): Python = partialResult(tree, UnexpectedNode(tree.toString)) + def partialResult(trees: Seq[Any], err: RemorphError): Python = + lift(PartialResult(s"# FIXME: ${trees.mkString(" | ")} !!!", err)) + def partialResult(tree: Any, err: RemorphError): Python = lift(PartialResult(s"# FIXME: $tree !!!", err)) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/ExpressionGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/ExpressionGenerator.scala new file mode 100644 index 0000000000..0e1ad83ebb --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/ExpressionGenerator.scala @@ -0,0 +1,129 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.intermediate.Expression +import com.databricks.labs.remorph.{intermediate => ir} + +class ExpressionGenerator extends BasePythonGenerator[ir.Expression] { + override def generate(tree: Expression): Python = expression(tree) + + private def expression(expr: ir.Expression): Python = expr match { + case ir.Name(name) => code"$name" + case _: ir.Arithmetic => arithmetic(expr) + case _: ir.Predicate => predicate(expr) + case l: ir.Literal => literal(l) + case c: Call => call(c) + case d: Dict => dict(d) + case s: Slice => slice(s) + case Comment(child, text) => withGenCtx(ctx => code"# $text\n${ctx.ws}${expression(child)}") + case Comprehension(target, iter, ifs) => comprehension(target, iter, ifs) + case GeneratorExp(elt, gens) => code"${expression(elt)} ${spaces(gens)}" + case ListComp(elt, gens) => code"[${expression(elt)} ${spaces(gens)}]" + case SetComp(elt, gens) => code"{${expression(elt)} ${spaces(gens)}}" + case DictComp(key, value, gens) => code"{${generate(key)}: ${generate(value)} ${spaces(gens)}}" + case IfExp(test, body, orElse) => ifExpr(test, body, orElse) + case Lambda(args, body) => code"lambda ${arguments(args)}: ${expression(body)}" + case Tuple(elements, _) => code"(${commas(elements)},)" + case Attribute(value, attr, _) => code"${expression(value)}.${expression(attr)}" + case Subscript(value, index, _) => code"${expression(value)}[${expression(index)}]" + case List(elements, _) => code"[${commas(elements)}]" + case Set(Nil) => code"set()" + case Set(elements) => code"{${commas(elements)}}" + case _ => partialResult(expr) + } + + private def comprehension(target: Expression, iter: Expression, ifs: Seq[Expression]): Python = { + val ifsExpr = ifs.map(expression(_)).mkCode(" and ") + val base = code"for ${expression(target)} in ${expression(iter)}" + ifsExpr.isEmpty.flatMap { isEmpty => + if (isEmpty) { + base + } else { + code"$base if $ifsExpr" + } + } + } + + private def ifExpr(test: Expression, body: Expression, orelse: Expression): Python = { + code"${expression(body)} if ${expression(test)} else ${expression(orelse)}" + } + + def arguments(arguments: Arguments): Python = { + // TODO: add support for defaults + val positional = arguments.args match { + case Nil => None + case some => Some(commas(some)) + } + val args = arguments.vararg map { case ir.Name(name) => code"*$name" } + val kwargs = arguments.kwargs map { case ir.Name(name) => code"**$name" } + val argumentLists = Seq(positional, args, kwargs).filter(_.nonEmpty).map(_.get) + argumentLists.mkCode(", ") + } + + private def slice(s: Slice): Python = s match { + case Slice(None, None, None) => code":" + case Slice(Some(lower), None, None) => code"${expression(lower)}:" + case Slice(None, Some(upper), None) => code":${expression(upper)}" + case Slice(Some(lower), Some(upper), None) => code"${expression(lower)}:${expression(upper)}" + case Slice(None, None, Some(step)) => code"::${expression(step)}" + case Slice(Some(lower), None, Some(step)) => code"${expression(lower)}::${expression(step)}" + case Slice(None, Some(upper), Some(step)) => code":${expression(upper)}:${expression(step)}" + case Slice(Some(lower), Some(upper), Some(step)) => + code"${expression(lower)}:${expression(upper)}:${expression(step)}" + } + + private def dict(d: Dict): Python = { + d.keys.zip(d.values).map { case (k, v) => + code"${expression(k)}: ${expression(v)}" + } mkCode ("{", ", ", "}") + } + + private def call(c: Call): Python = { + val args = c.args.map(expression(_)) + val kwargs = c.keywords.map { case Keyword(k, v) => code"${expression(k)}=${expression(v)}" } + code"${expression(c.func)}(${(args ++ kwargs).mkCode(", ")})" + } + + private def arithmetic(expr: ir.Expression): Python = expr match { + case ir.UMinus(child) => code"-${expression(child)}" + case ir.UPlus(child) => code"+${expression(child)}" + case ir.Multiply(left, right) => code"${expression(left)} * ${expression(right)}" + case ir.Divide(left, right) => code"${expression(left)} / ${expression(right)}" + case ir.Mod(left, right) => code"${expression(left)} % ${expression(right)}" + case ir.Add(left, right) => code"${expression(left)} + ${expression(right)}" + case ir.Subtract(left, right) => code"${expression(left)} - ${expression(right)}" + } + + // see com.databricks.labs.remorph.generators.py.rules.AndOrToBitwise + private def predicate(expr: ir.Expression): Python = expr match { + case ir.BitwiseOr(left, right) => code"${expression(left)} | ${expression(right)}" + case ir.BitwiseAnd(left, right) => code"${expression(left)} & ${expression(right)}" + case ir.And(left, right) => code"${expression(left)} and ${expression(right)}" + case ir.Or(left, right) => code"${expression(left)} or ${expression(right)}" + case ir.Not(child) => code"~(${expression(child)})" + case ir.Equals(left, right) => code"${expression(left)} == ${expression(right)}" + case ir.NotEquals(left, right) => code"${expression(left)} != ${expression(right)}" + case ir.LessThan(left, right) => code"${expression(left)} < ${expression(right)}" + case ir.LessThanOrEqual(left, right) => code"${expression(left)} <= ${expression(right)}" + case ir.GreaterThan(left, right) => code"${expression(left)} > ${expression(right)}" + case ir.GreaterThanOrEqual(left, right) => code"${expression(left)} >= ${expression(right)}" + case _ => partialResult(expr) + } + + private def literal(l: ir.Literal): Python = l match { + case ir.Literal(_, ir.NullType) => code"None" + case ir.Literal(bytes: Array[Byte], ir.BinaryType) => ok(bytes.map("%02X" format _).mkString) + case ir.Literal(true, ir.BooleanType) => code"True" + case ir.Literal(false, ir.BooleanType) => code"False" + case ir.Literal(value, ir.ShortType) => ok(value.toString) + case ir.IntLiteral(value) => ok(value.toString) + case ir.Literal(value, ir.LongType) => ok(value.toString) + case ir.FloatLiteral(value) => ok(value.toString) + case ir.DoubleLiteral(value) => ok(value.toString) + case ir.DecimalLiteral(value) => ok(value.toString) + case ir.Literal(value: String, ir.StringType) => singleQuote(value) + case _ => partialResult(l, ir.UnsupportedDataType(l.dataType.toString)) + } + + private def singleQuote(s: String): Python = code"'$s'" +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/LogicalPlanGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/LogicalPlanGenerator.scala new file mode 100644 index 0000000000..bcf0ebe73e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/LogicalPlanGenerator.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.generators.CodeInterpolator +import com.databricks.labs.remorph.{intermediate => ir} + +class LogicalPlanGenerator extends BasePythonGenerator[ir.LogicalPlan] { + // TODO: see if com.databricks.labs.remorph.generators.GeneratorContext.logical is still needed + override def generate(tree: ir.LogicalPlan): Python = code"..." +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala new file mode 100644 index 0000000000..b53baa8f39 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.utils.StandardInputPythonSubprocess +import com.databricks.labs.remorph.Result + +class RuffFormatter { + private[this] val ruffFmt = new StandardInputPythonSubprocess("ruff format -") + def format(input: String): Result[String] = ruffFmt(input) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/StatementGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/StatementGenerator.scala new file mode 100644 index 0000000000..76743ebe65 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/StatementGenerator.scala @@ -0,0 +1,97 @@ +package com.databricks.labs.remorph.generators.py +import com.databricks.labs.remorph.generators.{CodeInterpolator, TBASeqOps} +import com.databricks.labs.remorph.intermediate.Expression + +class StatementGenerator(private[this] val exprs: ExpressionGenerator) extends BasePythonGenerator[Statement] { + override def generate(tree: Statement): Python = { + withGenCtx { ctx => + code"${ctx.ws}${statement(tree)}" + } + } + + private def statement(tree: Statement): Python = tree match { + case Module(children) => lines(children) + case Alias(name, None) => e(name) + case ExprStatement(expr) => e(expr) + case FunctionDef(name, args, children, decorators) => + withIndentedBlock(code"${decorate(decorators)}def ${name.name}(${exprs.arguments(args)}):", lines(children)) + case ClassDef(name, bases, children, decorators) => + withIndentedBlock(code"${decorate(decorators)}class ${name.name}${parents(bases)}:", lines(children)) + case Alias(name, Some(alias)) => + code"${e(name)} as ${e(alias)}" + case Import(names) => + code"import ${commas(names)}" + case ImportFrom(Some(module), names, _) => + code"from ${e(module)} import ${commas(names)}" + case Assign(targets, value) => + code"${exprs.commas(targets)} = ${e(value)}" + case Decorator(expr) => + code"@${e(expr)}" + case For(target, iter, body, orElse) => + Seq(withIndentedBlock(code"for ${e(target)} in ${e(iter)}:", lines(body)), elseB(orElse)).mkCode + case While(test, body, orElse) => + Seq(withIndentedBlock(code"while ${e(test)}:", lines(body)), elseB(orElse)).mkCode + case If(test, body, orElse) => + Seq(withIndentedBlock(code"if ${e(test)}:", lines(body)), elseB(orElse)).mkCode + case With(context, body) => + withIndentedBlock(code"with ${commas(context)}:", lines(body)) + case Raise(None, None) => + code"raise" + case Raise(Some(exc), None) => + code"raise ${e(exc)}" + case Raise(Some(exc), Some(cause)) => + code"raise ${e(exc)} from ${e(cause)}" + case Try(body, handlers, orElse, orFinally) => + Seq( + withIndentedBlock(code"try:", code"${lines(body)}"), + lines(handlers), + elseB(orElse), + elseB(orFinally, "finally")).mkCode + case Except(None, children) => + withIndentedBlock(code"except:", lines(children, finish = "")) + case Except(Some(alias), children) => + withIndentedBlock(code"except ${generate(alias)}:", lines(children, finish = "")) + case Assert(test, None) => + code"assert ${e(test)}" + case Assert(test, Some(msg)) => + code"assert ${e(test)}, ${e(msg)}" + case Return(None) => code"return" + case Return(Some(value)) => code"return ${e(value)}" + case Delete(targets) => code"del ${exprs.commas(targets)}" + case Pass => code"pass" + case Break => code"break" + case Continue => code"continue" + case _ => partialResult(tree) + } + + private def e(expr: Expression): Python = exprs.generate(expr) + + private def lines(statements: Seq[Statement], finish: String = "\n"): Python = { + val body = statements.map(generate) + val separatedItems = body.reduceLeftOption[Python] { case (agg, item) => code"$agg\n$item" } + separatedItems.map(items => code"$items$finish").getOrElse(code"") + } + + // decorators need their leading whitespace trimmed and get followed by a trailing whitespace + private def decorate(decorators: Seq[Decorator]): Python = { + withGenCtx { ctx => + lines(decorators).map { + case "" => "" + case some => s"${some.trim}\n${ctx.ws}" + } + } + } + + private def elseB(orElse: Seq[Statement], branch: String = "else"): Python = orElse match { + case Nil => code"" + case some => + withGenCtx { ctx => + withIndentedBlock(code"${ctx.ws}$branch:", lines(some)) + } + } + + private def parents(names: Seq[Expression]): Python = names match { + case Nil => code"" + case some => code"(${exprs.commas(some)})" + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/ast.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/ast.scala new file mode 100644 index 0000000000..4fe29ca5c1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/ast.scala @@ -0,0 +1,229 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.intermediate.{Binary, DataType, Expression, Name, Plan, StringType, UnresolvedType, Attribute => IRAttribute} + +// this is a subset of https://docs.python.org/3/library/ast.html + +abstract class Statement extends Plan[Statement] { + override def output: Seq[IRAttribute] = Nil +} + +abstract class LeafStatement extends Statement { + override final def children: Seq[Statement] = Nil +} + +case class Comment(child: Expression, text: String) extends Expression { + override def children: Seq[Expression] = Seq(child) + override def dataType: DataType = child.dataType +} + +// see https://docs.python.org/3/library/ast.html#ast.Module +case class Module(children: Seq[Statement]) extends Statement + +case class Arguments( + args: Seq[Expression] = Seq.empty, + vararg: Option[Name] = None, + kwargs: Option[Name] = None, + defaults: Seq[Expression] = Seq.empty) { + def expression: Seq[Expression] = args ++ vararg ++ kwargs ++ defaults +} + +// keyword arguments supplied to call +case class Keyword(arg: Name, value: Expression) extends Binary(arg, value) { + override def dataType: DataType = UnresolvedType +} + +case class Decorator(expr: Expression) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.FunctionDef +case class FunctionDef(name: Name, args: Arguments, children: Seq[Statement], decorators: Seq[Decorator] = Seq.empty) + extends Statement + +// see https://docs.python.org/3/library/ast.html#ast.ClassDef +case class ClassDef( + name: Name, + bases: Seq[Expression] = Seq.empty, + children: Seq[Statement] = Seq.empty, + decorators: Seq[Decorator] = Seq.empty) + extends Statement + +case class Return(value: Option[Expression] = None) extends LeafStatement +case class Delete(targets: Seq[Expression]) extends LeafStatement +case class Assign(targets: Seq[Expression], value: Expression) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.For +case class For(target: Expression, iter: Expression, body: Seq[Statement], orElse: Seq[Statement] = Seq.empty) + extends Statement { + override def children: Seq[Statement] = body ++ orElse +} + +// see https://docs.python.org/3/library/ast.html#ast.While +case class While(test: Expression, body: Seq[Statement], orElse: Seq[Statement] = Seq.empty) extends Statement { + override def children: Seq[Statement] = body ++ orElse +} + +// see https://docs.python.org/3/library/ast.html#ast.If +case class If(test: Expression, body: Seq[Statement], orElse: Seq[Statement] = Seq.empty) extends Statement { + override def children: Seq[Statement] = body ++ orElse +} + +// see https://docs.python.org/3/library/ast.html#ast.With +case class With(context: Seq[Alias], body: Seq[Statement]) extends Statement { + override def children: Seq[Statement] = context ++ body +} + +// see https://docs.python.org/3/library/ast.html#ast.Raise +case class Raise(exc: Option[Expression] = None, cause: Option[Expression] = None) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.Try +case class Try( + body: Seq[Statement], + handlers: Seq[Except] = Seq.empty, + orElse: Seq[Statement] = Seq.empty, + orFinally: Seq[Statement] = Seq.empty) + extends Statement { + override def children: Seq[Statement] = body ++ handlers ++ orElse ++ orFinally +} + +// see https://docs.python.org/3/library/ast.html#ast.ExceptHandler +case class Except(exception: Option[Alias] = None, children: Seq[Statement]) extends Statement + +// see https://docs.python.org/3/library/ast.html#ast.Assert +case class Assert(test: Expression, msg: Option[Expression] = None) extends LeafStatement + +case class Alias(name: Expression, alias: Option[Name] = None) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.Import +case class Import(names: Seq[Alias]) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.ImportFrom +case class ImportFrom(module: Option[Name], names: Seq[Alias] = Seq.empty, level: Option[Int] = None) + extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.Global +case class Global(names: Seq[Name]) extends LeafStatement + +case object Pass extends LeafStatement +case object Break extends LeafStatement +case object Continue extends LeafStatement + +case class ExprStatement(expr: Expression) extends LeafStatement + +// see https://docs.python.org/3/library/ast.html#ast.Call +case class Call(func: Expression, args: Seq[Expression] = Seq.empty, keywords: Seq[Keyword] = Seq.empty) + extends Expression { + override def children: Seq[Expression] = Seq(func) ++ args ++ keywords + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.NamedExpr +case class NamedExpr(target: Expression, value: Expression) extends Binary(target, value) { + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.Lambda +case class Lambda(args: Arguments, body: Expression) extends Expression { + override def children: Seq[Expression] = args.expression ++ Seq(body) + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.IfExp +case class IfExp(test: Expression, body: Expression, orElse: Expression) extends Expression { + override def children: Seq[Expression] = Seq(test, body, orElse) + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.Dict +case class Dict(keys: Seq[Expression], values: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = keys ++ values + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.Set +case class Set(elts: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = elts + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.comprehension +case class Comprehension(target: Expression, iter: Expression, ifs: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = target +: iter +: ifs + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.ListComp +case class ListComp(elt: Expression, generators: Seq[Comprehension]) extends Expression { + override def children: Seq[Expression] = elt +: generators + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.SetComp +case class SetComp(elt: Expression, generators: Seq[Comprehension]) extends Expression { + override def children: Seq[Expression] = elt +: generators + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.DictComp +case class DictComp(key: Expression, value: Expression, generators: Seq[Comprehension]) extends Expression { + override def children: Seq[Expression] = key +: value +: generators + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.GeneratorExp +case class GeneratorExp(elt: Expression, generators: Seq[Comprehension]) extends Expression { + override def children: Seq[Expression] = elt +: generators + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.FormattedValue +case class FormattedValue(value: Expression, conversion: Int, formatSpec: Option[Expression] = None) + extends Expression { + override def children: Seq[Expression] = value +: formatSpec.toList + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.JoinedStr +case class JoinedStr(children: Seq[Expression]) extends Expression { + override def dataType: DataType = StringType +} + +// see https://docs.python.org/3/library/ast.html#ast.Attribute +case class Attribute(value: Expression, attr: Expression, ctx: ExprContext = Load) extends Expression { + def this(value: Expression, name: String) = this(value, Name(name), Load) + override def children: Seq[Expression] = Seq(value, attr) + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#subscripting +case class Subscript(value: Expression, slice: Expression, ctx: ExprContext = Load) extends Expression { + override def children: Seq[Expression] = Seq(value, slice) + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#subscripting +case class Slice(lower: Option[Expression] = None, upper: Option[Expression] = None, step: Option[Expression] = None) + extends Expression { + override def children: Seq[Expression] = Nil ++ lower ++ upper ++ step + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.Starred +case class Starred(value: Expression, ctx: ExprContext = Store) extends Expression { + override def children: Seq[Expression] = Seq(value) + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.List +case class List(children: Seq[Expression], ctx: ExprContext = Load) extends Expression { + override def dataType: DataType = UnresolvedType +} + +// see https://docs.python.org/3/library/ast.html#ast.Tuple +case class Tuple(children: Seq[Expression], ctx: ExprContext = Load) extends Expression { + override def dataType: DataType = UnresolvedType +} + +sealed trait ExprContext +case object Load extends ExprContext +case object Store extends ExprContext +case object Delete extends ExprContext diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/py.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/py.scala new file mode 100644 index 0000000000..483da8520b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/py.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.Transformation + +package object py { + + type Python = Transformation[String] +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/AndOrToBitwise.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/AndOrToBitwise.scala new file mode 100644 index 0000000000..b48b3e6bfa --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/AndOrToBitwise.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.{intermediate => ir} + +// Converts `F.col('a') and F.col('b')` to `F.col('a') & F.col('b')` +class AndOrToBitwise extends ir.Rule[ir.Expression] { + override def apply(plan: ir.Expression): ir.Expression = plan match { + case ir.And(left, right) => ir.BitwiseAnd(left, right) + case ir.Or(left, right) => ir.BitwiseOr(left, right) + case _ => plan + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/DotToFCol.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/DotToFCol.scala new file mode 100644 index 0000000000..34e0e99357 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/DotToFCol.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.intermediate._ + +class DotToFCol extends Rule[Expression] with PyCommon { + override def apply(plan: Expression): Expression = plan transformUp { case Dot(Id(left, _), Id(right, _)) => + F("col", StringLiteral(s"$left.$right") :: Nil) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/ImportClasses.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/ImportClasses.scala new file mode 100644 index 0000000000..6dc41a6a85 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/ImportClasses.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.generators.py + +case class ImportAliasSideEffect(expr: ir.Expression, module: String, alias: Option[String] = None) + extends ir.Expression { + override def children: Seq[ir.Expression] = Seq(expr) + override def dataType: ir.DataType = ir.UnresolvedType +} + +case class ImportClassSideEffect(expr: ir.Expression, module: String, klass: String) extends ir.Expression { + override def children: Seq[ir.Expression] = Seq(expr) + override def dataType: ir.DataType = ir.UnresolvedType +} + +// to be called after PySparkExpressions +class ImportClasses extends ir.Rule[py.Statement] { + override def apply(plan: py.Statement): py.Statement = plan match { + case py.Module(children) => + var imports = Seq.empty[py.Import] + var importsFrom = Seq.empty[py.ImportFrom] + val body = children map { statement => + statement transformAllExpressions { + case ImportAliasSideEffect(expr, module, alias) => + imports = imports :+ py.Import(Seq(py.Alias(ir.Name(module), alias.map(ir.Name)))) + expr + case ImportClassSideEffect(expr, module, klass) => + importsFrom = importsFrom :+ py.ImportFrom(Some(ir.Name(module)), Seq(py.Alias(ir.Name(klass)))) + expr + } + } + py.Module(imports.distinct ++ importsFrom.distinct ++ body) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PyCommon.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PyCommon.scala new file mode 100644 index 0000000000..66f0cfc4f9 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PyCommon.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.generators.py + +trait PyCommon { + protected def methodOf(value: ir.Expression, name: String, args: Seq[ir.Expression]): ir.Expression = { + py.Call(py.Attribute(value, ir.Name(name)), args) + } + + protected def F(name: String, args: Seq[ir.Expression]): ir.Expression = { + ImportAliasSideEffect(methodOf(ir.Name("F"), name, args), "pyspark.sql.functions", alias = Some("F")) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkExpressions.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkExpressions.scala new file mode 100644 index 0000000000..c6469f156a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkExpressions.scala @@ -0,0 +1,124 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.generators.py + +import java.time.format.DateTimeFormatter +import java.time.{Instant, LocalDate, LocalDateTime, ZoneId, ZonedDateTime} +import java.util.Locale + +// F.expr(...) +case class RawExpr(expr: ir.Expression) extends ir.LeafExpression { + override def dataType: ir.DataType = ir.UnresolvedType +} + +class PySparkExpressions extends ir.Rule[ir.Expression] with PyCommon { + private[this] val dateFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd") + private[this] val timeFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.of("UTC")) + + override def apply(expr: ir.Expression): ir.Expression = expr transformUp { + case _: ir.Bitwise => bitwise(expr) + case ir.Like(col, pattern, escape) => methodOf(col, "like", Seq(pattern) ++ escape) + case ir.RLike(col, pattern) => methodOf(col, "rlike", Seq(pattern)) + case ir.Between(exp, lower, upper) => methodOf(exp, "between", Seq(lower, upper)) + case ir.Literal(epochDay: Long, ir.DateType) => dateLiteral(epochDay) + case ir.Literal(epochSecond: Long, ir.TimestampType) => timestampLiteral(epochSecond) + case ir.ArrayExpr(children, _) => F("array", children) + case ir.IsNull(col) => methodOf(col, "isNull", Seq()) + case ir.IsNotNull(col) => methodOf(col, "isNotNull", Seq()) + case ir.UnresolvedAttribute(name, _, _, _, _, _, _) => F("col", ir.StringLiteral(name) :: Nil) + case ir.Id(name, _) => F("col", ir.StringLiteral(name) :: Nil) + case ir.Alias(child, ir.Id(name, _)) => methodOf(child, "alias", Seq(ir.StringLiteral(name))) + case o: ir.SortOrder => sortOrder(o) + case _: ir.Star => F("col", ir.StringLiteral("*") :: Nil) + case i: ir.KnownInterval => RawExpr(i) + case ir.Case(None, branches, otherwise) => caseWhenBranches(branches, otherwise) + case w: ir.Window => window(w) + // case ir.Exists(subquery) => F("exists", Seq(py.Lambda(py.Arguments(args = Seq(ir.Name("col"))), subquery))) + // case l: ir.LambdaFunction => py.Lambda(py.Arguments(args = Seq(ir.Name("col"))), apply(l.body)) + case ir.ArrayAccess(array, index) => py.Subscript(apply(array), apply(index)) + case ir.Variable(name) => ir.Name(name) + case ir.Extract(field, child) => F("extract", Seq(apply(field), apply(child))) + case ir.Concat(children) => F("concat", children) + case ir.ConcatWs(children) => F("concat_ws", children) + case ir.In(value, list) => methodOf(value, "isin", list) + case fn: ir.Fn => F(fn.prettyName.toLowerCase(Locale.getDefault), fn.children.map(apply)) + } + + private def sortOrder(order: ir.SortOrder): ir.Expression = order match { + case ir.SortOrder(col, ir.Ascending, ir.NullsFirst) => methodOf(apply(col), "asc_nulls_first", Seq()) + case ir.SortOrder(col, ir.Ascending, ir.NullsLast) => methodOf(apply(col), "asc_nulls_last", Seq()) + case ir.SortOrder(col, ir.Ascending, _) => methodOf(apply(col), "asc", Seq()) + case ir.SortOrder(col, ir.Descending, ir.NullsFirst) => methodOf(apply(col), "desc_nulls_first", Seq()) + case ir.SortOrder(col, ir.Descending, ir.NullsLast) => methodOf(apply(col), "desc_nulls_last", Seq()) + case ir.SortOrder(col, ir.Descending, _) => methodOf(apply(col), "desc", Seq()) + case ir.SortOrder(col, _, _) => apply(col) + } + + private def window(w: ir.Window): ir.Expression = { + var windowSpec: ir.Expression = ir.Name("Window") + windowSpec = w.partition_spec match { + case Nil => windowSpec + case _ => methodOf(windowSpec, "partitionBy", w.partition_spec.map(apply)) + } + windowSpec = w.sort_order match { + case Nil => windowSpec + case _ => methodOf(windowSpec, "orderBy", w.sort_order.map(apply)) + } + windowSpec = w.frame_spec match { + case None => windowSpec + case Some(value) => windowFrame(windowSpec, value) + } + windowSpec = ImportClassSideEffect(windowSpec, module = "pyspark.sql.window", klass = "Window") + val fn = apply(w.window_function) + methodOf(fn, "over", Seq(windowSpec)) + } + + private def windowFrame(windowSpec: ir.Expression, frame: ir.WindowFrame): ir.Expression = frame match { + case ir.WindowFrame(ir.RangeFrame, left, right) => + methodOf(windowSpec, "rangeBetween", Seq(frameBoundary(left), frameBoundary(right))) + case ir.WindowFrame(ir.RowsFrame, left, right) => + methodOf(windowSpec, "rowsBetween", Seq(frameBoundary(left), frameBoundary(right))) + case _ => windowSpec + } + + private def frameBoundary(boundary: ir.FrameBoundary): ir.Expression = boundary match { + case ir.CurrentRow => py.Attribute(ir.Name("Window"), ir.Name("currentRow")) + case ir.UnboundedPreceding => py.Attribute(ir.Name("Window"), ir.Name("unboundedPreceding")) + case ir.UnboundedFollowing => py.Attribute(ir.Name("Window"), ir.Name("unboundedFollowing")) + case ir.PrecedingN(n) => n + case ir.FollowingN(n) => n + case ir.NoBoundary => ir.Name("noBoundary") + } + + private def caseWhenBranches(branches: Seq[ir.WhenBranch], otherwise: Option[ir.Expression]) = { + val when = F("when", Seq(apply(branches.head.condition), apply(branches.head.expression))) + val body = branches.foldLeft(when) { case (acc, branch) => + methodOf(acc, "when", Seq(apply(branch.condition), apply(branch.expression))) + } + otherwise match { + case Some(value) => methodOf(body, "otherwise", Seq(apply(value))) + case None => body + } + } + + private def dateLiteral(epochDay: Long): ir.Expression = { + val raw = ir.StringLiteral(LocalDate.ofEpochDay(epochDay).format(dateFormat)) + methodOf(F("lit", Seq(raw)), "cast", Seq(ir.StringLiteral("date"))) + } + + private def timestampLiteral(epochSecond: Long): ir.Expression = { + val raw = ir.StringLiteral( + LocalDateTime + .from(ZonedDateTime.ofInstant(Instant.ofEpochSecond(epochSecond), ZoneId.of("UTC"))) + .format(timeFormat)) + methodOf(F("lit", Seq(raw)), "cast", Seq(ir.StringLiteral("timestamp"))) + } + + private def bitwise(expr: ir.Expression): ir.Expression = expr match { + case ir.BitwiseOr(left, right) => methodOf(left, "bitwiseOR", Seq(right)) + case ir.BitwiseAnd(left, right) => methodOf(left, "bitwiseAND", Seq(right)) + case ir.BitwiseXor(left, right) => methodOf(left, "bitwiseXOR", Seq(right)) + case ir.BitwiseNot(child) => ir.Not(child) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkStatements.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkStatements.scala new file mode 100644 index 0000000000..cd53f8265d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/rules/PySparkStatements.scala @@ -0,0 +1,40 @@ +package com.databricks.labs.remorph.generators.py.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.generators.py + +object PySparkStatements { + def apply(plan: ir.LogicalPlan): py.Module = plan match { + case ir.Batch(children) => + val statements = children.map(Action) + py.Module(statements) + } +} + +case class Action(plan: ir.LogicalPlan) extends py.LeafStatement + +class PySparkStatements(val expr: ir.Rule[ir.Expression]) extends ir.Rule[py.Statement] with PyCommon { + override def apply(in: py.Statement): py.Statement = in match { + case py.Module(statements) => py.Module(statements.map(apply)) + case Action(logical) => py.ExprStatement(plan(pythonize(logical))) + } + + private def pythonize(logical: ir.LogicalPlan): ir.LogicalPlan = { + logical transformExpressionsDown { case e: ir.Expression => + expr(e) + } + } + + private def plan(logical: ir.LogicalPlan): ir.Expression = logical match { + case ir.PlanComment(input, text) => py.Comment(plan(input), text) + case ir.NamedTable(name, _, _) => methodOf(ir.Name("spark"), "table", Seq(ir.StringLiteral(name))) + case ir.NoTable => methodOf(ir.Name("spark"), "emptyDataFrame", Seq()) + case ir.Filter(input, condition) => methodOf(plan(input), "filter", Seq(condition)) + case ir.Project(input, projectList) => methodOf(plan(input), "select", projectList) + case ir.Limit(input, limit) => methodOf(plan(input), "limit", Seq(limit)) + case ir.Offset(input, offset) => methodOf(plan(input), "offset", Seq(offset)) + case ir.Sort(input, sortList, _) => methodOf(plan(input), "orderBy", sortList) + case ir.Aggregate(input, ir.GroupBy, exprs, _) => methodOf(plan(input), "groupBy", exprs) + case ir.Deduplicate(input, keys, _, _) => methodOf(plan(input), "dropDuplicates", keys) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/BaseSQLGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/BaseSQLGenerator.scala new file mode 100644 index 0000000000..792cbb292b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/BaseSQLGenerator.scala @@ -0,0 +1,43 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.intermediate.{RemorphError, TreeNode, UnexpectedNode} +import com.databricks.labs.remorph.{PartialResult, intermediate => ir} + +abstract class BaseSQLGenerator[In <: TreeNode[In]] extends CodeGenerator[In] { + def partialResult(tree: In): SQL = partialResult(tree, UnexpectedNode(tree.toString)) + def partialResult(trees: Seq[Any], err: RemorphError): SQL = lift( + PartialResult(s"!!! ${trees.mkString(" | ")} !!!", err)) + def partialResult(tree: Any, err: RemorphError): SQL = lift(PartialResult(s"!!! $tree !!!", err)) + + /** + * Generate an inline comment that describes the error that was detected in the unresolved relation, + * which could be parsing errors, or could be something that is not yet implemented. Implemented as + * a separate method as we may wish to do more with this in the future. + */ + protected def describeError(relation: ir.Unresolved[_]): SQL = { + val ruleText = + if (relation.ruleText.trim.isEmpty) "" + else + relation.ruleText + .split("\n") + .map { + case line if line.trim.isEmpty => line.trim + case line => line.replaceAll(" +$", "") + } + .mkString(" ", "\n ", "") + + val message = + if (relation.message.trim.isEmpty) "" + else + relation.message + .split("\n") + .map { + case line if line.trim.isEmpty => line.trim + case line => line.replaceAll(" +$", "") + } + .mkString(" ", "\n ", "") + + code"/* The following issues were detected:\n\n$message\n$ruleText\n */" + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/DataTypeGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/DataTypeGenerator.scala new file mode 100644 index 0000000000..d546ebe321 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/DataTypeGenerator.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.{OkResult, PartialResult, TransformationConstructors, intermediate => ir} + +/** + * @see + * https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html + */ +object DataTypeGenerator extends TransformationConstructors { + + def generateDataType(dt: ir.DataType): SQL = dt match { + case ir.NullType => lift(OkResult("VOID")) + case ir.BooleanType => lift(OkResult("BOOLEAN")) + case ir.BinaryType => lift(OkResult("BINARY")) + case ir.ShortType => lift(OkResult("SMALLINT")) + case ir.TinyintType => lift(OkResult("TINYINT")) + case ir.IntegerType => lift(OkResult("INT")) + case ir.LongType => lift(OkResult("BIGINT")) + case ir.FloatType => lift(OkResult("FLOAT")) + case ir.DoubleType => lift(OkResult("DOUBLE")) + case ir.DecimalType(precision, scale) => + val arguments = precision.toSeq ++ scale.toSeq + if (arguments.isEmpty) { + lift(OkResult("DECIMAL")) + } else { + code"DECIMAL${arguments.mkString("(", ", ", ")")}" + } + case ir.StringType => lift(OkResult("STRING")) + case ir.DateType => lift(OkResult("DATE")) + case ir.TimestampType => lift(OkResult("TIMESTAMP")) + case ir.TimestampNTZType => lift(OkResult("TIMESTAMP_NTZ")) + case ir.ArrayType(elementType) => code"ARRAY<${generateDataType(elementType)}>" + case ir.StructType(fields) => + val fieldTypes = fields + .map { case ir.StructField(name, dataType, nullable, _) => + val isNullable = if (nullable) "" else " NOT NULL" + code"$name:${generateDataType(dataType)}$isNullable" + } + .mkCode(",") + code"STRUCT<$fieldTypes>" + case ir.MapType(keyType, valueType) => + code"MAP<${generateDataType(keyType)}, ${generateDataType(valueType)}>" + case ir.VarcharType(size) => code"VARCHAR${maybeSize(size)}" + case ir.CharType(size) => code"CHAR${maybeSize(size)}" + case ir.VariantType => code"VARIANT" + case ir.JinjaAsDataType(text) => code"$text" + case _ => lift(PartialResult(s"!!! $dt !!!", ir.UnsupportedDataType(dt.toString))) + } + + private def maybeSize(size: Option[Int]): String = size.map(s => s"($s)").getOrElse("") +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/ExpressionGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/ExpressionGenerator.scala new file mode 100644 index 0000000000..22fffa47e2 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/ExpressionGenerator.scala @@ -0,0 +1,502 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.{Generating, OkResult, TransformationConstructors, intermediate => ir} + +import java.time._ +import java.time.format.DateTimeFormatter +import java.util.Locale + +class ExpressionGenerator extends BaseSQLGenerator[ir.Expression] with TransformationConstructors { + private[this] val dateFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd") + private[this] val timeFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.of("UTC")) + + override def generate(tree: ir.Expression): SQL = + expression(tree) + + def expression(expr: ir.Expression): SQL = { + val sql: SQL = expr match { + case ir.Like(subject, pattern, escape) => likeSingle(subject, pattern, escape, caseSensitive = true) + case ir.LikeAny(subject, patterns) => likeMultiple(subject, patterns, caseSensitive = true, all = false) + case ir.LikeAll(subject, patterns) => likeMultiple(subject, patterns, caseSensitive = true, all = true) + case ir.ILike(subject, pattern, escape) => likeSingle(subject, pattern, escape, caseSensitive = false) + case ir.ILikeAny(subject, patterns) => likeMultiple(subject, patterns, caseSensitive = false, all = false) + case ir.ILikeAll(subject, patterns) => likeMultiple(subject, patterns, caseSensitive = false, all = true) + case r: ir.RLike => rlike(r) + case _: ir.Bitwise => bitwise(expr) + case _: ir.Arithmetic => arithmetic(expr) + case b: ir.Between => between(b) + case _: ir.Predicate => predicate(expr) + case l: ir.Literal => literal(l) + case a: ir.ArrayExpr => arrayExpr(a) + case m: ir.MapExpr => mapExpr(m) + case s: ir.StructExpr => structExpr(s) + case i: ir.IsNull => isNull(i) + case i: ir.IsNotNull => isNotNull(i) + case ir.UnresolvedAttribute(name, _, _, _, _, _, _) => lift(OkResult(name)) + case d: ir.Dot => dot(d) + case i: ir.Id => nameOrPosition(i) + case o: ir.ObjectReference => objectReference(o) + case a: ir.Alias => alias(a) + case d: ir.Distinct => distinct(d) + case s: ir.Star => star(s) + case c: ir.Cast => cast(c) + case t: ir.TryCast => tryCast(t) + case col: ir.Column => column(col) + case _: ir.DeleteAction => code"DELETE" + case ia: ir.InsertAction => insertAction(ia) + case ua: ir.UpdateAction => updateAction(ua) + case a: ir.Assign => assign(a) + case opts: ir.Options => options(opts) + case i: ir.KnownInterval => interval(i) + case s: ir.ScalarSubquery => scalarSubquery(s) + case c: ir.Case => caseWhen(c) + case w: ir.Window => window(w) + case o: ir.SortOrder => sortOrder(o) + case ir.Exists(subquery) => + withGenCtx(ctx => code"EXISTS (${ctx.logical.generate(subquery)})") + case a: ir.ArrayAccess => arrayAccess(a) + case j: ir.JsonAccess => jsonAccess(j) + case l: ir.LambdaFunction => lambdaFunction(l) + case v: ir.Variable => variable(v) + case s: ir.SchemaReference => schemaReference(s) + case r: ir.RegExpExtract => regexpExtract(r) + case t: ir.TimestampDiff => timestampDiff(t) + case t: ir.TimestampAdd => timestampAdd(t) + case e: ir.Extract => extract(e) + case c: ir.Concat => concat(c) + case i: ir.In => in(i) + case ir.JinjaAsExpression(text) => code"$text" + + // keep this case after every case involving an `Fn`, otherwise it will make said case unreachable + case fn: ir.Fn => code"${fn.prettyName}(${commas(fn.children)})" + + // We see an unresolved for parsing errors, when we have no visitor for a given rule, + // when something went wrong with IR generation, or when we have a visitor but it is not + // yet implemented. + case u: ir.Unresolved[_] => describeError(u) + + case null => code"" // don't fail transpilation if the expression is null + case x => partialResult(x) + } + + updatePhase { case g: Generating => + g.copy(currentNode = expr) + }.flatMap(_ => sql) + } + + private def structExpr(s: ir.StructExpr): SQL = { + s.fields + .map { + case a: ir.Alias => generate(a) + case s: ir.Star => code"*" + } + .mkCode("STRUCT(", ", ", ")") + } + + private def jsonAccess(j: ir.JsonAccess): SQL = { + val path = jsonPath(j.path).mkCode + val anchorPath = path.map(p => if (p.head == '.') ':' +: p.tail else p) + code"${expression(j.json)}$anchorPath" + } + + private def jsonPath(j: ir.Expression): Seq[SQL] = { + j match { + case ir.Id(name, _) if isValidIdentifier(name) => Seq(code".$name") + case ir.Id(name, _) => Seq(s"['$name']".replace("'", "\"")).map(OkResult(_)).map(lift) + case ir.IntLiteral(value) => Seq(code"[$value]") + case ir.StringLiteral(value) => Seq(code"['$value']") + case ir.Dot(left, right) => jsonPath(left) ++ jsonPath(right) + case i: ir.Expression => Seq(partialResult(i)) + } + } + + private def isNull(i: ir.IsNull) = code"${expression(i.left)} IS NULL" + private def isNotNull(i: ir.IsNotNull) = code"${expression(i.left)} IS NOT NULL" + + private def interval(interval: ir.KnownInterval): SQL = { + val iType = interval.iType match { + case ir.YEAR_INTERVAL => "YEAR" + case ir.MONTH_INTERVAL => "MONTH" + case ir.WEEK_INTERVAL => "WEEK" + case ir.DAY_INTERVAL => "DAY" + case ir.HOUR_INTERVAL => "HOUR" + case ir.MINUTE_INTERVAL => "MINUTE" + case ir.SECOND_INTERVAL => "SECOND" + case ir.MILLISECOND_INTERVAL => "MILLISECOND" + case ir.MICROSECOND_INTERVAL => "MICROSECOND" + case ir.NANOSECOND_INTERVAL => "NANOSECOND" + } + code"INTERVAL ${generate(interval.value)} ${iType}" + } + + private def options(opts: ir.Options): SQL = { + // First gather the options that are set by expressions + val exprOptions = opts.expressionOpts + .map { case (key, expression) => + code" ${key} = ${generate(expression)}\n" + } + .toSeq + .mkCode + val exprStr = exprOptions.nonEmpty.flatMap { nonEmpty => + if (nonEmpty) { + code" Expression options:\n\n${exprOptions}\n" + } else { + code"" + } + } + val stringOptions = opts.stringOpts.map { case (key, value) => + s" ${key} = '${value}'\n" + }.mkString + val stringStr = if (stringOptions.nonEmpty) { + s" String options:\n\n${stringOptions}\n" + } else { + "" + } + + val boolOptions = opts.boolFlags.map { case (key, value) => + s" ${key} ${if (value) { "ON" } + else { "OFF" }}\n" + }.mkString + val boolStr = if (boolOptions.nonEmpty) { + s" Boolean options:\n\n${boolOptions}\n" + } else { + "" + } + + val autoOptions = opts.autoFlags.map { key => + s" ${key} AUTO\n" + }.mkString + val autoStr = if (autoOptions.nonEmpty) { + s" Auto options:\n\n${autoOptions}\n" + } else { + "" + } + val optString = code"${exprStr}${stringStr}${boolStr}${autoStr}" + optString.nonEmpty.flatMap { nonEmpty => + if (nonEmpty) { + code"/*\n The following statement was originally given the following OPTIONS:\n\n${optString}\n */\n" + } else { + code"" + } + } + } + + private def assign(assign: ir.Assign): SQL = { + code"${expression(assign.left)} = ${expression(assign.right)}" + } + + private def column(column: ir.Column): SQL = { + val objectRef = column.tableNameOrAlias.map(or => code"${generateObjectReference(or)}.").getOrElse("") + code"$objectRef${nameOrPosition(column.columnName)}" + } + + private def insertAction(insertAction: ir.InsertAction): SQL = { + val (cols, values) = insertAction.assignments.map { assign => + (assign.left, assign.right) + }.unzip + code"INSERT (${commas(cols)}) VALUES (${commas(values)})" + } + + private def updateAction(updateAction: ir.UpdateAction): SQL = { + code"UPDATE SET ${commas(updateAction.assignments)}" + } + + private def arithmetic(expr: ir.Expression): SQL = expr match { + case ir.UMinus(child) => code"-${expression(child)}" + case ir.UPlus(child) => code"+${expression(child)}" + case ir.Multiply(left, right) => code"${expression(left)} * ${expression(right)}" + case ir.Divide(left, right) => code"${expression(left)} / ${expression(right)}" + case ir.Mod(left, right) => code"${expression(left)} % ${expression(right)}" + case ir.Add(left, right) => code"${expression(left)} + ${expression(right)}" + case ir.Subtract(left, right) => code"${expression(left)} - ${expression(right)}" + } + + private def bitwise(expr: ir.Expression): SQL = expr match { + case ir.BitwiseOr(left, right) => code"${expression(left)} | ${expression(right)}" + case ir.BitwiseAnd(left, right) => code"${expression(left)} & ${expression(right)}" + case ir.BitwiseXor(left, right) => code"${expression(left)} ^ ${expression(right)}" + case ir.BitwiseNot(child) => code"~${expression(child)}" + } + + private def likeSingle( + subject: ir.Expression, + pattern: ir.Expression, + escapeChar: Option[ir.Expression], + caseSensitive: Boolean): SQL = { + val op = if (caseSensitive) { "LIKE" } + else { "ILIKE" } + val escape = escapeChar.map(char => code" ESCAPE ${expression(char)}").getOrElse(code"") + code"${expression(subject)} $op ${expression(pattern)}$escape" + } + + private def likeMultiple( + subject: ir.Expression, + patterns: Seq[ir.Expression], + caseSensitive: Boolean, + all: Boolean): SQL = { + val op = if (caseSensitive) { "LIKE" } + else { "ILIKE" } + val allOrAny = if (all) { "ALL" } + else { "ANY" } + code"${expression(subject)} $op $allOrAny ${patterns.map(expression(_)).mkCode("(", ", ", ")")}" + } + + private def rlike(r: ir.RLike): SQL = { + code"${expression(r.left)} RLIKE ${expression(r.right)}" + } + + private def predicate(expr: ir.Expression): SQL = expr match { + case a: ir.And => andPredicate(a) + case o: ir.Or => orPredicate(o) + case ir.Not(child) => code"NOT (${expression(child)})" + case ir.Equals(left, right) => code"${expression(left)} = ${expression(right)}" + case ir.NotEquals(left, right) => code"${expression(left)} != ${expression(right)}" + case ir.LessThan(left, right) => code"${expression(left)} < ${expression(right)}" + case ir.LessThanOrEqual(left, right) => code"${expression(left)} <= ${expression(right)}" + case ir.GreaterThan(left, right) => code"${expression(left)} > ${expression(right)}" + case ir.GreaterThanOrEqual(left, right) => code"${expression(left)} >= ${expression(right)}" + case _ => partialResult(expr) + } + + private def andPredicate(a: ir.And): SQL = a match { + case ir.And(ir.Or(ol, or), right) => + code"(${expression(ol)} OR ${expression(or)}) AND ${expression(right)}" + case ir.And(left, ir.Or(ol, or)) => + code"${expression(left)} AND (${expression(ol)} OR ${expression(or)})" + case ir.And(left, right) => code"${expression(left)} AND ${expression(right)}" + } + + private def orPredicate(a: ir.Or): SQL = a match { + case ir.Or(ir.And(ol, or), right) => + code"(${expression(ol)} AND ${expression(or)}) OR ${expression(right)}" + case ir.Or(left, ir.And(ol, or)) => + code"${expression(left)} OR (${expression(ol)} AND ${expression(or)})" + case ir.Or(left, right) => code"${expression(left)} OR ${expression(right)}" + } + + private def literal(l: ir.Literal): SQL = + l match { + case ir.Literal(_, ir.NullType) => code"NULL" + case ir.Literal(bytes: Array[Byte], ir.BinaryType) => lift(OkResult(bytes.map("%02X" format _).mkString)) + case ir.Literal(value, ir.BooleanType) => lift(OkResult(value.toString.toLowerCase(Locale.getDefault))) + case ir.Literal(value, ir.ShortType) => lift(OkResult(value.toString)) + case ir.IntLiteral(value) => lift(OkResult(value.toString)) + case ir.Literal(value, ir.LongType) => lift(OkResult(value.toString)) + case ir.FloatLiteral(value) => lift(OkResult(value.toString)) + case ir.DoubleLiteral(value) => lift(OkResult(value.toString)) + case ir.DecimalLiteral(value) => lift(OkResult(value.toString)) + case ir.Literal(value: String, ir.StringType) => singleQuote(value) + case ir.Literal(epochDay: Long, ir.DateType) => + val dateStr = singleQuote(LocalDate.ofEpochDay(epochDay).format(dateFormat)) + code"CAST($dateStr AS DATE)" + case ir.Literal(epochSecond: Long, ir.TimestampType) => + val timestampStr = singleQuote( + LocalDateTime + .from(ZonedDateTime.ofInstant(Instant.ofEpochSecond(epochSecond), ZoneId.of("UTC"))) + .format(timeFormat)) + code"CAST($timestampStr AS TIMESTAMP)" + case _ => partialResult(l, ir.UnsupportedDataType(l.dataType.toString)) + } + + private def arrayExpr(a: ir.ArrayExpr): SQL = { + val elementsStr = commas(a.children) + code"ARRAY($elementsStr)" + } + + private def mapExpr(m: ir.MapExpr): SQL = { + val entriesStr = m.map + .map { case (key, value) => + code"${expression(key)}, ${expression(value)}" + } + .toSeq + .mkCode(", ") + code"MAP($entriesStr)" + } + + private def nameOrPosition(id: ir.NameOrPosition): SQL = id match { + case ir.Id(name, true) => code"`$name`" + case ir.Id(name, false) => ok(name) + case ir.Name(name) => ok(name) + case p @ ir.Position(_) => partialResult(p) + } + + private def alias(alias: ir.Alias): SQL = { + code"${expression(alias.expr)} AS ${expression(alias.name)}" + } + + private def distinct(distinct: ir.Distinct): SQL = { + code"DISTINCT ${expression(distinct.expression)}" + } + + private def star(star: ir.Star): SQL = { + val objectRef = star.objectName.map(or => code"${generateObjectReference(or)}.").getOrElse(code"") + code"$objectRef*" + } + + private def generateObjectReference(reference: ir.ObjectReference): SQL = { + (reference.head +: reference.tail).map(nameOrPosition).mkCode(".") + } + + private def cast(cast: ir.Cast): SQL = { + castLike("CAST", cast.expr, cast.dataType) + } + + private def tryCast(tryCast: ir.TryCast): SQL = { + castLike("TRY_CAST", tryCast.expr, tryCast.dataType) + } + + private def castLike(prettyName: String, expr: ir.Expression, dataType: ir.DataType): SQL = { + val e = expression(expr) + val dt = DataTypeGenerator.generateDataType(dataType) + code"$prettyName($e AS $dt)" + } + + private def dot(dot: ir.Dot): SQL = { + code"${expression(dot.left)}.${expression(dot.right)}" + } + + private def objectReference(objRef: ir.ObjectReference): SQL = { + (objRef.head +: objRef.tail).map(nameOrPosition).mkCode(".") + } + + private def caseWhen(c: ir.Case): SQL = { + val expr = c.expression.map(expression).toSeq + val branches = c.branches.map { branch => + code"WHEN ${expression(branch.condition)} THEN ${expression(branch.expression)}" + } + val otherwise = c.otherwise.map { o => code"ELSE ${expression(o)}" }.toSeq + val chunks = expr ++ branches ++ otherwise + chunks.mkCode("CASE ", " ", " END") + } + + private def in(inExpr: ir.In): SQL = { + val values = commas(inExpr.other) + val enclosed = values.map { sql => + if (sql.charAt(0) == '(' && sql.charAt(sql.length - 1) == ')') { + sql + } else { + "(" + sql + ")" + } + } + code"${expression(inExpr.left)} IN ${enclosed}" + } + + private def scalarSubquery(subquery: ir.ScalarSubquery): SQL = { + withGenCtx(ctx => { + val subcode = ctx.logical.generate(subquery.relation) + code"(${subcode})" + }) + } + + private def window(window: ir.Window): SQL = { + val expr = expression(window.window_function) + val partition = if (window.partition_spec.isEmpty) { code"" } + else { window.partition_spec.map(expression).mkCode("PARTITION BY ", ", ", "") } + val orderBy = if (window.sort_order.isEmpty) { code"" } + else { window.sort_order.map(sortOrder).mkCode(" ORDER BY ", ", ", "") } + val windowFrame = window.frame_spec + .map { frame => + val mode = frame.frame_type match { + case ir.RowsFrame => "ROWS" + case ir.RangeFrame => "RANGE" + } + val boundaries = frameBoundary(frame.lower) ++ frameBoundary(frame.upper) + val frameBoundaries = if (boundaries.size < 2) { boundaries.mkCode } + else { boundaries.mkCode("BETWEEN ", " AND ", "") } + code" $mode $frameBoundaries" + } + .getOrElse(code"") + if (window.ignore_nulls) { + return code"$expr IGNORE NULLS OVER ($partition$orderBy$windowFrame)" + } + + code"$expr OVER ($partition$orderBy$windowFrame)" + } + + private def frameBoundary(boundary: ir.FrameBoundary): Seq[SQL] = boundary match { + case ir.NoBoundary => Seq.empty + case ir.CurrentRow => Seq(code"CURRENT ROW") + case ir.UnboundedPreceding => Seq(code"UNBOUNDED PRECEDING") + case ir.UnboundedFollowing => Seq(code"UNBOUNDED FOLLOWING") + case ir.PrecedingN(n) => Seq(code"${expression(n)} PRECEDING") + case ir.FollowingN(n) => Seq(code"${expression(n)} FOLLOWING") + } + + private def sortOrder(order: ir.SortOrder): SQL = { + val orderBy = expression(order.child) + val direction = order.direction match { + case ir.Ascending => Seq(code"ASC") + case ir.Descending => Seq(code"DESC") + case ir.UnspecifiedSortDirection => Seq() + } + val nulls = order.nullOrdering match { + case ir.NullsFirst => Seq(code"NULLS FIRST") + case ir.NullsLast => Seq(code"NULLS LAST") + case ir.SortNullsUnspecified => Seq() + } + (Seq(orderBy) ++ direction ++ nulls).mkCode(" ") + } + + private def regexpExtract(extract: ir.RegExpExtract): SQL = { + val c = if (extract.c.isEmpty || extract.c.contains(ir.Literal(1))) { code"" } + else { code", ${expression(extract.c.get)}" } + code"${extract.prettyName}(${expression(extract.left)}, ${expression(extract.right)}$c)" + } + + private def arrayAccess(access: ir.ArrayAccess): SQL = { + code"${expression(access.array)}[${expression(access.index)}]" + } + + private def timestampDiff(diff: ir.TimestampDiff): SQL = { + code"${diff.prettyName}(${diff.unit}, ${expression(diff.start)}, ${expression(diff.end)})" + } + + private def timestampAdd(tsAdd: ir.TimestampAdd): SQL = { + code"${tsAdd.prettyName}(${tsAdd.unit}, ${expression(tsAdd.quantity)}, ${expression(tsAdd.timestamp)})" + } + + private def extract(e: ir.Extract): SQL = { + code"EXTRACT(${expression(e.left)} FROM ${expression(e.right)})" + } + + private def lambdaFunction(l: ir.LambdaFunction): SQL = { + val parameterList = l.arguments.map(lambdaArgument) + val parameters = if (parameterList.size > 1) { parameterList.mkCode("(", ", ", ")") } + else { parameterList.mkCode } + val body = expression(l.function) + code"$parameters -> $body" + } + + private def lambdaArgument(arg: ir.UnresolvedNamedLambdaVariable): SQL = { + lift(OkResult(arg.name_parts.mkString("."))) + } + + private def variable(v: ir.Variable): SQL = code"$${${v.name}}" + + private def concat(c: ir.Concat): SQL = { + val args = c.children.map(expression(_)) + if (c.children.size > 2) { + args.mkCode(" || ") + } else { + args.mkCode("CONCAT(", ", ", ")") + } + } + + private def schemaReference(s: ir.SchemaReference): SQL = { + val ref = s.columnName match { + case d: ir.Dot => expression(d) + case i: ir.Id => expression(i) + case _ => code"JSON_COLUMN" + } + code"{${ref.map(_.toUpperCase(Locale.getDefault()))}_SCHEMA}" + } + private def singleQuote(s: String): SQL = code"'${s.replace("'", "\\'")}'" + private def isValidIdentifier(s: String): Boolean = + (s.head.isLetter || s.head == '_') && s.forall(x => x.isLetterOrDigit || x == '_') + + private def between(b: ir.Between): SQL = { + code"${expression(b.exp)} BETWEEN ${expression(b.lower)} AND ${expression(b.upper)}" + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala new file mode 100644 index 0000000000..2fb12b5059 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala @@ -0,0 +1,523 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.{Generating, OkResult, TransformationConstructors, intermediate => ir} + +class LogicalPlanGenerator( + val expr: ExpressionGenerator, + val optGen: OptionGenerator, + val explicitDistinct: Boolean = false) + extends BaseSQLGenerator[ir.LogicalPlan] + with TransformationConstructors { + + override def generate(tree: ir.LogicalPlan): SQL = { + + val sql: SQL = tree match { + case b: ir.Batch => batch(b) + case w: ir.WithCTE => cte(w) + case p: ir.Project => project(p) + case ir.NamedTable(id, _, _) => lift(OkResult(id)) + case ir.Filter(input, condition) => { + val source = input match { + // enclose subquery in parenthesis + case project: ir.Project => code"(${generate(project)})" + case _ => code"${generate(input)}" + } + code"${source} WHERE ${expr.generate(condition)}" + } + case ir.Limit(input, limit) => + code"${generate(input)} LIMIT ${expr.generate(limit)}" + case ir.Offset(child, offset) => + code"${generate(child)} OFFSET ${expr.generate(offset)}" + case ir.Values(data) => + code"VALUES ${data.map(_.map(expr.generate(_)).mkCode("(", ",", ")")).mkCode(", ")}" + case ir.PlanComment(child, text) => code"/* $text */\n${generate(child)}" + case agg: ir.Aggregate => aggregate(agg) + case sort: ir.Sort => orderBy(sort) + case join: ir.Join => generateJoin(join) + case setOp: ir.SetOperation => setOperation(setOp) + case mergeIntoTable: ir.MergeIntoTable => merge(mergeIntoTable) + case withOptions: ir.WithOptions => generateWithOptions(withOptions) + case s: ir.SubqueryAlias => subQueryAlias(s) + case t: ir.TableAlias => tableAlias(t) + case d: ir.Deduplicate => deduplicate(d) + case u: ir.UpdateTable => updateTable(u) + case i: ir.InsertIntoTable => insert(i) + case ir.DeleteFromTable(target, None, where, None, None) => delete(target, where) + case c: ir.CreateTableCommand => createTable(c) + case rt: ir.ReplaceTableCommand => replaceTable(rt) + case t: ir.TableSample => tableSample(t) + case a: ir.AlterTableCommand => alterTable(a) + case l: ir.Lateral => lateral(l) + case c: ir.CreateTableParams => createTableParams(c) + case ir.JinjaAsStatement(text) => code"$text" + + // We see an unresolved for parsing errors, when we have no visitor for a given rule, + // when something went wrong with IR generation, or when we have a visitor but it is not + // yet implemented. + case u: ir.Unresolved[_] => describeError(u) + case ir.NoopNode => code"" + + case null => code"" // don't fail transpilation if the plan is null + case x => partialResult(x) + } + + updatePhase { case g: Generating => + g.copy(currentNode = tree) + }.flatMap(_ => sql) + } + + private def batch(b: ir.Batch): SQL = { + val seqSql = b.children.map(generate(_)).sequence + seqSql.map { seq => + seq + .map { elem => + if (!elem.endsWith("*/") && !elem.startsWith("_!Jinja")) s"$elem;" + else elem + } + .mkString("\n") + } + } + + private def createTableParams(crp: ir.CreateTableParams): SQL = { + + // We build the overall table creation statement differently depending on whether the primitive is + // a CREATE TABLE or a CREATE TABLE AS (SELECT ...) + crp.create match { + case ct: ir.CreateTable => + // we build the columns using the raw column declarations, adding in any col constraints + // and any column options + val columns = ct.schema match { + case ir.StructType(fields) => + fields + .map { col => + val constraints = crp.colConstraints.getOrElse(col.name, Seq.empty) + val options = crp.colOptions.getOrElse(col.name, Seq.empty) + genColumnDecl(col, constraints, options) + } + .mkCode(", ") + } + + // We now generate any table level constraints + val tableConstraintStr = crp.constraints.map(constraint).mkCode(", ") + val tableConstraintStrWithComma = + tableConstraintStr.nonEmpty.flatMap(nonEmpty => if (nonEmpty) code", $tableConstraintStr" else code"") + + // record any table level options + val tableOptions = crp.options.map(_.map(optGen.generateOption).mkCode("\n ")).getOrElse(code"") + + val tableOptionsComment = { + tableOptions.isEmpty.flatMap { isEmpty => + if (isEmpty) code"" else code" The following options are unsupported:\n\n $tableOptions\n" + } + } + val indicesStr = crp.indices.map(constraint).mkCode(" \n") + val indicesComment = + indicesStr.isEmpty.flatMap { isEmpty => + if (isEmpty) code"" + else code" The following index directives are unsupported:\n\n $indicesStr*/\n" + } + val leadingComment = { + for { + toc <- tableOptionsComment + ic <- indicesComment + } yield (toc, ic) match { + case ("", "") => "" + case (a, "") => s"/*\n$a*/\n" + case ("", b) => s"/*\n$b*/\n" + case (a, b) => s"/*\n$a\n$b*/\n" + } + } + + code"${leadingComment}CREATE TABLE ${ct.table_name} (${columns}${tableConstraintStrWithComma})" + + case ctas: ir.CreateTableAsSelect => code"CREATE TABLE ${ctas.table_name} AS ${generate(ctas.query)}" + case rtas: ir.ReplaceTableAsSelect => + code"CREATE OR REPLACE TABLE ${rtas.table_name} AS ${generate(rtas.query)}" + } + } + + private def genColumnDecl( + col: ir.StructField, + constraints: Seq[ir.Constraint], + options: Seq[ir.GenericOption]): SQL = { + val dataType = DataTypeGenerator.generateDataType(col.dataType) + val dataTypeStr = if (!col.nullable) code"$dataType NOT NULL" else dataType + val constraintsStr = constraints.map(constraint(_)).mkCode(" ") + val constraintsGen = constraintsStr.nonEmpty.flatMap { nonEmpty => + if (nonEmpty) code" $constraintsStr" else code"" + } + val optionsStr = options.map(optGen.generateOption(_)).mkCode(" ") + val optionsComment = optionsStr.nonEmpty.flatMap { nonEmpty => if (nonEmpty) code" /* $optionsStr */" else code"" } + code"${col.name} ${dataTypeStr}${constraintsGen}${optionsComment}" + } + + private def alterTable(a: ir.AlterTableCommand): SQL = { + val operation = buildTableAlteration(a.alterations) + code"ALTER TABLE ${a.tableName} $operation" + } + + private def buildTableAlteration(alterations: Seq[ir.TableAlteration]): SQL = { + // docs:https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-alter-table.html#parameters + // docs:https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-alter-table#syntax + // ADD COLUMN can be Seq[ir.TableAlteration] + // DROP COLUMN will be ir.TableAlteration since it stored the list of columns + // DROP CONSTRAINTS BY NAME is ir.TableAlteration + // RENAME COLUMN/ RENAME CONSTRAINTS Always be ir.TableAlteration + // ALTER COLUMN IS A Seq[ir.TableAlternations] Data Type Change, Constraint Changes etc + alterations map { + case ir.AddColumn(columns) => code"ADD COLUMN ${buildAddColumn(columns)}" + case ir.DropColumns(columns) => code"DROP COLUMN ${columns.mkString(", ")}" + case ir.DropConstraintByName(constraints) => code"DROP CONSTRAINT ${constraints}" + case ir.RenameColumn(oldName, newName) => code"RENAME COLUMN ${oldName} to ${newName}" + case x => partialResult(x, ir.UnexpectedTableAlteration(x.toString)) + } mkCode ", " + } + + private def buildAddColumn(columns: Seq[ir.ColumnDeclaration]): SQL = { + columns + .map { c => + val dataType = DataTypeGenerator.generateDataType(c.dataType) + val constraints = c.constraints.map(constraint(_)).mkCode(" ") + code"${c.name} $dataType $constraints" + } + .mkCode(", ") + } + + // @see https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-lateral-view.html + private def lateral(lateral: ir.Lateral): SQL = + lateral match { + case ir.Lateral(ir.TableFunction(fn), isOuter, isView) => + val outer = if (isOuter) " OUTER" else "" + val view = if (isView) " VIEW" else "" + code"LATERAL$view$outer ${expr.generate(fn)}" + case _ => partialResult(lateral) + } + + // @see https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-sampling.html + private def tableSample(t: ir.TableSample): SQL = { + val sampling = t.samplingMethod match { + case ir.RowSamplingProbabilistic(probability) => s"$probability PERCENT" + case ir.RowSamplingFixedAmount(amount) => s"$amount ROWS" + case ir.BlockSampling(probability) => s"BUCKET $probability OUT OF 1" + } + val seed = t.seed.map(s => s" REPEATABLE ($s)").getOrElse("") + code"(${generate(t.child)}) TABLESAMPLE ($sampling)$seed" + } + + private def createTable(createTable: ir.CreateTableCommand): SQL = { + val columns = createTable.columns + .map { col => + val dataType = DataTypeGenerator.generateDataType(col.dataType) + val constraints = col.constraints.map(constraint(_)).mkCode(" ") + code"${col.name} $dataType $constraints" + } + code"CREATE TABLE ${createTable.name} (${columns.mkCode(", ")})" + } + + private def replaceTable(createTable: ir.ReplaceTableCommand): SQL = { + val columns = createTable.columns + .map { col => + val dataType = DataTypeGenerator.generateDataType(col.dataType) + val constraints = col.constraints.map(constraint).mkCode(" ") + code"${col.name} $dataType $constraints" + } + code"CREATE OR REPLACE TABLE ${createTable.name} (${columns.mkCode(", ")})" + } + + private def constraint(c: ir.Constraint): SQL = c match { + case unique: ir.Unique => generateUniqueConstraint(unique) + case ir.Nullability(nullable) => lift(OkResult(if (nullable) "NULL" else "NOT NULL")) + case pk: ir.PrimaryKey => generatePrimaryKey(pk) + case fk: ir.ForeignKey => generateForeignKey(fk) + case ir.NamedConstraint(name, unnamed) => code"CONSTRAINT $name ${constraint(unnamed)}" + case ir.UnresolvedConstraint(inputText) => code"/** $inputText **/" + case ir.CheckConstraint(e) => code"CHECK (${expr.generate(e)})" + case ir.DefaultValueConstraint(value) => code"DEFAULT ${expr.generate(value)}" + case identity: ir.IdentityConstraint => generateIdentityConstraint(identity) + case ir.GeneratedAlways(expression) => code"GENERATED ALWAYS AS (${expr.generate(expression)})" + } + + private def generateIdentityConstraint(c: ir.IdentityConstraint): SQL = c match { + case ir.IdentityConstraint(None, None, true, false) => code"GENERATED ALWAYS AS IDENTITY" + case ir.IdentityConstraint(None, None, false, true) => code"GENERATED BY DEFAULT AS IDENTITY" + case ir.IdentityConstraint(Some(seed), Some(step), false, true) => + code"GENERATED BY DEFAULT AS IDENTITY (START WITH $seed INCREMENT BY $step)" + case ir.IdentityConstraint(Some(seed), None, false, true) => + code"GENERATED BY DEFAULT AS IDENTITY (START WITH $seed)" + case ir.IdentityConstraint(None, Some(step), false, true) => + code"GENERATED BY DEFAULT AS IDENTITY (INCREMENT BY $step)" + // IdentityConstraint(None, None, true, true) This is an incorrect representation parser should not generate this + // for IdentityConstraint(None, None, false, false) we will generate empty + case _ => code"" + + } + + private def generateForeignKey(fk: ir.ForeignKey): SQL = { + val colNames = fk.tableCols match { + case "" => "" + case cols => s"(${cols}) " + } + val commentOptions = optGen.generateOptionList(fk.options) match { + case "" => "" + case options => s" /* Unsupported: $options */" + } + code"FOREIGN KEY ${colNames}REFERENCES ${fk.refObject}(${fk.refCols})$commentOptions" + } + + private def generatePrimaryKey(key: ir.PrimaryKey): SQL = { + val columns = key.columns.map(_.mkString("(", ", ", ")")).getOrElse("") + val commentOptions = optGen.generateOptionList(key.options) match { + case "" => "" + case options => s" /* $options */" + } + val columnsStr = if (columns.isEmpty) "" else s" $columns" + code"PRIMARY KEY${columnsStr}${commentOptions}" + } + + private def generateUniqueConstraint(unique: ir.Unique): SQL = { + val columns = unique.columns.map(_.mkString("(", ", ", ")")).getOrElse("") + val columnStr = if (columns.isEmpty) "" else s" $columns" + val commentOptions = optGen.generateOptionList(unique.options) match { + case "" => "" + case options => s" /* $options */" + } + code"UNIQUE${columnStr}${commentOptions}" + } + + private def project(proj: ir.Project): SQL = { + val fromClause = if (proj.input != ir.NoTable) { + code" FROM ${generate(proj.input)}" + } else { + code"" + } + + // Don't put commas after unresolved expressions as they are error comments only + val sqlParts = proj.expressions + .map { + case u: ir.Unresolved[_] => expr.generate(u) + case exp: ir.Expression => expr.generate(exp).map(_ + ", ") + } + .sequence + .map(_.mkString.stripSuffix(", ")) + + code"SELECT $sqlParts$fromClause" + } + + private def orderBy(sort: ir.Sort): SQL = { + val orderStr = sort.order + .map { case ir.SortOrder(child, direction, nulls) => + val dir = direction match { + case ir.Ascending => "" + case ir.Descending => " DESC" + } + code"${expr.generate(child)}$dir ${nulls.sql}" + } + + code"${generate(sort.child)} ORDER BY ${orderStr.mkCode(", ")}" + } + + private def isLateralView(lp: ir.LogicalPlan): Boolean = { + lp.find { + case ir.Lateral(_, _, isView) => isView + case _ => false + }.isDefined + } + + private def generateJoin(join: ir.Join): SQL = { + val left = generate(join.left) + val right = generate(join.right) + if (join.join_condition.isEmpty && join.using_columns.isEmpty && join.join_type == ir.InnerJoin) { + if (isLateralView(join.right)) { + code"$left $right" + } else { + code"$left, $right" + } + } else { + val joinType = generateJoinType(join.join_type) + val joinClause = if (joinType.isEmpty) { "JOIN" } + else { joinType + " JOIN" } + val conditionOpt = join.join_condition.map(expr.generate(_)) + val condition = join.join_condition match { + case None => code"" + case Some(_: ir.And) | Some(_: ir.Or) => code"ON (${conditionOpt.get})" + case Some(_) => code"ON ${conditionOpt.get}" + } + val usingColumns = join.using_columns.mkString(", ") + val using = if (usingColumns.isEmpty) "" else s"USING ($usingColumns)" + for { + l <- left + r <- right + cond <- condition + } yield { + Seq(l, joinClause, r, cond, using).filterNot(_.isEmpty).mkString(" ") + } + } + } + + private def generateJoinType(joinType: ir.JoinType): String = joinType match { + case ir.InnerJoin => "INNER" + case ir.FullOuterJoin => "FULL OUTER" + case ir.LeftOuterJoin => "LEFT" + case ir.LeftSemiJoin => "LEFT SEMI" + case ir.LeftAntiJoin => "LEFT ANTI" + case ir.RightOuterJoin => "RIGHT" + case ir.CrossJoin => "CROSS" + case ir.NaturalJoin(ir.UnspecifiedJoin) => "NATURAL" + case ir.NaturalJoin(jt) => s"NATURAL ${generateJoinType(jt)}" + case ir.UnspecifiedJoin => "" + } + + private def setOperation(setOp: ir.SetOperation): SQL = { + if (setOp.allow_missing_columns) { + return partialResult(setOp) + } + if (setOp.by_name) { + return partialResult(setOp) + } + val op = setOp.set_op_type match { + case ir.UnionSetOp => "UNION" + case ir.IntersectSetOp => "INTERSECT" + case ir.ExceptSetOp => "EXCEPT" + case _ => return partialResult(setOp) + } + val duplicates = if (setOp.is_all) " ALL" else if (explicitDistinct) " DISTINCT" else "" + code"(${generate(setOp.left)}) $op$duplicates (${generate(setOp.right)})" + } + + // @see https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-dml-insert-into.html + private def insert(insert: ir.InsertIntoTable): SQL = { + val target = generate(insert.target) + val columns = + insert.columns.map(cols => cols.map(expr.generate(_)).mkCode("(", ", ", ")")).getOrElse(code"") + val values = generate(insert.values) + val output = insert.outputRelation.map(generate(_)).getOrElse(code"") + val options = insert.options.map(expr.generate(_)).getOrElse(code"") + val overwrite = if (insert.overwrite) "OVERWRITE TABLE" else "INTO" + code"INSERT $overwrite $target $columns $values$output$options" + } + + // @see https://docs.databricks.com/en/sql/language-manual/delta-update.html + private def updateTable(update: ir.UpdateTable): SQL = { + val target = generate(update.target) + val set = expr.commas(update.set) + val where = update.where.map(cond => code" WHERE ${expr.generate(cond)}").getOrElse("") + code"UPDATE $target SET $set$where" + } + + // @see https://docs.databricks.com/en/sql/language-manual/delta-delete-from.html + private def delete(target: ir.LogicalPlan, where: Option[ir.Expression]): SQL = { + val whereStr = where.map(cond => code" WHERE ${expr.generate(cond)}").getOrElse(code"") + code"DELETE FROM ${generate(target)}$whereStr" + } + + // @see https://docs.databricks.com/en/sql/language-manual/delta-merge-into.html + private def merge(mergeIntoTable: ir.MergeIntoTable): SQL = { + val target = generate(mergeIntoTable.targetTable) + val source = generate(mergeIntoTable.sourceTable) + val condition = expr.generate(mergeIntoTable.mergeCondition) + val matchedActions = + mergeIntoTable.matchedActions.map { action => + val conditionText = action.condition.map(cond => code" AND ${expr.generate(cond)}").getOrElse(code"") + code" WHEN MATCHED${conditionText} THEN ${expr.generate(action)}" + }.mkCode + + val notMatchedActions = + mergeIntoTable.notMatchedActions.map { action => + val conditionText = action.condition.map(cond => code" AND ${expr.generate(cond)}").getOrElse(code"") + code" WHEN NOT MATCHED${conditionText} THEN ${expr.generate(action)}" + }.mkCode + val notMatchedBySourceActions = + mergeIntoTable.notMatchedBySourceActions.map { action => + val conditionText = action.condition.map(cond => code" AND ${expr.generate(cond)}").getOrElse(code"") + code" WHEN NOT MATCHED BY SOURCE${conditionText} THEN ${expr.generate(action)}" + }.mkCode + code"""MERGE INTO $target + |USING $source + |ON $condition + |$matchedActions + |$notMatchedActions + |$notMatchedBySourceActions + |""".map(_.stripMargin) + } + + private def aggregate(aggregate: ir.Aggregate): SQL = { + val child = generate(aggregate.child) + val expressions = expr.commas(aggregate.grouping_expressions) + aggregate.group_type match { + case ir.GroupByAll => code"$child GROUP BY ALL" + case ir.GroupBy => + code"$child GROUP BY $expressions" + case ir.Pivot if aggregate.pivot.isDefined => + val pivot = aggregate.pivot.get + val col = expr.generate(pivot.col) + val values = pivot.values.map(expr.generate).mkCode(" IN(", ", ", ")") + code"$child PIVOT($expressions FOR $col$values)" + case a => partialResult(a, ir.UnsupportedGroupType(a.toString)) + } + } + private def generateWithOptions(withOptions: ir.WithOptions): SQL = { + val optionComments = expr.generate(withOptions.options) + val plan = generate(withOptions.input) + code"${optionComments}${plan}" + } + + private def cte(withCte: ir.WithCTE): SQL = { + val ctes = withCte.ctes + .map { + case ir.SubqueryAlias(child, alias, cols) => + val columns = cols.map(expr.generate(_)).mkCode("(", ", ", ")") + val columnsStr = if (cols.isEmpty) code"" else code" $columns" + val id = expr.generate(alias) + val sub = generate(child) + code"$id$columnsStr AS ($sub)" + case x => generate(x) + } + val query = generate(withCte.query) + code"WITH ${ctes.mkCode(", ")} $query" + } + + private def subQueryAlias(subQAlias: ir.SubqueryAlias): SQL = { + val subquery = subQAlias.child match { + case l: ir.Lateral => lateral(l) + case _ => code"(${generate(subQAlias.child)})" + } + val tableName = expr.generate(subQAlias.alias) + val table = + if (subQAlias.columnNames.isEmpty) { + code"AS $tableName" + } else { + // Added this to handle the case for POS Explode + // We have added two columns index and value as an alias for pos explode default columns (pos, col) + // these column will be added to the databricks query + subQAlias.columnNames match { + case Seq(ir.Id("value", _), ir.Id("index", _)) => + val columnNamesStr = + subQAlias.columnNames.sortBy(_.nodeName).reverse.map(expr.generate(_)) + code"$tableName AS ${columnNamesStr.mkCode(", ")}" + case _ => + val columnNamesStr = subQAlias.columnNames.map(expr.generate(_)) + code"AS $tableName${columnNamesStr.mkCode("(", ", ", ")")}" + } + } + code"$subquery $table" + } + + private def tableAlias(alias: ir.TableAlias): SQL = { + val target = generate(alias.child) + val columns = if (alias.columns.isEmpty) { code"" } + else { + expr.commas(alias.columns).map("(" + _ + ")") + } + code"$target AS ${alias.alias}$columns" + } + + private def deduplicate(dedup: ir.Deduplicate): SQL = { + val table = generate(dedup.child) + val columns = if (dedup.all_columns_as_keys) { code"*" } + else { + expr.commas(dedup.column_names) + } + code"SELECT DISTINCT $columns FROM $table" + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/OptionGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/OptionGenerator.scala new file mode 100644 index 0000000000..a1ac81585e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/OptionGenerator.scala @@ -0,0 +1,29 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators._ +import com.databricks.labs.remorph.{intermediate => ir} + +class OptionGenerator(expr: ExpressionGenerator) { + + def generateOption(option: ir.GenericOption): SQL = + option match { + case ir.OptionExpression(id, value, supplement) => + code"$id = ${expr.generate(value)} ${supplement.map(s => s" $s").getOrElse("")}" + case ir.OptionString(id, value) => + code"$id = '$value'" + case ir.OptionOn(id) => + code"$id = ON" + case ir.OptionOff(id) => + code"$id = OFF" + case ir.OptionAuto(id) => + code"$id = AUTO" + case ir.OptionDefault(id) => + code"$id = DEFAULT" + case ir.OptionUnresolved(text) => + code"$text" + } + + def generateOptionList(options: Seq[ir.GenericOption]): String = + options.map(generateOption(_)).mkString(", ") + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/package.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/package.scala new file mode 100644 index 0000000000..a45b124c82 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/package.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.Transformation + +package object sql { + + type SQL = Transformation[String] +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/graph/TableGraph.scala b/core/src/main/scala/com/databricks/labs/remorph/graph/TableGraph.scala new file mode 100644 index 0000000000..94db9a9d70 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/graph/TableGraph.scala @@ -0,0 +1,191 @@ +package com.databricks.labs.remorph.graph + +import com.databricks.labs.remorph.discovery.{ExecutedQuery, QueryHistory, TableDefinition} +import com.databricks.labs.remorph.parsers.PlanParser +import com.typesafe.scalalogging.LazyLogging +import com.databricks.labs.remorph.{KoResult, OkResult, Parsing, PartialResult, TranspilerState, intermediate => ir} + +protected case class Node(tableDefinition: TableDefinition, metadata: Map[String, Set[String]]) +// `from` is the table which is sourced to create `to` table +protected case class Edge(from: Node, to: Option[Node], metadata: Map[String, String]) + +class TableGraph(parser: PlanParser[_]) extends DependencyGraph with LazyLogging { + private[this] val nodes = scala.collection.mutable.Set.empty[Node] + private[this] val edges = scala.collection.mutable.Set.empty[Edge] + + override protected def addNode(id: TableDefinition, metadata: Map[String, Set[String]]): Unit = { + // Metadata list of query ids and add node only if it is not already present. + // for Example if a table is used in multiple queries, we need to consolidate the metadata + // here we are storing only query hash id as metadata not the query itself + val existingNode = nodes.find(_.tableDefinition == id) + existingNode match { + case Some(node) => + val consolidatedMetadata = (node.metadata.toSeq ++ metadata.toSeq) + .groupBy(_._1) + .mapValues(_.flatten(_._2).toSet) + nodes -= node + nodes += node.copy(metadata = consolidatedMetadata) + case None => + nodes += Node(id, metadata) + } + } + + override protected def addEdge( + from: TableDefinition, + to: Option[TableDefinition], + metadata: Map[String, String]): Unit = { + val fromNode = nodes.find(_.tableDefinition == from).get + val toNode = to.flatMap(td => nodes.find(_.tableDefinition == td)) + edges += Edge(fromNode, toNode, metadata) + } + + private def getTableName(plan: ir.LogicalPlan): String = { + plan collectFirst { case x: ir.NamedTable => + x.unparsed_identifier + } + }.getOrElse("None") + + private def generateEdges(plan: ir.LogicalPlan, tableDefinition: Set[TableDefinition], queryId: String): Unit = { + var toTable: Option[TableDefinition] = None + var action = "SELECT" + var fromTable: Seq[TableDefinition] = Seq.empty + + def collectTables(node: ir.LogicalPlan): Unit = { + node match { + case _: ir.CreateTable => + toTable = Some(tableDefinition.filter(_.table == getTableName(plan)).head) + action = "CREATE" + case _: ir.InsertIntoTable => + toTable = Some(tableDefinition.filter(_.table == getTableName(plan)).head) + action = "INSERT" + case _: ir.DeleteFromTable => + toTable = Some(tableDefinition.filter(_.table == getTableName(plan)).head) + action = "DELETE" + case _: ir.UpdateTable => + toTable = Some(tableDefinition.filter(_.table == getTableName(plan)).head) + action = "UPDATE" + case _: ir.MergeIntoTable => + toTable = Some(tableDefinition.filter(_.table == getTableName(plan)).head) + action = "MERGE" + case _: ir.Project | _: ir.Join | _: ir.SubqueryAlias | _: ir.Filter => + val tableList = plan collect { case x: ir.NamedTable => + x.unparsed_identifier + } + fromTable = tableDefinition.toSeq.filter(x => tableList.contains(x.table)) + case _ => // Do nothing + } + node.children.foreach(collectTables) + } + + collectTables(plan) + + if (fromTable.nonEmpty) { + fromTable.foreach(f => { + if (toTable.isDefined && f != toTable.get) { + addEdge(f, toTable, Map("query" -> queryId, "action" -> action)) + } else { + logger.debug(s"Ignoring reference detected for table ${f.table}") + } + }) + } else { + logger.debug(s"No tables found for insert into table values query") + } + } + + private def buildNode(plan: ir.LogicalPlan, tableDefinition: Set[TableDefinition], query: ExecutedQuery): Unit = { + plan collect { case x: ir.NamedTable => + tableDefinition.find(_.table == x.unparsed_identifier) match { + case Some(name) => addNode(name, Map("query" -> Set(query.id))) + case None => + logger.warn( + s"Table ${x.unparsed_identifier} not found in table definitions " + + s"or is it a subquery alias") + } + } + } + + def buildDependency(queryHistory: QueryHistory, tableDefinition: Set[TableDefinition]): Unit = { + queryHistory.queries.foreach { query => + try { + val plan = parser.parse.flatMap(parser.visit).run(TranspilerState(Parsing(query.source))) + plan match { + case KoResult(_, error) => + logger.warn(s"Failed to produce plan from query: ${query.id}") + logger.debug(s"Error: ${error.msg}") + case PartialResult((_, p), error) => + logger.warn(s"Errors occurred while producing plan from query ${query.id}") + logger.debug(s"Error: ${error.msg}") + buildNode(p, tableDefinition, query) + generateEdges(p, tableDefinition, query.id) + case OkResult((_, p)) => + buildNode(p, tableDefinition, query) + generateEdges(p, tableDefinition, query.id) + } + } catch { + // TODO Null Pointer Exception is thrown as OkResult, need to investigate for Merge Query. + case e: Exception => logger.warn(s"Failed to produce plan from query: ${query.source}") + } + } + } + + private def countInDegrees(): Map[TableDefinition, Int] = { + val inDegreeMap = scala.collection.mutable.Map[TableDefinition, Int]().withDefaultValue(0) + + // Initialize inDegreeMap with all nodes + nodes.foreach { node => + inDegreeMap(node.tableDefinition) = 0 + } + + edges.foreach { edge => + edge.to.foreach { toTable => + inDegreeMap(toTable.tableDefinition) += 1 + } + } + inDegreeMap.toMap + } + + // TODO Implement logic for fetching edges(parents) only upto certain level + def getRootTables(): Set[TableDefinition] = { + val inDegreeMap = countInDegrees() + nodes + .map(_.tableDefinition) + .filter(table => inDegreeMap.getOrElse(table, 0) == 0) + .toSet + } + + override def getUpstreamTables(table: TableDefinition): Set[TableDefinition] = { + def getUpstreamTablesRec(node: Node, visited: Set[Node]): Set[TableDefinition] = { + if (visited.contains(node)) { + Set.empty + } else { + val parents = edges.filter(_.to.get.tableDefinition == node.tableDefinition).map(_.from).toSet + parents.flatMap(parent => getUpstreamTablesRec(parent, visited + node)) + node.tableDefinition + } + } + + nodes.find(_.tableDefinition == table) match { + case Some(n) => getUpstreamTablesRec(n, Set.empty) - table + case None => + logger.warn(s"Table ${table.table} not found in the graph") + Set.empty[TableDefinition] + } + } + + override def getDownstreamTables(table: TableDefinition): Set[TableDefinition] = { + def getDownstreamTablesRec(node: Node, visited: Set[Node]): Set[TableDefinition] = { + if (visited.contains(node)) { + Set.empty + } else { + val children = edges.filter(_.from.tableDefinition == node.tableDefinition).flatMap(_.to).toSet + children.flatMap(child => getDownstreamTablesRec(child, visited + node)) + node.tableDefinition + } + } + nodes.find(_.tableDefinition == table) match { + case Some(n) => + getDownstreamTablesRec(n, Set.empty) - table + case None => + logger.warn(s"Table ${table.table} not found in the graph") + Set.empty[TableDefinition] + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/graph/lineage.scala b/core/src/main/scala/com/databricks/labs/remorph/graph/lineage.scala new file mode 100644 index 0000000000..72f5d90c57 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/graph/lineage.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.graph + +import com.databricks.labs.remorph.discovery.TableDefinition + +trait DependencyGraph { + protected def addNode(id: TableDefinition, metadata: Map[String, Set[String]]): Unit + protected def addEdge(from: TableDefinition, to: Option[TableDefinition], metadata: Map[String, String]): Unit + def getUpstreamTables(table: TableDefinition): Set[TableDefinition] + def getDownstreamTables(table: TableDefinition): Set[TableDefinition] +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/IRHelpers.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/IRHelpers.scala new file mode 100644 index 0000000000..c945b7b60c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/IRHelpers.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.intermediate + +trait IRHelpers { + + protected def namedTable(name: String): LogicalPlan = NamedTable(name, Map.empty, is_streaming = false) + protected def simplyNamedColumn(name: String): Column = Column(None, Id(name)) + protected def crossJoin(left: LogicalPlan, right: LogicalPlan): LogicalPlan = + Join(left, right, None, CrossJoin, Seq(), JoinDataType(is_left_struct = false, is_right_struct = false)) + + protected def withNormalizedName(call: Fn): Fn = call match { + case CallFunction(name, args) => CallFunction(name.toUpperCase(), args) + case other => other + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/commands.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/commands.scala new file mode 100644 index 0000000000..1c505f3e33 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/commands.scala @@ -0,0 +1,95 @@ +package com.databricks.labs.remorph.intermediate + +trait Command extends LogicalPlan { + def output: Seq[Attribute] = Seq.empty +} + +case class SqlCommand(sql: String, named_arguments: Map[String, Expression], pos_arguments: Seq[Expression]) + extends LeafNode + with Command + +case class CreateDataFrameViewCommand(child: Relation, name: String, is_global: Boolean, replace: Boolean) + extends LeafNode + with Command + +abstract class TableSaveMethod +case object UnspecifiedSaveMethod extends TableSaveMethod +case object SaveAsTableSaveMethod extends TableSaveMethod +case object InsertIntoSaveMethod extends TableSaveMethod + +abstract class SaveMode +case object UnspecifiedSaveMode extends SaveMode +case object AppendSaveMode extends SaveMode +case object OverwriteSaveMode extends SaveMode +case object ErrorIfExistsSaveMode extends SaveMode +case object IgnoreSaveMode extends SaveMode + +case class SaveTable(table_name: String, save_method: TableSaveMethod) extends LeafNode with Command + +case class BucketBy(bucket_column_names: Seq[String], num_buckets: Int) + +case class WriteOperation( + child: Relation, + source: Option[String], + path: Option[String], + table: Option[SaveTable], + mode: SaveMode, + sort_column_names: Seq[String], + partitioning_columns: Seq[String], + bucket_by: Option[BucketBy], + options: Map[String, String], + clustering_columns: Seq[String]) + extends LeafNode + with Command + +abstract class Mode +case object UnspecifiedMode extends Mode +case object CreateMode extends Mode +case object OverwriteMode extends Mode +case object OverwritePartitionsMode extends Mode +case object AppendMode extends Mode +case object ReplaceMode extends Mode +case object CreateOrReplaceMode extends Mode + +case class WriteOperationV2( + child: Relation, + table_name: String, + provider: Option[String], + partitioning_columns: Seq[Expression], + options: Map[String, String], + table_properties: Map[String, String], + mode: Mode, + overwrite_condition: Option[Expression], + clustering_columns: Seq[String]) + extends LeafNode + with Command + +case class Trigger( + processing_time_interval: Option[String], + available_now: Boolean = false, + once: Boolean = false, + continuous_checkpoint_interval: Option[String]) + +case class SinkDestination(path: Option[String], table_name: Option[String]) + +case class StreamingForeachFunction(python_udf: Option[PythonUDF], scala_function: Option[ScalarScalaUDF]) + +case class WriteStreamOperationStart( + child: Relation, + format: String, + options: Map[String, String], + partitioning_column_names: Seq[String], + trigger: Trigger, + output_mode: String, + query_name: String, + sink_destination: SinkDestination, + foreach_writer: Option[StreamingForeachFunction], + foreach_batch: Option[StreamingForeachFunction]) + extends LeafNode + with Command + +// TODO: align snowflake and common IR implementations for `CreateVariable` +case class CreateVariable(name: Id, dataType: DataType, defaultExpr: Option[Expression] = None, replace: Boolean) + extends LeafNode + with Command + diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/common.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/common.scala new file mode 100644 index 0000000000..83481e3946 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/common.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.intermediate + +case class StorageLevel( + use_disk: Boolean, + use_memory: Boolean, + use_off_heap: Boolean, + deserialized: Boolean, + replication: Int) + +case class ResourceInformation(name: String, addresses: Seq[String]) diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/ddl.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/ddl.scala new file mode 100644 index 0000000000..d63246c0b8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/ddl.scala @@ -0,0 +1,215 @@ +package com.databricks.labs.remorph.intermediate + +abstract class DataType { + def isPrimitive: Boolean = this match { + case BooleanType => true + case ByteType(_) => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case StringType => true + case _ => false + } +} + +case object NullType extends DataType +case object BooleanType extends DataType +case object BinaryType extends DataType + +// Numeric types +case class ByteType(size: Option[Int]) extends DataType +case object ShortType extends DataType +case object TinyintType extends DataType +case object IntegerType extends DataType +case object LongType extends DataType + +case object FloatType extends DataType +case object DoubleType extends DataType + +object DecimalType { + def apply(): DecimalType = DecimalType(None, None) + def apply(precision: Int, scale: Int): DecimalType = DecimalType(Some(precision), Some(scale)) + def fromBigDecimal(d: BigDecimal): DecimalType = DecimalType(Some(d.precision), Some(d.scale)) +} + +case class DecimalType(precision: Option[Int], scale: Option[Int]) extends DataType + +// String types +case object StringType extends DataType +case class CharType(size: Option[Int]) extends DataType +case class VarcharType(size: Option[Int]) extends DataType + +// Datatime types +case object DateType extends DataType +case object TimeType extends DataType +case object TimestampType extends DataType +case object TimestampNTZType extends DataType + +// Interval types +case object IntervalType extends DataType +case object CalendarIntervalType extends DataType +case object YearMonthIntervalType extends DataType +case object DayTimeIntervalType extends DataType + +// Complex types +case class ArrayType(elementType: DataType) extends DataType +case class StructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Option[Metadata] = None) +case class StructType(fields: Seq[StructField]) extends DataType +case class MapType(keyType: DataType, valueType: DataType) extends DataType +case object VariantType extends DataType +case class Metadata(comment: Option[String]) +// While Databricks SQl does not DIRECTLY support IDENTITY in the way some other dialects do, it does support +// Id BIGINT GENERATED ALWAYS AS IDENTITY +case class IdentityType(start: Option[Int], increment: Option[Int]) extends DataType + +// UserDefinedType +case class UDTType() extends DataType + +case class UnparsedType(text: String) extends DataType + +case object UnresolvedType extends DataType + +// These are likely to change in a not-so-remote future. Spark SQL does not have constraints +// as it is not a database in its own right. Databricks SQL supports Key constraints and +// also allows the definition of CHECK constraints via ALTER table after table creation. Spark +// does support nullability but stores that as a boolean in the column definition, as well as an +// expression for default values. +// +// So we will store the column constraints with the column definition and then use them to generate +// Databricks SQL CHECK constraints where we can, and comment the rest. +sealed trait Constraint +sealed trait UnnamedConstraint extends Constraint +case class Unique(options: Seq[GenericOption] = Seq.empty, columns: Option[Seq[String]] = None) + extends UnnamedConstraint +// Nullability is kept in case the NOT NULL constraint is named and we must generate a CHECK constraint +case class Nullability(nullable: Boolean) extends UnnamedConstraint +case class PrimaryKey(options: Seq[GenericOption] = Seq.empty, columns: Option[Seq[String]] = None) + extends UnnamedConstraint +case class ForeignKey(tableCols: String, refObject: String, refCols: String, options: Seq[GenericOption]) + extends UnnamedConstraint +case class DefaultValueConstraint(value: Expression) extends UnnamedConstraint +case class CheckConstraint(expression: Expression) extends UnnamedConstraint +// ExtendedBased on https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-table-using.html#syntax +case class IdentityConstraint( + start: Option[String] = None, + increment: Option[String] = None, + always: Boolean = false, + default: Boolean = false) + extends UnnamedConstraint +case class NamedConstraint(name: String, constraint: UnnamedConstraint) extends Constraint +case class UnresolvedConstraint(inputText: String) extends UnnamedConstraint +case class GeneratedAlways(expression: Expression) extends UnnamedConstraint + +// This, and the above, are likely to change in a not-so-remote future. +// There's already a CreateTable case defined in catalog.scala but its structure seems too different from +// the information Snowflake grammar carries. +// In future changes, we'll have to reconcile this CreateTableCommand with the "Sparkier" CreateTable somehow. +case class ColumnDeclaration( + name: String, + dataType: DataType, + virtualColumnDeclaration: Option[Expression] = Option.empty, + constraints: Seq[Constraint] = Seq.empty) + +case class CreateTableCommand(name: String, columns: Seq[ColumnDeclaration]) extends Catalog {} + +// TODO Need to introduce TableSpecBase, TableSpec and UnresolvedTableSpec + +case class ReplaceTableCommand(name: String, columns: Seq[ColumnDeclaration], orCreate: Boolean) extends Catalog + +case class ReplaceTableAsSelect( + table_name: String, + query: LogicalPlan, + writeOptions: Map[String, String], + orCreate: Boolean, + isAnalyzed: Boolean = false) + extends Catalog + +sealed trait TableAlteration +case class AddColumn(columnDeclaration: Seq[ColumnDeclaration]) extends TableAlteration +case class AddConstraint(columnName: String, constraint: Constraint) extends TableAlteration +case class ChangeColumnDataType(columnName: String, newDataType: DataType) extends TableAlteration +case class UnresolvedTableAlteration( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends TableAlteration + with UnwantedInGeneratorInput + with Unresolved[UnresolvedTableAlteration] { + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedTableAlteration = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class DropConstraintByName(constraintName: String) extends TableAlteration +// When constraintName is None, drop the constraint on every relevant column +case class DropConstraint(columnName: Option[String], constraint: Constraint) extends TableAlteration +case class DropColumns(columnNames: Seq[String]) extends TableAlteration +case class RenameConstraint(oldName: String, newName: String) extends TableAlteration +case class RenameColumn(oldName: String, newName: String) extends TableAlteration + +case class AlterTableCommand(tableName: String, alterations: Seq[TableAlteration]) extends Catalog {} + +// Catalog API (experimental / unstable) +abstract class Catalog extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class SetCurrentDatabase(db_name: String) extends Catalog {} +case class ListDatabases(pattern: Option[String]) extends Catalog {} +case class ListTables(db_name: Option[String], pattern: Option[String]) extends Catalog {} +case class ListFunctions(db_name: Option[String], pattern: Option[String]) extends Catalog {} +case class ListColumns(table_name: String, db_name: Option[String]) extends Catalog {} +case class GetDatabase(db_name: String) extends Catalog {} +case class GetTable(table_name: String, db_name: Option[String]) extends Catalog {} +case class GetFunction(function_name: String, db_name: Option[String]) extends Catalog {} +case class DatabaseExists(db_name: String) extends Catalog {} +case class TableExists(table_name: String, db_name: Option[String]) extends Catalog {} +case class FunctionExists(function_name: String, db_name: Option[String]) extends Catalog {} +case class CreateExternalTable( + table_name: String, + path: Option[String], + source: Option[String], + description: Option[String], + override val schema: DataType) + extends Catalog {} + +// As per Spark v2Commands +case class CreateTable( + table_name: String, + path: Option[String], + source: Option[String], + description: Option[String], + override val schema: DataType) + extends Catalog {} + +// As per Spark v2Commands +case class CreateTableAsSelect( + table_name: String, + query: LogicalPlan, + path: Option[String], + source: Option[String], + description: Option[String]) + extends Catalog {} + +case class DropTempView(view_name: String) extends Catalog {} +case class DropGlobalTempView(view_name: String) extends Catalog {} +case class RecoverPartitions(table_name: String) extends Catalog {} +case class IsCached(table_name: String) extends Catalog {} +case class CacheTable(table_name: String, storage_level: StorageLevel) extends Catalog {} +case class UncachedTable(table_name: String) extends Catalog {} +case class ClearCache() extends Catalog {} +case class RefreshTable(table_name: String) extends Catalog {} +case class RefreshByPath(path: String) extends Catalog {} +case class SetCurrentCatalog(catalog_name: String) extends Catalog {} +case class ListCatalogs(pattern: Option[String]) extends Catalog {} + +case class TableIdentifier(table: String, database: Option[String]) +case class CatalogTable( + identifier: TableIdentifier, + schema: StructType, + partitionColumnNames: Seq[String], + viewText: Option[String], + comment: Option[String], + unsupportedFeatures: Seq[String]) diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/dml.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/dml.scala new file mode 100644 index 0000000000..4a6c11cfcd --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/dml.scala @@ -0,0 +1,71 @@ +package com.databricks.labs.remorph.intermediate + +// Used for DML other than SELECT +abstract class Modification extends LogicalPlan + +case class InsertIntoTable( + target: LogicalPlan, + columns: Option[Seq[NameOrPosition]], + values: LogicalPlan, + outputRelation: Option[LogicalPlan] = None, + options: Option[Expression] = None, + overwrite: Boolean = false) + extends Modification { + override def children: Seq[LogicalPlan] = Seq(target, values, outputRelation.getOrElse(NoopNode)) + override def output: Seq[Attribute] = target.output +} + +case class DeleteFromTable( + target: LogicalPlan, + source: Option[LogicalPlan] = None, + where: Option[Expression] = None, + outputRelation: Option[LogicalPlan] = None, + options: Option[Expression] = None) + extends Modification { + override def children: Seq[LogicalPlan] = Seq(target, source.getOrElse(NoopNode), outputRelation.getOrElse(NoopNode)) + override def output: Seq[Attribute] = target.output +} + +case class UpdateTable( + target: LogicalPlan, + source: Option[LogicalPlan], + set: Seq[Expression], + where: Option[Expression] = None, + outputRelation: Option[LogicalPlan] = None, + options: Option[Expression] = None) + extends Modification { + override def children: Seq[LogicalPlan] = Seq(target, source.getOrElse(NoopNode), outputRelation.getOrElse(NoopNode)) + override def output: Seq[Attribute] = target.output +} + +/** + * The logical plan of the MERGE INTO command, aligned with SparkSQL + */ +case class MergeIntoTable( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + mergeCondition: Expression, + matchedActions: Seq[MergeAction] = Seq.empty, + notMatchedActions: Seq[MergeAction] = Seq.empty, + notMatchedBySourceActions: Seq[MergeAction] = Seq.empty) + extends Modification { + + override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable) + override def output: Seq[Attribute] = targetTable.output +} + +abstract class MergeAction extends Expression { + def condition: Option[Expression] + override def dataType: DataType = UnresolvedType + override def children: Seq[Expression] = condition.toSeq +} + +case class DeleteAction(condition: Option[Expression] = None) extends MergeAction + +case class UpdateAction(condition: Option[Expression], assignments: Seq[Assign]) extends MergeAction { + override def children: Seq[Expression] = condition.toSeq ++ assignments +} + +case class InsertAction(condition: Option[Expression], assignments: Seq[Assign]) extends MergeAction { + override def children: Seq[Expression] = assignments +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/errors.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/errors.scala new file mode 100644 index 0000000000..8e4fef13bd --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/errors.scala @@ -0,0 +1,112 @@ +package com.databricks.labs.remorph.intermediate + +import com.databricks.labs.remorph.Phase +import com.databricks.labs.remorph.utils.Strings + +sealed trait RemorphError { + def msg: String +} + +sealed trait SingleError extends RemorphError + +sealed trait MultipleErrors extends RemorphError { + def errors: Seq[SingleError] +} + +object RemorphError { + def merge(l: RemorphError, r: RemorphError): RemorphError = (l, r) match { + case (ls: MultipleErrors, rs: MultipleErrors) => RemorphErrors(ls.errors ++ rs.errors) + case (ls: MultipleErrors, r: SingleError) => RemorphErrors(ls.errors :+ r) + case (l: SingleError, rs: MultipleErrors) => RemorphErrors(l +: rs.errors) + case (l: SingleError, r: SingleError) => RemorphErrors(Seq(l, r)) + } +} + +case class RemorphErrors(errors: Seq[SingleError]) extends RemorphError with MultipleErrors { + override def msg: String = s"Multiple errors: ${errors.map(_.msg).mkString(", ")}" +} + +case class PreParsingError(line: Int, charPositionInLine: Int, offendingTokenText: String, message: String) + extends RemorphError + with SingleError { + override def msg: String = s"Pre-parsing error starting at $line:$charPositionInLine: $message" +} + +case class ParsingError( + line: Int, + charPositionInLine: Int, + message: String, + offendingTokenWidth: Int, + offendingTokenText: String, + offendingTokenName: String, + ruleName: String) + extends RemorphError + with SingleError { + override def msg: String = + s"Parsing error starting at $line:$charPositionInLine involving rule '$ruleName' and" + + s" token '$offendingTokenText'($offendingTokenName): $message" +} + +case class ParsingErrors(errors: Seq[ParsingError]) extends RemorphError with MultipleErrors { + override def msg: String = s"Parsing errors: ${errors.map(_.msg).mkString(", ")}" +} + +// TODO: If we wish to preserve the whole node in say JSON output, we will need to accept TreeNodew[_] and deal with +// implicits for TreeNode[_] as well +case class UnexpectedNode(offendingNode: String) extends RemorphError with SingleError { + override def msg: String = s"Unexpected node of class ${offendingNode}" +} + +case class UnexpectedTableAlteration(offendingTableAlteration: String) extends RemorphError with SingleError { + override def msg: String = s"Unexpected table alteration $offendingTableAlteration" +} + +case class UnsupportedGroupType(offendingGroupType: String) extends RemorphError with SingleError { + override def msg: String = s"Unsupported group type $offendingGroupType" +} + +case class UnsupportedDataType(offendingDataType: String) extends RemorphError with SingleError { + override def msg: String = s"Unsupported data type $offendingDataType" +} + +case class WrongNumberOfArguments(functionName: String, got: Int, expectationMessage: String) + extends RemorphError + with SingleError { + override def msg: String = + s"Wrong number of arguments for $functionName: got $got, expected $expectationMessage" +} + +case class UnsupportedArguments(functionName: String, arguments: Seq[Expression]) + extends RemorphError + with SingleError { + override def msg: String = s"Unsupported argument(s) to $functionName" +} + +case class UnsupportedDateTimePart(expression: Expression) extends RemorphError with SingleError { + override def msg: String = s"Unsupported date/time part specification: $expression" +} + +case class PlanGenerationFailure(exception: Throwable) extends RemorphError with SingleError { + override def msg: String = s"PlanGenerationFailure: ${exception.getClass.getSimpleName}, ${exception.getMessage}" +} + +case class TranspileFailure(exception: Throwable) extends RemorphError with SingleError { + override def msg: String = s"TranspileFailure: ${exception.getClass.getSimpleName}, ${exception.getMessage}" +} + +case class UncaughtException(exception: Throwable) extends RemorphError with SingleError { + override def msg: String = exception.getMessage +} + +case class UnexpectedOutput(expected: String, actual: String) extends RemorphError with SingleError { + override def msg: String = + s""" + |=== Unexpected output (expected vs actual) === + |${Strings.sideBySide(expected, actual).mkString("\n")} + |""".stripMargin +} + +case class IncoherentState(currentPhase: Phase, expectedPhase: Class[_]) extends RemorphError with SingleError { + override def msg: String = + s"Incoherent state: current phase is ${currentPhase.getClass.getSimpleName} but should be ${expectedPhase.getSimpleName}" +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/expressions.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/expressions.scala new file mode 100644 index 0000000000..11ae9a8001 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/expressions.scala @@ -0,0 +1,240 @@ +package com.databricks.labs.remorph.intermediate + +import java.util.UUID + +// Expression used to refer to fields, functions and similar. This can be used everywhere +// expressions in SQL appear. +abstract class Expression extends TreeNode[Expression] { + lazy val resolved: Boolean = childrenResolved + + def dataType: DataType + + def childrenResolved: Boolean = children.forall(_.resolved) + + def references: AttributeSet = new AttributeSet(children.flatMap(_.references): _*) +} + +/** Expression without any child expressions */ +abstract class LeafExpression extends Expression { + override final def children: Seq[Expression] = Nil +} + +object NamedExpression { + private[this] val curId = new java.util.concurrent.atomic.AtomicLong() + private[intermediate] val jvmId = UUID.randomUUID() + def newExprId: ExprId = ExprId(curId.getAndIncrement(), jvmId) + def unapply(expr: NamedExpression): Option[(String, DataType)] = Some((expr.name, expr.dataType)) +} + +case class ExprId(id: Long, jvmId: UUID) { + override def hashCode(): Int = id.hashCode() + override def equals(other: Any): Boolean = other match { + case ExprId(id, jvmId) => this.id == id && this.jvmId == jvmId + case _ => false + } +} + +object ExprId { + def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) +} + +trait NamedExpression extends Expression { + def name: String + def exprId: ExprId + + /** + * Returns a dot separated fully qualified name for this attribute. Given that there can be multiple qualifiers, it is + * possible that there are other possible way to refer to this attribute. + */ + def qualifiedName: String = (qualifier :+ name).mkString(".") + + /** + * Optional qualifier for the expression. Qualifier can also contain the fully qualified information, for e.g, + * Sequence of string containing the database and the table name + * + * For now, since we do not allow using original table name to qualify a column name once the table is aliased, this + * can only be: + * + * 1. Empty Seq: when an attribute doesn't have a qualifier, e.g. top level attributes aliased in the SELECT clause, + * or column from a LocalRelation. 2. Seq with a Single element: either the table name or the alias name of the + * table. 3. Seq with 2 elements: database name and table name 4. Seq with 3 elements: catalog name, database + * name and table name + */ + def qualifier: Seq[String] + + def toAttribute: Attribute + + /** Returns a copy of this expression with a new `exprId`. */ + def newInstance(): NamedExpression +} + +class AttributeSet(val attrs: NamedExpression*) extends Set[NamedExpression] { + def this(attrs: Set[NamedExpression]) = this(attrs.toSeq: _*) + + override def iterator: Iterator[NamedExpression] = attrs.iterator + + override def +(elem: NamedExpression): AttributeSet = new AttributeSet(attrs :+ elem: _*) + + override def -(elem: NamedExpression): AttributeSet = new AttributeSet(attrs.filterNot(_ == elem): _*) + + def --(other: AttributeSet): AttributeSet = new AttributeSet(attrs.filterNot(other.contains): _*) + + override def contains(key: NamedExpression): Boolean = attrs.contains(key) +} + +abstract class Attribute extends LeafExpression with NamedExpression { + + @transient + override lazy val references: AttributeSet = new AttributeSet(this) + + override def toAttribute: Attribute = this +} + +case class AttributeReference( + name: String, + dataType: DataType, + nullable: Boolean = true, + exprId: ExprId = NamedExpression.newExprId, + qualifier: Seq[String] = Seq.empty[String]) + extends Attribute { + override def newInstance(): NamedExpression = copy(exprId = NamedExpression.newExprId) +} + +abstract class Unary(val child: Expression) extends Expression { + override def children: Seq[Expression] = Seq(child) +} + +abstract class Binary(left: Expression, right: Expression) extends Expression { + override def children: Seq[Expression] = Seq(left, right) +} + +case class WhenBranch(condition: Expression, expression: Expression) extends Binary(condition, expression) { + override def dataType: DataType = expression.dataType +} + +case class Case(expression: Option[Expression], branches: Seq[WhenBranch], otherwise: Option[Expression]) + extends Expression { + override def children: Seq[Expression] = expression.toSeq ++ + branches.flatMap(b => Seq(b.condition, b.expression)) ++ otherwise + override def dataType: DataType = branches.head.dataType +} + +/** isnotnull(expr) - Returns true if `expr` is not null, or false otherwise. */ +case class IsNotNull(left: Expression) extends Unary(left) { + override def dataType: DataType = left.dataType +} + +/** isnull(expr) - Returns true if `expr` is null, or false otherwise. */ +case class IsNull(left: Expression) extends Unary(left) { + override def dataType: DataType = left.dataType +} + +abstract class FrameType +case object UndefinedFrame extends FrameType +case object RangeFrame extends FrameType +case object RowsFrame extends FrameType + +sealed trait FrameBoundary +case object CurrentRow extends FrameBoundary +case object UnboundedPreceding extends FrameBoundary +case object UnboundedFollowing extends FrameBoundary +case class PrecedingN(n: Expression) extends FrameBoundary +case class FollowingN(n: Expression) extends FrameBoundary +case object NoBoundary extends FrameBoundary +case class WindowFrame(frame_type: FrameType, lower: FrameBoundary, upper: FrameBoundary) + +case class Window( + window_function: Expression, + partition_spec: Seq[Expression] = Seq.empty, + sort_order: Seq[SortOrder] = Seq.empty, + frame_spec: Option[WindowFrame] = None, + ignore_nulls: Boolean = false) // TODO: this is a property of Last(), not Window + extends Expression { + override def children: Seq[Expression] = Seq(window_function) ++ partition_spec ++ sort_order + override def dataType: DataType = window_function.dataType +} + +/** cast(expr AS type) - Casts the value `expr` to the target data type `type`. */ +case class Cast( + expr: Expression, + dataType: DataType, + type_str: String = "", + returnNullOnError: Boolean = false, + timeZoneId: Option[String] = None) + extends Unary(expr) + +case class CalendarInterval(months: Int, days: Int, microseconds: Long) extends LeafExpression { + override def dataType: DataType = CalendarIntervalType +} + +case class StructExpr(fields: Seq[StarOrAlias]) extends Expression { + override def children: Seq[Expression] = fields.map { + case a: Alias => a + case s: Star => s + } + + override def dataType: DataType = fields match { + case Nil => UnresolvedType + case Seq(Star(_)) => UnresolvedType + case _ => + StructType(fields.map { case Alias(child, Id(name, _)) => + StructField(name, child.dataType) + }) + } +} + +case class UpdateFields(struct_expression: Expression, field_name: String, value_expression: Option[Expression]) + extends Expression { + override def children: Seq[Expression] = struct_expression :: value_expression.toList + override def dataType: DataType = UnresolvedType // TODO: Fix this +} + +trait StarOrAlias + +case class Alias(expr: Expression, name: Id) extends Unary(expr) with StarOrAlias { + override def dataType: DataType = expr.dataType +} + +case class LambdaFunction(function: Expression, arguments: Seq[UnresolvedNamedLambdaVariable]) extends Expression { + override def children: Seq[Expression] = function +: arguments + override def dataType: DataType = UnresolvedType // TODO: Fix this +} + +case class UnresolvedNamedLambdaVariable(name_parts: Seq[String]) extends Expression { + override def children: Seq[Expression] = Nil + override def dataType: DataType = UnresolvedType +} + +case class PythonUDF(output_type: DataType, eval_type: Int, command: Array[Byte], python_ver: String) + extends LeafExpression { + override def dataType: DataType = output_type +} + +case class ScalarScalaUDF(payload: Array[Byte], inputTypes: Seq[DataType], outputType: DataType, nullable: Boolean) + extends LeafExpression { + override def dataType: DataType = outputType +} + +case class JavaUDF(class_name: String, output_type: Option[DataType], aggregate: Boolean) extends LeafExpression { + override def dataType: DataType = output_type.getOrElse(UnresolvedType) +} + +case class CommonInlineUserDefinedFunction( + function_name: String, + deterministic: Boolean, + arguments: Seq[Expression], + python_udf: Option[PythonUDF], + scalar_scala_udf: Option[ScalarScalaUDF], + java_udf: Option[JavaUDF]) + extends Expression { + override def children: Seq[Expression] = arguments ++ python_udf.toSeq ++ scalar_scala_udf.toSeq ++ java_udf.toSeq + override def dataType: DataType = UnresolvedType +} + +case class Variable(name: String) extends LeafExpression { + override def dataType: DataType = UnresolvedType +} + +case class SchemaReference(columnName: Expression) extends Unary(columnName) { + override def dataType: DataType = UnresolvedType +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala new file mode 100644 index 0000000000..8052a7947e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala @@ -0,0 +1,261 @@ +package com.databricks.labs.remorph.intermediate + +trait AstExtension + +abstract class ToRefactor extends LeafExpression { + override def dataType: DataType = UnresolvedType +} + +sealed trait NameOrPosition extends LeafExpression + +// TODO: (nfx) refactor to align more with catalyst, replace with Name +case class Id(id: String, caseSensitive: Boolean = false) extends ToRefactor with NameOrPosition + +case class Name(name: String) extends NameOrPosition { + override def dataType: DataType = UnresolvedType +} + +case class Position(index: Int) extends ToRefactor with NameOrPosition {} + +// TODO: (nfx) refactor to align more with catalyst +case class ObjectReference(head: NameOrPosition, tail: NameOrPosition*) extends ToRefactor + +// TODO: (nfx) refactor to align more with catalyst +case class Column(tableNameOrAlias: Option[ObjectReference], columnName: NameOrPosition) + extends ToRefactor + with AstExtension {} + +case class Identifier(name: String, isQuoted: Boolean) extends ToRefactor with AstExtension {} +case object DollarAction extends ToRefactor with AstExtension {} +case class Distinct(expression: Expression) extends ToRefactor + +case object Noop extends LeafExpression { + override def dataType: DataType = UnresolvedType +} + +case object NoopNode extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +// TODO: (nfx) refactor to align more with catalyst, UnaryNode +// case class UnresolvedWith(child: LogicalPlan, ctes: Seq[(String, SubqueryAlias)]) +case class WithCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationCommon { + override def output: Seq[Attribute] = query.output + override def children: Seq[LogicalPlan] = ctes :+ query +} + +case class WithRecursiveCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationCommon { + override def output: Seq[Attribute] = query.output + override def children: Seq[LogicalPlan] = ctes :+ query +} + +// TODO: (nfx) refactor to align more with catalyst, rename to UnresolvedStar +case class Star(objectName: Option[ObjectReference] = None) extends LeafExpression with StarOrAlias { + override def dataType: DataType = UnresolvedType +} + +// Assignment operators +// TODO: (ji) This needs to be renamed to Assignment as per Catalyst +case class Assign(left: Expression, right: Expression) extends Binary(left, right) { + override def dataType: DataType = UnresolvedType +} + +// Some statements, such as SELECT, do not require a table specification +case object NoTable extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class LocalVarTable(id: Id) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +// Table hints are not directly supported in Databricks SQL, but at least some of +// them will have direct equivalents for the Catalyst optimizer. Hence they are +// included in the AST for the code generator to use them if it can. At worst, +// a comment can be generated with the hint text to guide the conversion. +abstract class TableHint +case class FlagHint(name: String) extends TableHint +case class IndexHint(indexes: Seq[Expression]) extends TableHint +case class ForceSeekHint(index: Option[Expression], indexColumns: Option[Seq[Expression]]) extends TableHint + +// It was not clear whether the NamedTable options should be used for the alias. I'm assuming it is not what +// they are for. +case class TableAlias(child: LogicalPlan, alias: String, columns: Seq[Id] = Seq.empty) extends UnaryNode { + override def output: Seq[Attribute] = columns.map(c => AttributeReference(c.id, StringType)) +} + +// TODO: (nfx) refactor to align more with catalyst +// TODO: remove this and replace with Hint(Hint(...), ...) +case class TableWithHints(child: LogicalPlan, hints: Seq[TableHint]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Batch(children: Seq[LogicalPlan]) extends LogicalPlan { + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Seq()) +} + +case class FunctionParameter(name: String, dataType: DataType, defaultValue: Option[Expression]) + +sealed trait RuntimeInfo +case class JavaRuntimeInfo(runtimeVersion: Option[String], imports: Seq[String], handler: String) extends RuntimeInfo +case class PythonRuntimeInfo(runtimeVersion: Option[String], packages: Seq[String], handler: String) extends RuntimeInfo +case object JavaScriptRuntimeInfo extends RuntimeInfo +case class ScalaRuntimeInfo(runtimeVersion: Option[String], imports: Seq[String], handler: String) extends RuntimeInfo +case class SQLRuntimeInfo(memoizable: Boolean) extends RuntimeInfo + +case class CreateInlineUDF( + name: String, + returnType: DataType, + parameters: Seq[FunctionParameter], + runtimeInfo: RuntimeInfo, + acceptsNullParameters: Boolean, + comment: Option[String], + body: String) + extends Catalog {} + +// Used for raw expressions that have no context +case class Dot(left: Expression, right: Expression) extends Binary(left, right) { + override def dataType: DataType = UnresolvedType +} + +case class ArrayAccess(array: Expression, index: Expression) extends Binary(array, index) { + override def dataType: DataType = array.dataType +} + +case class JsonAccess(json: Expression, path: Expression) extends Binary(json, path) { + override def dataType: DataType = VariantType +} + +case class Collate(string: Expression, specification: String) extends Unary(string) { + override def dataType: DataType = StringType +} + +case class Timezone(expression: Expression, timeZone: Expression) extends Binary(expression, timeZone) { + override def dataType: DataType = expression.dataType +} + +case class WithinGroup(expression: Expression, order: Seq[SortOrder]) extends Unary(expression) { + override def dataType: DataType = expression.dataType +} + +sealed trait SamplingMethod +case class RowSamplingProbabilistic(probability: BigDecimal) extends SamplingMethod +case class RowSamplingFixedAmount(amount: BigDecimal) extends SamplingMethod +case class BlockSampling(probability: BigDecimal) extends SamplingMethod + +// TODO: (nfx) refactor to align more with catalyst +case class TableSample(input: LogicalPlan, samplingMethod: SamplingMethod, seed: Option[BigDecimal]) extends UnaryNode { + override def child: LogicalPlan = input + override def output: Seq[Attribute] = input.output +} + +// Note that Databricks SQL supports FILTER() used as an expression. +case class FilterExpr(input: Seq[Expression], lambdaFunction: LambdaFunction) extends Expression { + override def children: Seq[Expression] = input :+ lambdaFunction + override def dataType: DataType = UnresolvedType +} + +case class ValueArray(expressions: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = expressions + override def dataType: DataType = UnresolvedType +} + +case class NamedStruct(keys: Seq[Expression], values: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = keys ++ values + override def dataType: DataType = UnresolvedType +} + +case class FilterStruct(input: NamedStruct, lambdaFunction: LambdaFunction) extends Expression { + override def children: Seq[Expression] = Seq(input, lambdaFunction) + override def dataType: DataType = UnresolvedType +} + +// TSQL has some join types that are not natively supported in Databricks SQL, but can possibly be emulated +// using LATERAL VIEW and an explode function. Some things like functions are translatable at IR production +// time, but complex joins are better done at the translation from IR, via an optimizer rule as they are more involved +// than some simple prescribed action such as a rename +case object CrossApply extends JoinType +case object OuterApply extends JoinType + +// TODO: fix +// @see https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-tvf.html +case class TableFunction(functionCall: Expression) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class Lateral(expr: LogicalPlan, outer: Boolean = false, isView: Boolean = false) extends UnaryNode { + override def child: LogicalPlan = expr + override def output: Seq[Attribute] = expr.output +} + +case class PlanComment(child: LogicalPlan, text: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Options( + expressionOpts: Map[String, Expression], + stringOpts: Map[String, String], + boolFlags: Map[String, Boolean], + autoFlags: List[String]) + extends Expression { + override def children: Seq[Expression] = expressionOpts.values.toSeq + override def dataType: DataType = UnresolvedType +} + +case class WithOptions(input: LogicalPlan, options: Expression) extends UnaryNode { + override def child: LogicalPlan = input + override def output: Seq[Attribute] = input.output +} + +case class WithModificationOptions(input: Modification, options: Expression) extends Modification { + override def children: Seq[Modification] = Seq(input) + override def output: Seq[Attribute] = input.output +} + +// TSQL allows the definition of everything including constraints and indexes in CREATE TABLE, +// whereas Databricks SQL does not. We will store the constraints, indexes etc., separately from the +// spark like CreateTable and then deal with them in the generator. This is because some TSQL stuff will +// be column constraints, some become table constraints, some need to be generated as ALTER statements after +// the CREATE TABLE, etc. +case class CreateTableParams( + create: Catalog, // The base create table command + colConstraints: Map[String, Seq[Constraint]], // Column constraints + colOptions: Map[String, Seq[GenericOption]], // Column constraints + constraints: Seq[Constraint], // Table constraints + indices: Seq[Constraint], // Index Definitions (currently all unresolved) + partition: Option[String], // Partitioning information but unsupported + options: Option[Seq[GenericOption]] // Command level options +) extends Catalog + +// Though at least TSQL only needs the time based intervals, we are including all the interval types +// supported by Spark SQL for completeness and future proofing +sealed trait KnownIntervalType +case object NANOSECOND_INTERVAL extends KnownIntervalType +case object MICROSECOND_INTERVAL extends KnownIntervalType +case object MILLISECOND_INTERVAL extends KnownIntervalType +case object SECOND_INTERVAL extends KnownIntervalType +case object MINUTE_INTERVAL extends KnownIntervalType +case object HOUR_INTERVAL extends KnownIntervalType +case object DAY_INTERVAL extends KnownIntervalType +case object WEEK_INTERVAL extends KnownIntervalType +case object MONTH_INTERVAL extends KnownIntervalType +case object YEAR_INTERVAL extends KnownIntervalType + +// TSQL - For translation purposes, we cannot use the standard Catalyst CalendarInterval as it is not +// meant for code generation and converts everything to microseconds. It is much easier to use an extension +// to the AST to represent the interval as it is required in TSQL, where we need to know if we were dealing with +// MONTHS, HOURS, etc. +case class KnownInterval(value: Expression, iType: KnownIntervalType) extends Expression { + override def children: Seq[Expression] = Seq(value) + override def dataType: DataType = UnresolvedType +} + +case class JinjaAsStatement(text: String) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class JinjaAsExpression(text: String) extends LeafExpression { + override def dataType: DataType = UnresolvedType +} + +case class JinjaAsDataType(text: String) extends DataType diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/functions.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/functions.scala new file mode 100644 index 0000000000..cd86ccfc93 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/functions.scala @@ -0,0 +1,2563 @@ +package com.databricks.labs.remorph.intermediate + +import com.databricks.labs.remorph.transpilers.TranspileException + +import java.util.Locale + +trait Fn extends Expression { + def prettyName: String +} + +case class CallFunction(function_name: String, arguments: Seq[Expression]) extends Expression with Fn { + override def children: Seq[Expression] = arguments + override def dataType: DataType = UnresolvedType + override def prettyName: String = function_name.toUpperCase(Locale.getDefault) +} + +class CallMapper extends Rule[LogicalPlan] with IRHelpers { + + override final def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { case fn: Fn => + try { + convert(fn) + } catch { + case e: IndexOutOfBoundsException => + throw TranspileException(WrongNumberOfArguments(fn.prettyName, fn.children.size, e.getMessage)) + + } + } + } + + /** This function is supposed to be overridden by dialects */ + def convert(call: Fn): Expression = withNormalizedName(call) match { + case CallFunction("ABS", args) => Abs(args.head) + case CallFunction("ACOS", args) => Acos(args.head) + case CallFunction("ACOSH", args) => Acosh(args.head) + case CallFunction("ADD_MONTHS", args) => AddMonths(args.head, args(1)) + case CallFunction("AGGREGATE", args) => ArrayAggregate(args.head, args(1), args(2), args(3)) + case CallFunction("ANY", args) => BoolOr(args.head) + case CallFunction("APPROX_COUNT_DISTINCT", args) => HyperLogLogPlusPlus(args.head, args(1)) + case CallFunction("ARRAYS_OVERLAP", args) => ArraysOverlap(args.head, args(1)) + case CallFunction("ARRAYS_ZIP", args) => ArraysZip(args) + case CallFunction("ARRAY_CONTAINS", args) => ArrayContains(args.head, args(1)) + case CallFunction("ARRAY_DISTINCT", args) => ArrayDistinct(args.head) + case CallFunction("ARRAY_EXCEPT", args) => ArrayExcept(args.head, args(1)) + case CallFunction("ARRAY_INTERSECT", args) => ArrayIntersect(args.head, args(1)) + case CallFunction("ARRAY_JOIN", args) => + val delim = if (args.size >= 3) Some(args(2)) else None + ArrayJoin(args.head, args(1), delim) + case CallFunction("ARRAY_MAX", args) => ArrayMax(args.head) + case CallFunction("ARRAY_MIN", args) => ArrayMin(args.head) + case CallFunction("ARRAY_POSITION", args) => ArrayPosition(args.head, args(1)) + case CallFunction("ARRAY_REMOVE", args) => ArrayRemove(args.head, args(1)) + case CallFunction("ARRAY_REPEAT", args) => ArrayRepeat(args.head, args(1)) + case CallFunction("ARRAY_SORT", args) => ArraySort(args.head, args(1)) + case CallFunction("ARRAY_UNION", args) => ArrayUnion(args.head, args(1)) + case CallFunction("ASCII", args) => Ascii(args.head) + case CallFunction("ASIN", args) => Asin(args.head) + case CallFunction("ASINH", args) => Asinh(args.head) + case CallFunction("ASSERT_TRUE", args) => AssertTrue(args.head, args(1)) + case CallFunction("ATAN", args) => Atan(args.head) + case CallFunction("ATAN2", args) => Atan2(args.head, args(1)) + case CallFunction("ATANH", args) => Atanh(args.head) + case CallFunction("AVG", args) => Average(args.head) + case CallFunction("BASE64", args) => Base64(args.head) + case CallFunction("BIN", args) => Bin(args.head) + case CallFunction("BIT_AND", args) => BitAndAgg(args.head) + case CallFunction("BIT_COUNT", args) => BitwiseCount(args.head) + case CallFunction("BIT_GET", args) => BitwiseGet(args.head, args(1)) + case CallFunction("GETBIT", args) => BitwiseGet(args.head, args(1)) // Synonym for BIT_GET + case CallFunction("BIT_LENGTH", args) => BitLength(args.head) + case CallFunction("BIT_OR", args) => BitOrAgg(args.head) + case CallFunction("BIT_XOR", args) => BitXorAgg(args.head) + case CallFunction("BOOL_AND", args) => BoolAnd(args.head) + case CallFunction("BROUND", args) => BRound(args.head, args(1)) + case CallFunction("CBRT", args) => Cbrt(args.head) + case CallFunction("CEIL", args) => Ceil(args.head) + case CallFunction("CHAR", args) => Chr(args.head) + case CallFunction("COALESCE", args) => Coalesce(args) + case CallFunction("COLLECT_LIST", args) => CollectList(args.head, args.tail.headOption) + case CallFunction("COLLECT_SET", args) => CollectSet(args.head) + case CallFunction("CONCAT", args) => Concat(args) + case CallFunction("CONCAT_WS", args) => ConcatWs(args) + case CallFunction("CONV", args) => Conv(args.head, args(1), args(2)) + case CallFunction("CORR", args) => Corr(args.head, args(1)) + case CallFunction("COS", args) => Cos(args.head) + case CallFunction("COSH", args) => Cosh(args.head) + case CallFunction("COT", args) => Cot(args.head) + case CallFunction("COUNT", args) => Count(args) + case CallFunction("COUNT_IF", args) => CountIf(args.head) + case CallFunction("COUNT_MIN_SKETCH", args) => + CountMinSketchAgg(args.head, args(1), args(2), args(3)) + case CallFunction("COVAR_POP", args) => CovPopulation(args.head, args(1)) + case CallFunction("COVAR_SAMP", args) => CovSample(args.head, args(1)) + case CallFunction("CRC32", args) => Crc32(args.head) + case CallFunction("CUBE", args) => Cube(args) + case CallFunction("CUME_DIST", _) => CumeDist() + case CallFunction("CURRENT_CATALOG", _) => CurrentCatalog() + case CallFunction("CURRENT_DATABASE", _) => CurrentDatabase() + case CallFunction("CURRENT_DATE", _) => CurrentDate() + case CallFunction("CURRENT_TIMESTAMP", _) => CurrentTimestamp() + case CallFunction("CURRENT_TIMEZONE", _) => CurrentTimeZone() + case CallFunction("DATEDIFF", args) => DateDiff(args.head, args(1)) + case CallFunction("DATE_ADD", args) => DateAdd(args.head, args(1)) + case CallFunction("DATE_FORMAT", args) => DateFormatClass(args.head, args(1)) + case CallFunction("DATE_FROM_UNIX_DATE", args) => DateFromUnixDate(args.head) + case CallFunction("DATE_PART", args) => DatePart(args.head, args(1)) + case CallFunction("DATE_SUB", args) => DateSub(args.head, args(1)) + case CallFunction("DATE_TRUNC", args) => TruncTimestamp(args.head, args(1)) + case CallFunction("DAYOFMONTH", args) => DayOfMonth(args.head) + case CallFunction("DAYOFWEEK", args) => DayOfWeek(args.head) + case CallFunction("DAYOFYEAR", args) => DayOfYear(args.head) + case CallFunction("DECODE", args) => Decode(args.head, args(1)) + case CallFunction("DEGREES", args) => ToDegrees(args.head) + case CallFunction("DENSE_RANK", args) => DenseRank(args) + case CallFunction("DIV", args) => IntegralDivide(args.head, args(1)) + case CallFunction("E", _) => EulerNumber() + case CallFunction("ELEMENT_AT", args) => ElementAt(args.head, args(1)) + case CallFunction("ELT", args) => Elt(args) + case CallFunction("ENCODE", args) => Encode(args.head, args(1)) + case CallFunction("EXISTS", args) => ArrayExists(args.head, args(1)) + case CallFunction("EXP", args) => Exp(args.head) + case CallFunction("EXPLODE", args) => Explode(args.head) + case CallFunction("EXPM1", args) => Expm1(args.head) + case CallFunction("EXTRACT", args) => Extract(args.head, args(1)) + case CallFunction("FACTORIAL", args) => Factorial(args.head) + case CallFunction("FILTER", args) => ArrayFilter(args.head, args(1)) + case CallFunction("FIND_IN_SET", args) => FindInSet(args.head, args(1)) + case CallFunction("FIRST", args) => First(args.head, args.lift(1)) + case CallFunction("FLATTEN", args) => Flatten(args.head) + case CallFunction("FLOOR", args) => Floor(args.head) + case CallFunction("FORALL", args) => ArrayForAll(args.head, args(1)) + case CallFunction("FORMAT_NUMBER", args) => FormatNumber(args.head, args(1)) + case CallFunction("FORMAT_STRING", args) => FormatString(args) + case CallFunction("FROM_CSV", args) => CsvToStructs(args.head, args(1), args(2)) + case CallFunction("FROM_JSON", args) => JsonToStructs(args.head, args(1), args.lift(2)) + case CallFunction("FROM_UNIXTIME", args) => FromUnixTime(args.head, args(1)) + case CallFunction("FROM_UTC_TIMESTAMP", args) => FromUTCTimestamp(args.head, args(1)) + case CallFunction("GET_JSON_OBJECT", args) => GetJsonObject(args.head, args(1)) + case CallFunction("GREATEST", args) => Greatest(args) + case CallFunction("GROUPING", args) => Grouping(args.head) + case CallFunction("GROUPING_ID", args) => GroupingID(args) + case CallFunction("HASH", args) => Murmur3Hash(args) + case CallFunction("HEX", args) => Hex(args.head) + case CallFunction("HOUR", args) => Hour(args.head) + case CallFunction("HYPOT", args) => Hypot(args.head, args(1)) + case CallFunction("IF", args) => If(args.head, args(1), args(2)) + case CallFunction("IFNULL", args) => IfNull(args.head, args(1)) + case CallFunction("IN", args) => In(args.head, args.tail) // TODO: not a function + case CallFunction("INITCAP", args) => InitCap(args.head) + case CallFunction("INLINE", args) => Inline(args.head) + case CallFunction("INPUT_FILE_BLOCK_LENGTH", _) => InputFileBlockLength() + case CallFunction("INPUT_FILE_BLOCK_START", _) => InputFileBlockStart() + case CallFunction("INPUT_FILE_NAME", _) => InputFileName() + case CallFunction("INSTR", args) => StringInstr(args.head, args(1)) + case CallFunction("ISNAN", args) => IsNaN(args.head) + case CallFunction("JAVA_METHOD", args) => CallMethodViaReflection(args) + case CallFunction("JSON_ARRAY_LENGTH", args) => LengthOfJsonArray(args.head) + case CallFunction("JSON_OBJECT_KEYS", args) => JsonObjectKeys(args.head) + case CallFunction("JSON_TUPLE", args) => JsonTuple(args) + case CallFunction("KURTOSIS", args) => Kurtosis(args.head) + case CallFunction("LAG", args) => Lag(args.head, args.lift(1), args.lift(2)) + case CallFunction("LAST", args) => Last(args.head, args.lift(1)) + case CallFunction("LAST_DAY", args) => LastDay(args.head) + case CallFunction("LEAD", args) => Lead(args.head, args.lift(1), args.lift(2)) + case CallFunction("LEAST", args) => Least(args) + case CallFunction("LEFT", args) => Left(args.head, args(1)) + case CallFunction("LENGTH", args) => Length(args.head) + case CallFunction("LEVENSHTEIN", args) => Levenshtein(args.head, args(1), args.lift(2)) + case CallFunction("LN", args) => Log(args.head) + case CallFunction("LOG", args) => Logarithm(args.head, args(1)) + case CallFunction("LOG10", args) => Log10(args.head) + case CallFunction("LOG1P", args) => Log1p(args.head) + case CallFunction("LOG2", args) => Log2(args.head) + case CallFunction("LOWER", args) => Lower(args.head) + case CallFunction("LPAD", args) => + StringLPad(args.head, args(1), args.lastOption.getOrElse(Literal(" "))) + case CallFunction("LTRIM", args) => StringTrimLeft(args.head, args.lift(1)) + case CallFunction("MAKE_DATE", args) => MakeDate(args.head, args(1), args(2)) + case CallFunction("MAKE_INTERVAL", args) => + MakeInterval(args.head, args(1), args(2), args(3), args(4), args(5)) + case CallFunction("MAKE_TIMESTAMP", args) => + MakeTimestamp(args.head, args(1), args(2), args(3), args(4), args(5), Some(args(6))) + case CallFunction("MAP", args) => CreateMap(args, useStringTypeWhenEmpty = false) + case CallFunction("MAP_CONCAT", args) => MapConcat(args) + case CallFunction("MAP_ENTRIES", args) => MapEntries(args.head) + case CallFunction("MAP_FILTER", args) => MapFilter(args.head, args(1)) + case CallFunction("MAP_FROM_ARRAYS", args) => MapFromArrays(args.head, args(1)) + case CallFunction("MAP_FROM_ENTRIES", args) => MapFromEntries(args.head) + case CallFunction("MAP_KEYS", args) => MapKeys(args.head) + case CallFunction("MAP_VALUES", args) => MapValues(args.head) + case CallFunction("MAP_ZIP_WITH", args) => MapZipWith(args.head, args(1), args(2)) + case CallFunction("MAX", args) => Max(args.head) + case CallFunction("MAX_BY", args) => MaxBy(args.head, args(1)) + case CallFunction("MD5", args) => Md5(args.head) + case CallFunction("MIN", args) => Min(args.head) + case CallFunction("MINUTE", args) => Minute(args.head) + case CallFunction("MIN_BY", args) => MinBy(args.head, args(1)) + case CallFunction("MOD", args) => Remainder(args.head, args(1)) + case CallFunction("MONOTONICALLY_INCREASING_ID", _) => MonotonicallyIncreasingID() + case CallFunction("MONTH", args) => Month(args.head) + case CallFunction("MONTHS_BETWEEN", args) => MonthsBetween(args.head, args(1), args(2)) + case CallFunction("NAMED_STRUCT", args) => CreateNamedStruct(args) + case CallFunction("NANVL", args) => NaNvl(args.head, args(1)) + case CallFunction("NEGATIVE", args) => UnaryMinus(args.head) + case CallFunction("NEXT_DAY", args) => NextDay(args.head, args(1)) + case CallFunction("NOW", _) => Now() + case CallFunction("NTH_VALUE", args) => NthValue(args.head, args(1), args.lift(2)) + case CallFunction("NTILE", args) => NTile(args.head) + case CallFunction("NULLIF", args) => NullIf(args.head, args(1)) + case CallFunction("NVL", args) => Nvl(args.head, args(1)) + case CallFunction("NVL2", args) => Nvl2(args.head, args(1), args(2)) + case CallFunction("OCTET_LENGTH", args) => OctetLength(args.head) + case CallFunction("OVERLAY", args) => Overlay(args.head, args(1), args(2), args(3)) + case CallFunction("PARSE_URL", args) => ParseUrl(args) + case CallFunction("PERCENTILE", args) => Percentile(args.head, args(1), args(2)) + case CallFunction("PERCENTILE_APPROX", args) => ApproximatePercentile(args.head, args(1), args(2)) + case CallFunction("PERCENT_RANK", args) => PercentRank(args) + case CallFunction("PI", _) => Pi() + case CallFunction("PMOD", args) => Pmod(args.head, args(1)) + case CallFunction("POSEXPLODE", args) => PosExplode(args.head) + case CallFunction("POSITION", args) => + StringLocate(args.head, args(1), args.lastOption.getOrElse(Literal(1))) + case CallFunction("POSITIVE", args) => UnaryPositive(args.head) + case CallFunction("POW", args) => Pow(args.head, args(1)) + case CallFunction("POWER", args) => Pow(args.head, args(1)) + case CallFunction("QUARTER", args) => Quarter(args.head) + case CallFunction("RADIANS", args) => ToRadians(args.head) + case CallFunction("RAISE_ERROR", args) => RaiseError(args.head) + case CallFunction("RAND", args) => Rand(args.head) + case CallFunction("RANDN", args) => Randn(args.head) + case CallFunction("RANK", args) => Rank(args) + case CallFunction("REGEXP_EXTRACT", args) => RegExpExtract(args.head, args(1), args.lift(2)) + case CallFunction("REGEXP_EXTRACT_ALL", args) => RegExpExtractAll(args.head, args(1), args.lift(2)) + case CallFunction("REGEXP_REPLACE", args) => RegExpReplace(args.head, args(1), args(2), args.lift(3)) + case CallFunction("REPEAT", args) => StringRepeat(args.head, args(1)) + case CallFunction("REPLACE", args) => StringReplace(args.head, args(1), args(2)) + case CallFunction("REVERSE", args) => Reverse(args.head) + case CallFunction("RIGHT", args) => Right(args.head, args(1)) + case CallFunction("RINT", args) => Rint(args.head) + case CallFunction("RLIKE", args) => RLike(args.head, args(1)) + case CallFunction("ROLLUP", args) => Rollup(args) + case CallFunction("ROUND", args) => Round(args.head, args.lift(1)) + case CallFunction("ROW_NUMBER", _) => RowNumber() + case CallFunction("RPAD", args) => StringRPad(args.head, args(1), args(2)) + case CallFunction("RTRIM", args) => StringTrimRight(args.head, args.lift(1)) + case CallFunction("SCHEMA_OF_CSV", args) => SchemaOfCsv(args.head, args(1)) + case CallFunction("SCHEMA_OF_JSON", args) => SchemaOfJson(args.head, args(1)) + case CallFunction("SECOND", args) => Second(args.head) + case CallFunction("SENTENCES", args) => Sentences(args.head, args(1), args(2)) + case CallFunction("SEQUENCE", args) => Sequence(args.head, args(1), args(2)) + case CallFunction("SHA", args) => Sha1(args.head) + case CallFunction("SHA2", args) => Sha2(args.head, args(1)) + case CallFunction("SHIFTLEFT", args) => ShiftLeft(args.head, args(1)) + case CallFunction("SHIFTRIGHT", args) => ShiftRight(args.head, args(1)) + case CallFunction("SHIFTRIGHTUNSIGNED", args) => ShiftRightUnsigned(args.head, args(1)) + case CallFunction("SHUFFLE", args) => Shuffle(args.head) + case CallFunction("SIGN", args) => Signum(args.head) + case CallFunction("SIN", args) => Sin(args.head) + case CallFunction("SINH", args) => Sinh(args.head) + case CallFunction("SIZE", args) => Size(args.head) + case CallFunction("SKEWNESS", args) => Skewness(args.head) + case CallFunction("SLICE", args) => Slice(args.head, args(1), args(2)) + case CallFunction("SORT_ARRAY", args) => SortArray(args.head, args.lift(1)) + case CallFunction("SOUNDEX", args) => SoundEx(args.head) + case CallFunction("SPACE", args) => StringSpace(args.head) + case CallFunction("SPARK_PARTITION_ID", _) => SparkPartitionID() + case CallFunction("SPLIT", args) => + val delim = if (args.size >= 3) Some(args(2)) else None + StringSplit(args.head, args(1), delim) + case CallFunction("SPLIT_PART", args) => StringSplitPart(args.head, args(1), args(2)) + case CallFunction("SQRT", args) => Sqrt(args.head) + case CallFunction("STACK", args) => Stack(args) + case CallFunction("STD", args) => StdSamp(args.head) + case CallFunction("STDDEV", args) => StddevSamp(args.head) + case CallFunction("STDDEV_POP", args) => StddevPop(args.head) + case CallFunction("STR_TO_MAP", args) => StringToMap(args.head, args(1), args(2)) + case CallFunction("SUBSTR", args) => Substring(args.head, args(1), args.lift(2)) + case CallFunction("SUBSTRING_INDEX", args) => SubstringIndex(args.head, args(1), args(2)) + case CallFunction("SUM", args) => Sum(args.head) + case CallFunction("TAN", args) => Tan(args.head) + case CallFunction("TANH", args) => Tanh(args.head) + case CallFunction("TIMESTAMP_MICROS", args) => MicrosToTimestamp(args.head) + case CallFunction("TIMESTAMP_MILLIS", args) => MillisToTimestamp(args.head) + case CallFunction("TIMESTAMP_SECONDS", args) => SecondsToTimestamp(args.head) + case CallFunction("TO_CSV", args) => StructsToCsv(args.head, args(1)) + case CallFunction("TO_DATE", args) => ParseToDate(args.head, args.lift(1)) + case CallFunction("TO_JSON", args) => StructsToJson(args.head, args.lift(1)) + case CallFunction("TO_NUMBER", args) => ToNumber(args.head, args(1)) + case CallFunction("TO_TIMESTAMP", args) => ParseToTimestamp(args.head, args.lift(1)) + case CallFunction("TO_UNIX_TIMESTAMP", args) => ToUnixTimestamp(args.head, args(1)) + case CallFunction("TO_UTC_TIMESTAMP", args) => ToUTCTimestamp(args.head, args(1)) + case CallFunction("TRANSFORM", args) => ArrayTransform(args.head, args(1)) + case CallFunction("TRANSFORM_KEYS", args) => TransformKeys(args.head, args(1)) + case CallFunction("TRANSFORM_VALUES", args) => TransformValues(args.head, args(1)) + case CallFunction("TRANSLATE", args) => StringTranslate(args.head, args(1), args(2)) + case CallFunction("TRIM", args) => StringTrim(args.head, args.lift(1)) + case CallFunction("TRUNC", args) => TruncDate(args.head, args(1)) + case CallFunction("TRY_TO_NUMBER", args) => TryToNumber(args.head, args(1)) + case CallFunction("TYPEOF", args) => TypeOf(args.head) + case CallFunction("UCASE", args) => Upper(args.head) + case CallFunction("UNBASE64", args) => UnBase64(args.head) + case CallFunction("UNHEX", args) => Unhex(args.head) + case CallFunction("UNIX_DATE", args) => UnixDate(args.head) + case CallFunction("UNIX_MICROS", args) => UnixMicros(args.head) + case CallFunction("UNIX_MILLIS", args) => UnixMillis(args.head) + case CallFunction("UNIX_SECONDS", args) => UnixSeconds(args.head) + case CallFunction("UNIX_TIMESTAMP", args) => UnixTimestamp(args.head, args(1)) + case CallFunction("UUID", _) => Uuid() + case CallFunction("VAR_POP", args) => VariancePop(args.head) + case CallFunction("VAR_SAMP", args) => VarianceSamp(args.head) + case CallFunction("VERSION", _) => SparkVersion() + case CallFunction("WEEKDAY", args) => WeekDay(args.head) + case CallFunction("WEEKOFYEAR", args) => WeekOfYear(args.head) + case CallFunction("WHEN", _) => + throw new IllegalArgumentException("WHEN (CaseWhen) should be handled separately") + case CallFunction("WIDTH_BUCKET", args) => WidthBucket(args.head, args(1), args(2), args(3)) + case CallFunction("WINDOW", _) => + throw new IllegalArgumentException("WINDOW (TimeWindow) should be handled separately") + case CallFunction("XPATH", args) => XPathList(args.head, args(1)) + case CallFunction("XPATH_BOOLEAN", args) => XPathBoolean(args.head, args(1)) + case CallFunction("XPATH_DOUBLE", args) => XPathDouble(args.head, args(1)) + case CallFunction("XPATH_FLOAT", args) => XPathFloat(args.head, args(1)) + case CallFunction("XPATH_INT", args) => XPathInt(args.head, args(1)) + case CallFunction("XPATH_LONG", args) => XPathLong(args.head, args(1)) + case CallFunction("XPATH_SHORT", args) => XPathShort(args.head, args(1)) + case CallFunction("XPATH_STRING", args) => XPathString(args.head, args(1)) + case CallFunction("XXHASH64", args) => XxHash64(args) + case CallFunction("YEAR", args) => Year(args.head) + case CallFunction("ZIP_WITH", args) => ZipWith(args.head, args(1), args(2)) + case _ => call // fallback + } +} + +/** abs(expr) - Returns the absolute value of the numeric value. */ +case class Abs(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ABS" + override def dataType: DataType = left.dataType +} + +/** acos(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by `java.lang.Math.acos`. */ +case class Acos(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ACOS" + override def dataType: DataType = left.dataType +} + +/** acosh(expr) - Returns inverse hyperbolic cosine of `expr`. */ +case class Acosh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ACOSH" + override def dataType: DataType = left.dataType +} + +/** + * aggregate(expr, start, merge, finish) - Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result by applying a finish function. + */ +case class ArrayAggregate(left: Expression, right: Expression, merge: Expression, finish: Expression) + extends Expression + with Fn { + override def prettyName: String = "AGGREGATE" + override def children: Seq[Expression] = Seq(left, right, merge, finish) + override def dataType: DataType = ArrayType(right.dataType) +} + +/** array(expr, ...) - Returns an array with the given elements. */ +case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolean = false) extends Expression with Fn { + override def prettyName: String = "ARRAY" + override def dataType: DataType = ArrayType( + children.headOption + .map(_.dataType) + .getOrElse(if (useStringTypeWhenEmpty) StringType else NullType)) +} + +/** + * array_sort(expr, func) - Sorts the input array. If func is omitted, sort in ascending order. The elements of the + * input array must be orderable. Null elements will be placed at the end of the returned array. Since 3.0.0 this + * function also sorts and returns the array based on the given comparator function. The comparator will take two + * arguments representing two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, + * or greater than the second element. If the comparator function returns other values (including null), the function + * will fail and raise an error. + */ +case class ArraySort(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_SORT" + override def dataType: DataType = left.dataType +} + +/** ascii(str) - Returns the numeric value of the first character of `str`. */ +case class Ascii(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ASCII" + override def dataType: DataType = LongType +} + +/** + * asin(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, as if computed by + * `java.lang.Math.asin`. + */ +case class Asin(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ASIN" + override def dataType: DataType = DoubleType +} + +/** asinh(expr) - Returns inverse hyperbolic sine of `expr`. */ +case class Asinh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ASINH" + override def dataType: DataType = DoubleType +} + +/** assert_true(expr) - Throws an exception if `expr` is not true. */ +case class AssertTrue(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ASSERT_TRUE" + override def dataType: DataType = UnresolvedType +} + +/** + * atan(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by `java.lang.Math.atan` + */ +case class Atan(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ATAN" + override def dataType: DataType = DoubleType +} + +/** + * atan2(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane and the point given by the + * coordinates (`exprX`, `exprY`), as if computed by `java.lang.Math.atan2`. + */ +case class Atan2(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ATAN2" + override def dataType: DataType = DoubleType +} + +/** atanh(expr) - Returns inverse hyperbolic tangent of `expr`. */ +case class Atanh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ATANH" + override def dataType: DataType = DoubleType +} + +/** base64(bin) - Converts the argument from a binary `bin` to a base64 string. */ +case class Base64(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BASE64" + override def dataType: DataType = StringType +} + +/** bin(expr) - Returns the string representation of the long value `expr` represented in binary. */ +case class Bin(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIN" + override def dataType: DataType = StringType +} + +/** + * bit_count(expr) - Returns the number of bits that are set in the argument expr as an unsigned 64-bit integer, or NULL + * if the argument is NULL. + */ +case class BitwiseCount(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIT_COUNT" + override def dataType: DataType = LongType +} + +/** bit_length(expr) - Returns the bit length of string data or number of bits of binary data. */ +case class BitLength(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIT_LENGTH" + override def dataType: DataType = LongType +} + +/** bit_get(expr, bit) and getbit(expr, bit) retuirns bit value at position bit in expr */ +case class BitwiseGet(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "GETBIT" + override def dataType: DataType = UnresolvedType +} + +/** bround(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_EVEN rounding mode. */ +case class BRound(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "BROUND" + override def dataType: DataType = DoubleType +} + +/** + * cardinality(expr) - Returns the size of an array or a map. The function returns null for null input if + * spark.sql.legacy.sizeOfNull is set to false or spark.sql.ansi.enabled is set to true. Otherwise, the function returns + * -1 for null input. With the default settings, the function returns -1 for null input. + */ +case class Size(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SIZE" + override def dataType: DataType = UnresolvedType +} + +/** cbrt(expr) - Returns the cube root of `expr`. */ +case class Cbrt(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "CBRT" + override def dataType: DataType = DoubleType +} + +/** ceil(expr) - Returns the smallest integer not smaller than `expr`. */ +case class Ceil(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "CEIL" + override def dataType: DataType = LongType +} + +/** + * char(expr) - Returns the ASCII character having the binary equivalent to `expr`. If n is larger than 256 the result + * is equivalent to chr(n % 256) + */ +case class Chr(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "CHAR" + override def dataType: DataType = StringType +} + +/** + * char_length(expr) - Returns the character length of string data or number of bytes of binary data. The length of + * string data includes the trailing spaces. The length of binary data includes binary zeros. + */ +case class Length(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LENGTH" + override def dataType: DataType = LongType +} + +/** coalesce(expr1, expr2, ...) - Returns the first non-null argument if exists. Otherwise, null. */ +case class Coalesce(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "COALESCE" + override def dataType: DataType = UnresolvedType +} + +/** concat_ws(sep[, str | array(str)]+) - Returns the concatenation of the strings separated by `sep`. */ +case class ConcatWs(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "CONCAT_WS" + override def dataType: DataType = StringType +} + +/** conv(num, from_base, to_base) - Convert `num` from `from_base` to `to_base`. */ +case class Conv(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "CONV" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** cos(expr) - Returns the cosine of `expr`, as if computed by `java.lang.Math.cos`. */ +case class Cos(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "COS" + override def dataType: DataType = DoubleType +} + +/** + * cosh(expr) - Returns the hyperbolic cosine of `expr`, as if computed by `java.lang.Math.cosh`. + */ +case class Cosh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "COSH" + override def dataType: DataType = DoubleType +} + +/** cot(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math.cot`. */ +case class Cot(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "COT" + override def dataType: DataType = DoubleType +} + +/** crc32(expr) - Returns a cyclic redundancy check value of the `expr` as a bigint. */ +case class Crc32(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "CRC32" + override def dataType: DataType = UnresolvedType +} + +/** + * cube([col1[, col2 ..]]) - create a multi-dimensional cube using the specified columns so that we can run aggregation + * on them. + */ +case class Cube(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "CUBE" + override def dataType: DataType = UnresolvedType +} + +/** current_catalog() - Returns the current catalog. */ +case class CurrentCatalog() extends LeafExpression with Fn { + override def prettyName: String = "CURRENT_CATALOG" + override def dataType: DataType = StringType +} + +/** current_database() - Returns the current database. */ +case class CurrentDatabase() extends LeafExpression with Fn { + override def prettyName: String = "CURRENT_DATABASE" + override def dataType: DataType = StringType +} + +/** dayofmonth(date) - Returns the day of month of the date/timestamp. */ +case class DayOfMonth(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "DAYOFMONTH" + override def dataType: DataType = IntegerType +} + +/** decode(bin, charset) - Decodes the first argument using the second argument character set. */ +case class Decode(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DECODE" + override def dataType: DataType = BinaryType +} + +/** degrees(expr) - Converts radians to degrees. */ +case class ToDegrees(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "DEGREES" + override def dataType: DataType = DoubleType +} + +/** + * expr1 div expr2 - Divide `expr1` by `expr2`. It returns NULL if an operand is NULL or `expr2` is 0. The result is + * casted to long. + */ +case class IntegralDivide(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DIV" + override def dataType: DataType = LongType +} + +/** e() - Returns Euler's number, e. */ +case class EulerNumber() extends LeafExpression with Fn { + override def prettyName: String = "E" + override def dataType: DataType = UnresolvedType +} + +/** + * element_at(array, index) - Returns element of array at given (1-based) index. If index < 0, accesses elements from + * the last to the first. The function returns NULL if the index exceeds the length of the array and + * `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, it throws + * ArrayIndexOutOfBoundsException for invalid indices. + * + * element_at(map, key) - Returns value for given key. The function returns NULL if the key is not contained in the map + * and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, it throws + * NoSuchElementException instead. + */ +case class ElementAt(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ELEMENT_AT" + override def dataType: DataType = UnresolvedType +} + +/** + * elt(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. The function returns + * NULL if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. If + * `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid indices. + */ +case class Elt(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "ELT" + override def dataType: DataType = UnresolvedType +} + +/** encode(str, charset) - Encodes the first argument using the second argument character set. */ +case class Encode(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ENCODE" + override def dataType: DataType = UnresolvedType +} + +/** exists(expr, pred) - Tests whether a predicate holds for one or more elements in the array. */ +case class ArrayExists(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "EXISTS" + override def dataType: DataType = UnresolvedType +} + +/** exp(expr) - Returns e to the power of `expr`. */ +case class Exp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "EXP" + override def dataType: DataType = UnresolvedType +} + +/** + * explode(expr) - Separates the elements of array `expr` into multiple rows, or the elements of map `expr` into + * multiple rows and columns. Unless specified otherwise, uses the default column name `col` for elements of the array + * or `key` and `value` for the elements of the map. + */ +case class Explode(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "EXPLODE" + override def dataType: DataType = UnresolvedType +} + +/** + * variant_explode(expr) - Separates the elements of a variant type `expr` into multiple rows. This function is used + * specifically for handling variant data types, ensuring that each element is processed and outputted as a separate row. + * The default column name for the elements is `col`. + */ +case class VariantExplode(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "VARIANT_EXPLODE" + override def dataType: DataType = UnresolvedType +} + +/** + * variant_explode_outer(expr) - Separates the elements of a variant type `expr` into multiple rows, + * including null values. This function is used specifically for handling variant data types, ensuring that + * each element, including nulls, is processed and outputted as a separate row. The default column name for the + * elements is `col`. + */ +case class VariantExplodeOuter(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "VARIANT_EXPLODE_OUTER" + override def dataType: DataType = UnresolvedType +} + +/** expm1(expr) - Returns exp(`expr`) - 1. */ +case class Expm1(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "EXPM1" + override def dataType: DataType = UnresolvedType +} + +/** extract(field FROM source) - Extracts a part of the date/timestamp or interval source. */ +case class Extract(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "EXTRACT" + override def dataType: DataType = UnresolvedType +} + +/** factorial(expr) - Returns the factorial of `expr`. `expr` is [0..20]. Otherwise, null. */ +case class Factorial(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "FACTORIAL" + override def dataType: DataType = UnresolvedType +} + +/** filter(expr, func) - Filters the input array using the given predicate. */ +case class ArrayFilter(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FILTER" + override def dataType: DataType = UnresolvedType +} + +/** + * find_in_set(str, str_array) - Returns the index (1-based) of the given string (`str`) in the comma-delimited list + * (`str_array`). Returns 0, if the string was not found or if the given string (`str`) contains a comma. + */ +case class FindInSet(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FIND_IN_SET" + override def dataType: DataType = UnresolvedType +} + +/** floor(expr) - Returns the largest integer not greater than `expr`. */ +case class Floor(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "FLOOR" + override def dataType: DataType = UnresolvedType +} + +/** forall(expr, pred) - Tests whether a predicate holds for all elements in the array. */ +case class ArrayForAll(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FORALL" + override def dataType: DataType = UnresolvedType +} + +/** + * format_number(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If + * `expr2` is 0, the result has no decimal point or fractional part. `expr2` also accept a user specified format. This + * is supposed to function like MySQL's FORMAT. + */ +case class FormatNumber(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FORMAT_NUMBER" + override def dataType: DataType = UnresolvedType +} + +/** format_string(strfmt, obj, ...) - Returns a formatted string from printf-style format strings. */ +case class FormatString(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "FORMAT_STRING" + override def dataType: DataType = UnresolvedType +} + +/** from_csv(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`. */ +case class CsvToStructs(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "FROM_CSV" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** greatest(expr, ...) - Returns the greatest value of all parameters, skipping null values. */ +case class Greatest(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "GREATEST" + override def dataType: DataType = UnresolvedType +} + +/** + * grouping(col) - indicates whether a specified column in a GROUP BY is aggregated or not, returns 1 for aggregated or + * 0 for not aggregated in the result set.", + */ +case class Grouping(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "GROUPING" + override def dataType: DataType = UnresolvedType +} + +/** + * grouping_id([col1[, col2 ..]]) - returns the level of grouping, equals to `(grouping(c1) << (n-1)) + (grouping(c2) << + * (n-2)) + ... + grouping(cn)` + */ +case class GroupingID(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "GROUPING_ID" + override def dataType: DataType = UnresolvedType +} + +/** hash(expr1, expr2, ...) - Returns a hash value of the arguments. */ +case class Murmur3Hash(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "HASH" + override def dataType: DataType = UnresolvedType +} + +/** hex(expr) - Converts `expr` to hexadecimal. */ +case class Hex(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "HEX" + override def dataType: DataType = UnresolvedType +} + +/** hypot(expr1, expr2) - Returns sqrt(`expr1`**2 + `expr2`**2). */ +case class Hypot(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "HYPOT" + override def dataType: DataType = UnresolvedType +} + +/** if(expr1, expr2, expr3) - If `expr1` evaluates to true, then returns `expr2`; otherwise returns `expr3`. */ +case class If(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "IF" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** ifnull(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise. */ +case class IfNull(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "IFNULL" + override def dataType: DataType = UnresolvedType +} + +/** expr1 in(expr2, expr3, ...) - Returns true if `expr` equals to any valN. */ +case class In(left: Expression, other: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "IN" + override def children: Seq[Expression] = left +: other + override def dataType: DataType = UnresolvedType +} + +/** + * initcap(str) - Returns `str` with the first letter of each word in uppercase. All other letters are in lowercase. + * Words are delimited by white space. + */ +case class InitCap(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "INITCAP" + override def dataType: DataType = UnresolvedType +} + +/** + * inline(expr) - Explodes an array of structs into a table. Uses column names col1, col2, etc. by default unless + * specified otherwise. + */ +case class Inline(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "INLINE" + override def dataType: DataType = UnresolvedType +} + +/** input_file_block_length() - Returns the length of the block being read, or -1 if not available. */ +case class InputFileBlockLength() extends LeafExpression with Fn { + override def prettyName: String = "INPUT_FILE_BLOCK_LENGTH" + override def dataType: DataType = UnresolvedType +} + +/** input_file_block_start() - Returns the start offset of the block being read, or -1 if not available. */ +case class InputFileBlockStart() extends LeafExpression with Fn { + override def prettyName: String = "INPUT_FILE_BLOCK_START" + override def dataType: DataType = UnresolvedType +} + +/** input_file_name() - Returns the name of the file being read, or empty string if not available. */ +case class InputFileName() extends LeafExpression with Fn { + override def prettyName: String = "INPUT_FILE_NAME" + override def dataType: DataType = UnresolvedType +} + +/** instr(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. */ +case class StringInstr(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "INSTR" + override def dataType: DataType = UnresolvedType +} + +/** isnan(expr) - Returns true if `expr` is NaN, or false otherwise. */ +case class IsNaN(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ISNAN" + override def dataType: DataType = UnresolvedType +} + +/** java_method(class, method[, arg1[, arg2 ..]]) - Calls a method with reflection. */ +case class CallMethodViaReflection(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "JAVA_METHOD" + override def dataType: DataType = UnresolvedType +} + +/** lcase(str) - Returns `str` with all characters changed to lowercase. */ +case class Lower(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LOWER" + override def dataType: DataType = UnresolvedType +} + +/** least(expr, ...) - Returns the least value of all parameters, skipping null values. */ +case class Least(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "LEAST" + override def dataType: DataType = UnresolvedType +} + +/** + * left(str, len) - Returns the leftmost `len`(`len` can be string type) characters from the string `str`,if `len` is + * less or equal than 0 the result is an empty string. + */ +case class Left(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "LEFT" + override def dataType: DataType = UnresolvedType +} + +/** levenshtein(str1, str2) - Returns the Levenshtein distance between the two given strings. */ +case class Levenshtein(left: Expression, right: Expression, maxDistance: Option[Expression]) + extends Expression + with Fn { + override def prettyName: String = "LEVENSHTEIN" + override def children: Seq[Expression] = Seq(left, right) ++ maxDistance.toSeq + override def dataType: DataType = UnresolvedType +} + +/** ln(expr) - Returns the natural logarithm (base e) of `expr`. */ +case class Log(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LN" + override def dataType: DataType = UnresolvedType +} + +/** + * locate(substr, str[, pos]) - Returns the position of the first occurrence of `substr` in `str` after position `pos`. + * The given `pos` and return value are 1-based. + */ +case class StringLocate(left: Expression, right: Expression, c: Expression = Literal(1)) extends Expression with Fn { + override def prettyName: String = "POSITION" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** log(base, expr) - Returns the logarithm of `expr` with `base`. */ +case class Logarithm(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "LOG" + override def dataType: DataType = UnresolvedType +} + +/** log10(expr) - Returns the logarithm of `expr` with base 10. */ +case class Log10(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LOG10" + override def dataType: DataType = UnresolvedType +} + +/** log1p(expr) - Returns log(1 + `expr`). */ +case class Log1p(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LOG1P" + override def dataType: DataType = UnresolvedType +} + +/** log2(expr) - Returns the logarithm of `expr` with base 2. */ +case class Log2(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LOG2" + override def dataType: DataType = UnresolvedType +} + +/** + * lpad(str, len[, pad]) - Returns `str`, left-padded with `pad` to a length of `len`. If `str` is longer than `len`, + * the return value is shortened to `len` characters. If `pad` is not specified, `str` will be padded to the left with + * space characters. + */ +case class StringLPad(left: Expression, right: Expression, pad: Expression = Literal(" ")) extends Expression with Fn { + override def prettyName: String = "LPAD" + override def children: Seq[Expression] = Seq(left, right, pad) + override def dataType: DataType = UnresolvedType +} + +/** ltrim(str) - Removes the leading space characters from `str`. */ +case class StringTrimLeft(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def prettyName: String = "LTRIM" + override def children: Seq[Expression] = Seq(left) ++ right + override def dataType: DataType = UnresolvedType +} + +/** + * make_interval(years, months, weeks, days, hours, mins, secs) - Make interval from years, months, weeks, days, hours, + * mins and secs. + */ +case class MakeInterval( + years: Expression, + months: Expression, + weeks: Expression, + hours: Expression, + mins: Expression, + secs: Expression) + extends Expression + with Fn { + override def prettyName: String = "MAKE_INTERVAL" + override def children: Seq[Expression] = Seq(years, months, weeks, hours, mins, secs) + override def dataType: DataType = UnresolvedType +} + +/** map(key0, value0, key1, value1, ...) - Creates a map with the given key/value pairs. */ +case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) extends Expression with Fn { + override def prettyName: String = "MAP" + override def dataType: DataType = UnresolvedType +} + +/** map_filter(expr, func) - Filters entries in a map using the function. */ +case class MapFilter(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "MAP_FILTER" + override def dataType: DataType = UnresolvedType +} + +/** + * map_from_arrays(keys, values) - Creates a map with a pair of the given key/value arrays. All elements in keys should + * not be null + */ +case class MapFromArrays(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "MAP_FROM_ARRAYS" + override def dataType: DataType = UnresolvedType +} + +/** + * map_zip_with(map1, map2, function) - Merges two given maps into a single map by applying function to the pair of + * values with the same key. For keys only presented in one map, NULL will be passed as the value for the missing key. + * If an input map contains duplicated keys, only the first entry of the duplicated key is passed into the lambda + * function. + */ +case class MapZipWith(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "MAP_ZIP_WITH" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** md5(expr) - Returns an MD5 128-bit checksum as a hex string of `expr`. */ +case class Md5(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MD5" + override def dataType: DataType = UnresolvedType +} + +/** expr1 mod expr2 - Returns the remainder after `expr1`/`expr2`. */ +case class Remainder(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "MOD" + override def dataType: DataType = left.dataType +} + +/** + * monotonically_increasing_id() - Returns monotonically increasing 64-bit integers. The generated ID is guaranteed to + * be monotonically increasing and unique, but not consecutive. The current implementation puts the partition ID in the + * upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the + * data frame has less than 1 billion partitions, and each partition has less than 8 billion records. The function is + * non-deterministic because its result depends on partition IDs. + */ +case class MonotonicallyIncreasingID() extends LeafExpression with Fn { + override def prettyName: String = "MONOTONICALLY_INCREASING_ID" + override def dataType: DataType = UnresolvedType +} + +/** named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values. */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "NAMED_STRUCT" + override def dataType: DataType = UnresolvedType +} + +/** nanvl(expr1, expr2) - Returns `expr1` if it's not NaN, or `expr2` otherwise. */ +case class NaNvl(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "NANVL" + override def dataType: DataType = UnresolvedType +} + +/** negative(expr) - Returns the negated value of `expr`. */ +case class UnaryMinus(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "NEGATIVE" + override def dataType: DataType = UnresolvedType +} + +/** nullif(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise. */ +case class NullIf(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "NULLIF" + override def dataType: DataType = UnresolvedType +} + +/** nvl(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise. */ +case class Nvl(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "NVL" + override def dataType: DataType = UnresolvedType +} + +/** nvl2(expr1, expr2, expr3) - Returns `expr2` if `expr1` is not null, or `expr3` otherwise. */ +case class Nvl2(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "NVL2" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** octet_length(expr) - Returns the byte length of string data or number of bytes of binary data. */ +case class OctetLength(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "OCTET_LENGTH" + override def dataType: DataType = UnresolvedType +} + +/** overlay(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`. */ +case class Overlay(left: Expression, right: Expression, c: Expression, d: Expression) extends Expression with Fn { + override def prettyName: String = "OVERLAY" + override def children: Seq[Expression] = Seq(left, right, c, d) + override def dataType: DataType = UnresolvedType +} + +/** parse_url(url, partToExtract[, key]) - Extracts a part from a URL. */ +case class ParseUrl(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "PARSE_URL" + override def dataType: DataType = UnresolvedType +} + +/** pi() - Returns pi. */ +case class Pi() extends LeafExpression with Fn { + override def prettyName: String = "PI" + override def dataType: DataType = UnresolvedType +} + +/** pmod(expr1, expr2) - Returns the positive value of `expr1` mod `expr2`. */ +case class Pmod(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "PMOD" + override def dataType: DataType = UnresolvedType +} + +/** + * posexplode(expr) - Separates the elements of array `expr` into multiple rows with positions, or the elements of map + * `expr` into multiple rows and columns with positions. Unless specified otherwise, uses the column name `pos` for + * position, `col` for elements of the array or `key` and `value` for elements of the map. + */ +case class PosExplode(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "POSEXPLODE" + override def dataType: DataType = UnresolvedType +} + +/** positive(expr) - Returns the value of `expr`. */ +case class UnaryPositive(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "POSITIVE" + override def dataType: DataType = UnresolvedType +} + +/** + * pow(expr1, expr2) - Raises `expr1` to the power of `expr2`. + * @see + * https://docs.databricks.com/en/sql/language-manual/functions/pow.html + */ +case class Pow(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "POWER" // alias: POW + override def dataType: DataType = UnresolvedType +} + +/** radians(expr) - Converts degrees to radians. */ +case class ToRadians(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "RADIANS" + override def dataType: DataType = UnresolvedType +} + +/** raise_error(expr) - Throws an exception with `expr`. */ +case class RaiseError(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "RAISE_ERROR" + override def dataType: DataType = UnresolvedType +} + +/** + * rand([seed]) - Returns a random value with independent and identically distributed (i.i.d.) uniformly distributed + * values in [0, 1). + */ +case class Rand(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "RAND" + override def dataType: DataType = UnresolvedType +} + +/** + * randn([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the + * standard normal distribution. + */ +case class Randn(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "RANDN" + override def dataType: DataType = UnresolvedType +} + +/** + * regexp_extract(str, regexp[, idx]) - Extract the first string in the `str` that match the `regexp` expression and + * corresponding to the regex group index. + */ +case class RegExpExtract(left: Expression, right: Expression, c: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "REGEXP_EXTRACT" + override def children: Seq[Expression] = Seq(left, right) ++ c.toSeq + override def dataType: DataType = UnresolvedType +} + +/** + * regexp_extract_all(str, regexp[, idx]) - Extract all strings in the `str` that match the `regexp` expression and + * corresponding to the regex group index. + */ +case class RegExpExtractAll(left: Expression, right: Expression, c: Option[Expression] = None) + extends Expression + with Fn { + override def prettyName: String = "REGEXP_EXTRACT_ALL" + override def children: Seq[Expression] = Seq(left, right) ++ c.toSeq + override def dataType: DataType = UnresolvedType +} + +/** regexp_replace(str, regexp, rep[, position]) - Replaces all substrings of `str` that match `regexp` with `rep`. */ +case class RegExpReplace(left: Expression, right: Expression, c: Expression, d: Option[Expression]) + extends Expression + with Fn { + override def prettyName: String = "REGEXP_REPLACE" + override def children: Seq[Expression] = Seq(left, right, c) ++ d + override def dataType: DataType = UnresolvedType +} + +/** repeat(str, n) - Returns the string which repeats the given string value n times. */ +case class StringRepeat(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "REPEAT" + override def dataType: DataType = UnresolvedType +} + +/** replace(str, search[, replace]) - Replaces all occurrences of `search` with `replace`. */ +case class StringReplace(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "REPLACE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * right(str, len) - Returns the rightmost `len`(`len` can be string type) characters from the string `str`,if `len` is + * less or equal than 0 the result is an empty string. + */ +case class Right(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "RIGHT" + override def dataType: DataType = UnresolvedType +} + +/** + * rint(expr) - Returns the double value that is closest in value to the argument and is equal to a mathematical + * integer. + */ +case class Rint(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "RINT" + override def dataType: DataType = UnresolvedType +} + +/** + * rollup([col1[, col2 ..]]) - create a multi-dimensional rollup using the specified columns so that we can run + * aggregation on them. + */ +case class Rollup(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "ROLLUP" + override def dataType: DataType = UnresolvedType +} + +/** round(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_UP rounding mode. */ +case class Round(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def children: Seq[Expression] = Seq(left) ++ right + override def prettyName: String = "ROUND" + override def dataType: DataType = UnresolvedType +} + +/** + * rpad(str, len[, pad]) - Returns `str`, right-padded with `pad` to a length of `len`. If `str` is longer than `len`, + * the return value is shortened to `len` characters. If `pad` is not specified, `str` will be padded to the right with + * space characters. + */ +case class StringRPad(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "RPAD" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** rtrim(str) - Removes the trailing space characters from `str`. */ +case class StringTrimRight(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def children: Seq[Expression] = Seq(left) ++ right + override def prettyName: String = "RTRIM" + override def dataType: DataType = UnresolvedType +} + +/** schema_of_csv(csv[, options]) - Returns schema in the DDL format of CSV string. */ +case class SchemaOfCsv(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SCHEMA_OF_CSV" + override def dataType: DataType = UnresolvedType +} + +/** sentences(str[, lang, country]) - Splits `str` into an array of array of words. */ +case class Sentences(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "SENTENCES" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** sha(expr) - Returns a sha1 hash value as a hex string of the `expr`. */ +case class Sha1(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SHA" + override def dataType: DataType = UnresolvedType +} + +/** + * sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of `expr`. SHA-224, SHA-256, SHA-384, and + * SHA-512 are supported. Bit length of 0 is equivalent to 256. + */ +case class Sha2(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SHA2" + override def dataType: DataType = UnresolvedType +} + +/** shiftleft(base, expr) - Bitwise left shift. */ +case class ShiftLeft(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SHIFTLEFT" + override def dataType: DataType = UnresolvedType +} + +/** shiftright(base, expr) - Bitwise (signed) right shift. */ +case class ShiftRight(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SHIFTRIGHT" + override def dataType: DataType = UnresolvedType +} + +/** shiftrightunsigned(base, expr) - Bitwise unsigned right shift. */ +case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SHIFTRIGHTUNSIGNED" + override def dataType: DataType = UnresolvedType +} + +/** sign(expr) - Returns -1.0, 0.0 or 1.0 as `expr` is negative, 0 or positive. */ +case class Signum(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SIGN" + override def dataType: DataType = UnresolvedType +} + +/** sin(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math.sin`. */ +case class Sin(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SIN" + override def dataType: DataType = UnresolvedType +} + +/** sinh(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math.sinh`. */ +case class Sinh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SINH" + override def dataType: DataType = UnresolvedType +} + +/** soundex(str) - Returns Soundex code of the string. */ +case class SoundEx(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SOUNDEX" + override def dataType: DataType = UnresolvedType +} + +/** space(n) - Returns a string consisting of `n` spaces. */ +case class StringSpace(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SPACE" + override def dataType: DataType = UnresolvedType +} + +/** spark_partition_id() - Returns the current partition id. */ +case class SparkPartitionID() extends LeafExpression with Fn { + override def prettyName: String = "SPARK_PARTITION_ID" + override def dataType: DataType = UnresolvedType +} + +/** + * split(str, regex, limit) - Splits `str` around occurrences that match `regex` and returns an array with a length of + * at most `limit` + */ +case class StringSplit(left: Expression, right: Expression, c: Option[Expression]) extends Expression with Fn { + override def prettyName: String = "SPLIT" + override def children: Seq[Expression] = Seq(left, right) ++ c.toSeq + override def dataType: DataType = UnresolvedType +} + +/** + * split_part(str, delim, partNum) - Splits str around occurrences of delim and returns the partNum part. + */ +case class StringSplitPart(str: Expression, delim: Expression, partNum: Expression) extends Expression with Fn { + override def prettyName: String = "SPLIT_PART" + override def children: Seq[Expression] = Seq(str, delim, partNum) + override def dataType: DataType = StringType +} + +/** sqrt(expr) - Returns the square root of `expr`. */ +case class Sqrt(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SQRT" + override def dataType: DataType = UnresolvedType +} + +/** + * stack(n, expr1, ..., exprk) - Separates `expr1`, ..., `exprk` into `n` rows. Uses column names col0, col1, etc. by + * default unless specified otherwise. + */ +case class Stack(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "STACK" + override def dataType: DataType = UnresolvedType +} + +/** + * str_to_map(text[, pairDelim[, keyValueDelim]]) - Creates a map after splitting the text into key/value pairs using + * delimiters. Default delimiters are ',' for `pairDelim` and ':' for `keyValueDelim`. Both `pairDelim` and + * `keyValueDelim` are treated as regular expressions. + */ +case class StringToMap(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "STR_TO_MAP" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * substr(str, pos[, len]) - Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of + * byte array that starts at `pos` and is of length `len`. + * + * substr(str FROM pos[ FOR len]]) - Returns the substring of `str` that starts at `pos` and is of length `len`, or the + * slice of byte array that starts at `pos` and is of length `len`. + */ +case class Substring(str: Expression, pos: Expression, len: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "SUBSTR" + override def children: Seq[Expression] = Seq(str, pos) ++ len.toSeq + override def dataType: DataType = UnresolvedType +} + +/** + * substring_index(str, delim, count) - Returns the substring from `str` before `count` occurrences of the delimiter + * `delim`. If `count` is positive, everything to the left of the final delimiter (counting from the left) is returned. + * If `count` is negative, everything to the right of the final delimiter (counting from the right) is returned. The + * function substring_index performs a case-sensitive match when searching for `delim`. + */ +case class SubstringIndex(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "SUBSTRING_INDEX" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** tan(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math.tan`. */ +case class Tan(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TAN" + override def dataType: DataType = UnresolvedType +} + +/** + * tanh(expr) - Returns the hyperbolic tangent of `expr`, as if computed by `java.lang.Math.tanh`. + */ +case class Tanh(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TANH" + override def dataType: DataType = UnresolvedType +} + +/** to_csv(expr[, options]) - Returns a CSV string with a given struct value */ +case class StructsToCsv(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TO_CSV" + override def dataType: DataType = UnresolvedType +} + +/** transform(expr, func) - Transforms elements in an array using the function. */ +case class ArrayTransform(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TRANSFORM" + override def dataType: DataType = UnresolvedType +} + +/** transform_keys(expr, func) - Transforms elements in a map using the function. */ +case class TransformKeys(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TRANSFORM_KEYS" + override def dataType: DataType = UnresolvedType +} + +/** transform_values(expr, func) - Transforms values in the map using the function. */ +case class TransformValues(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TRANSFORM_VALUES" + override def dataType: DataType = UnresolvedType +} + +/** + * translate(input, from, to) - Translates the `input` string by replacing the characters present in the `from` string + * with the corresponding characters in the `to` string. + */ +case class StringTranslate(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "TRANSLATE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * trim(str) - Removes the leading and trailing space characters from `str`. + * + * trim(BOTH FROM str) - Removes the leading and trailing space characters from `str`. + * + * trim(LEADING FROM str) - Removes the leading space characters from `str`. + * + * trim(TRAILING FROM str) - Removes the trailing space characters from `str`. + * + * trim(trimStr FROM str) - Remove the leading and trailing `trimStr` characters from `str`. + * + * trim(BOTH trimStr FROM str) - Remove the leading and trailing `trimStr` characters from `str`. + * + * trim(LEADING trimStr FROM str) - Remove the leading `trimStr` characters from `str`. + * + * trim(TRAILING trimStr FROM str) - Remove the trailing `trimStr` characters from `str`. + */ +case class StringTrim(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def children: Seq[Expression] = Seq(left) ++ right + override def prettyName: String = "TRIM" + override def dataType: DataType = UnresolvedType +} + +/** typeof(expr) - Return DDL-formatted type string for the data type of the input. */ +case class TypeOf(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TYPEOF" + override def dataType: DataType = UnresolvedType +} + +/** ucase(str) - Returns `str` with all characters changed to uppercase. */ +case class Upper(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UCASE" + override def dataType: DataType = UnresolvedType +} + +/** unbase64(str) - Converts the argument from a base 64 string `str` to a binary. */ +case class UnBase64(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNBASE64" + override def dataType: DataType = UnresolvedType +} + +/** unhex(expr) - Converts hexadecimal `expr` to binary. */ +case class Unhex(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNHEX" + override def dataType: DataType = UnresolvedType +} + +/** + * uuid() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID + * 36-character string. + */ +case class Uuid() extends LeafExpression with Fn { + override def prettyName: String = "UUID" + override def dataType: DataType = UnresolvedType +} + +/** + * version() - Returns the Spark version. The string contains 2 fields, the first being a release version and the second + * being a git revision. + */ +case class SparkVersion() extends LeafExpression with Fn { + override def prettyName: String = "VERSION" + override def dataType: DataType = UnresolvedType +} + +/** + * width_bucket(value, min_value, max_value, num_bucket) - Returns the bucket number to which `value` would be assigned + * in an equiwidth histogram with `num_bucket` buckets, in the range `min_value` to `max_value`." + */ +case class WidthBucket(left: Expression, right: Expression, c: Expression, d: Expression) extends Expression with Fn { + override def prettyName: String = "WIDTH_BUCKET" + override def children: Seq[Expression] = Seq(left, right, c, d) + override def dataType: DataType = UnresolvedType +} + +/** N/A. */ +case class TimeWindow(left: Expression, windowDuration: Long, slideDuration: Long, startTime: Long) + extends Unary(left) + with Fn { + override def prettyName: String = "WINDOW" + override def dataType: DataType = UnresolvedType +} + +/** xpath(xml, xpath) - Returns a string array of values within the nodes of xml that match the XPath expression. */ +case class XPathList(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_boolean(xml, xpath) - Returns true if the XPath expression evaluates to true, or if a matching node is found. + */ +case class XPathBoolean(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_BOOLEAN" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_double(xml, xpath) - Returns a double value, the value zero if no match is found, or NaN if a match is found + * but the value is non-numeric. + */ +case class XPathDouble(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_DOUBLE" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_float(xml, xpath) - Returns a float value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. + */ +case class XPathFloat(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_FLOAT" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_int(xml, xpath) - Returns an integer value, or the value zero if no match is found, or a match is found but the + * value is non-numeric. + */ +case class XPathInt(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_INT" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_long(xml, xpath) - Returns a long integer value, or the value zero if no match is found, or a match is found + * but the value is non-numeric. + */ +case class XPathLong(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_LONG" + override def dataType: DataType = UnresolvedType +} + +/** + * xpath_short(xml, xpath) - Returns a short integer value, or the value zero if no match is found, or a match is found + * but the value is non-numeric. + */ +case class XPathShort(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_SHORT" + override def dataType: DataType = UnresolvedType +} + +/** xpath_string(xml, xpath) - Returns the text contents of the first xml node that matches the XPath expression. */ +case class XPathString(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "XPATH_STRING" + override def dataType: DataType = UnresolvedType +} + +/** xxhash64(expr1, expr2, ...) - Returns a 64-bit hash value of the arguments. */ +case class XxHash64(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "XXHASH64" + override def dataType: DataType = UnresolvedType +} + +/** + * zip_with(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one + * array is shorter, nulls are appended at the end to match the length of the longer array, before applying function. + */ +case class ZipWith(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "ZIP_WITH" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** any(expr) - Returns true if at least one value of `expr` is true. */ +case class BoolOr(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ANY" + override def dataType: DataType = UnresolvedType +} + +/** + * approx_count_distinct(expr[, relativeSD]) - Returns the estimated cardinality by HyperLogLog++. `relativeSD` defines + * the maximum relative standard deviation allowed. + */ +case class HyperLogLogPlusPlus(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "APPROX_COUNT_DISTINCT" + override def dataType: DataType = UnresolvedType +} + +/** avg(expr) - Returns the mean calculated from values of a group. */ +case class Average(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "AVG" + override def dataType: DataType = UnresolvedType +} + +/** bit_and(expr) - Returns the bitwise AND of all non-null input values, or null if none. */ +case class BitAndAgg(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIT_AND" + override def dataType: DataType = UnresolvedType +} + +/** bit_or(expr) - Returns the bitwise OR of all non-null input values, or null if none. */ +case class BitOrAgg(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIT_OR" + override def dataType: DataType = UnresolvedType +} + +/** bit_xor(expr) - Returns the bitwise XOR of all non-null input values, or null if none. */ +case class BitXorAgg(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BIT_XOR" + override def dataType: DataType = UnresolvedType +} + +/** bool_and(expr) - Returns true if all values of `expr` are true. */ +case class BoolAnd(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "BOOL_AND" + override def dataType: DataType = UnresolvedType +} + +/** collect_list(expr) - Collects and returns a list of non-unique elements. */ +case class CollectList(expr: Expression, cond: Option[Expression] = None) extends Expression with Fn { + // COLLECT_LIST and ARRAY_AGG are synonyms, but ARRAY_AGG is used in the test examples + override def prettyName: String = "ARRAY_AGG" + override def dataType: DataType = UnresolvedType + override def children: Seq[Expression] = Seq(expr) ++ cond.toSeq +} + +/** collect_set(expr) - Collects and returns a set of unique elements. */ +case class CollectSet(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "COLLECT_SET" + override def dataType: DataType = UnresolvedType +} + +/** corr(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs. */ +case class Corr(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "CORR" + override def dataType: DataType = UnresolvedType +} + +/** + * count(*) - Returns the total number of retrieved rows, including rows containing null. + * + * count(expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are all non-null. + * + * count(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and + * non-null. + */ +case class Count(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "COUNT" + override def dataType: DataType = UnresolvedType +} + +/** count_if(expr) - Returns the number of `TRUE` values for the expression. */ +case class CountIf(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "COUNT_IF" + override def dataType: DataType = UnresolvedType +} + +/** + * count_min_sketch(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, confidence + * and seed. The result is an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min + * sketch is a probabilistic data structure used for cardinality estimation using sub-linear space. + */ +case class CountMinSketchAgg(left: Expression, right: Expression, c: Expression, d: Expression) + extends Expression + with Fn { + override def prettyName: String = "COUNT_MIN_SKETCH" + override def children: Seq[Expression] = Seq(left, right, c, d) + override def dataType: DataType = UnresolvedType +} + +/** covar_pop(expr1, expr2) - Returns the population covariance of a set of number pairs. */ +case class CovPopulation(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "COVAR_POP" + override def dataType: DataType = UnresolvedType +} + +/** covar_samp(expr1, expr2) - Returns the sample covariance of a set of number pairs. */ +case class CovSample(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "COVAR_SAMP" + override def dataType: DataType = UnresolvedType +} + +/** + * first(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows. If `isIgnoreNull` is true, + * returns only non-null values. + */ +case class First(left: Expression, right: Option[Expression] = None) extends Expression with Fn { + override def children: Seq[Expression] = Seq(left) ++ right + override def prettyName: String = "FIRST" + override def dataType: DataType = UnresolvedType +} + +/** kurtosis(expr) - Returns the kurtosis value calculated from values of a group. */ +case class Kurtosis(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "KURTOSIS" + override def dataType: DataType = UnresolvedType +} + +/** + * last(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows. If `isIgnoreNull` is true, returns + * only non-null values + */ +case class Last(left: Expression, right: Option[Expression] = None) extends Expression with Fn { + override def children: Seq[Expression] = Seq(left) ++ right + override def prettyName: String = "LAST" + override def dataType: DataType = UnresolvedType +} + +/** max(expr) - Returns the maximum value of `expr`. */ +case class Max(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MAX" + override def dataType: DataType = UnresolvedType +} + +/** max_by(x, y) - Returns the value of `x` associated with the maximum value of `y`. */ +case class MaxBy(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "MAX_BY" + override def dataType: DataType = UnresolvedType +} + +/** min(expr) - Returns the minimum value of `expr`. */ +case class Min(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MIN" + override def dataType: DataType = UnresolvedType +} + +/** min_by(x, y) - Returns the value of `x` associated with the minimum value of `y`. */ +case class MinBy(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "MIN_BY" + override def dataType: DataType = UnresolvedType +} + +/** + * percentile(col, percentage [, frequency]) - Returns the exact percentile value of numeric column `col` at the given + * percentage. The value of percentage must be between 0.0 and 1.0. The value of frequency should be positive integral + * + * percentile(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact percentile value array of + * numeric column `col` at the given percentage(s). Each value of the percentage array must be between 0.0 and 1.0. The + * value of frequency should be positive integral + */ +case class Percentile(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "PERCENTILE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * percentile_approx(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric column `col` + * which is the smallest value in the ordered `col` values (sorted from least to greatest) such that no more than + * `percentage` of `col` values is less than the value or equal to that value. The value of percentage must be between + * 0.0 and 1.0. The `accuracy` parameter (default: 10000) is a positive numeric literal which controls approximation + * accuracy at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative + * error of the approximation. When `percentage` is an array, each value of the percentage array must be between 0.0 and + * 1.0. In this case, returns the approximate percentile array of column `col` at the given percentage array. + */ +case class ApproximatePercentile(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "PERCENTILE_APPROX" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** skewness(expr) - Returns the skewness value calculated from values of a group. */ +case class Skewness(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SKEWNESS" + override def dataType: DataType = UnresolvedType +} + +/** std(expr) - Returns the sample standard deviation calculated from values of a group. */ +case class StdSamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "STD" + override def dataType: DataType = UnresolvedType +} + +/** stddev(expr) - Returns the sample standard deviation calculated from values of a group. */ +case class StddevSamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "STDDEV" + override def dataType: DataType = UnresolvedType +} + +/** stddev_pop(expr) - Returns the population standard deviation calculated from values of a group. */ +case class StddevPop(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "STDDEV_POP" + override def dataType: DataType = UnresolvedType +} + +/** sum(expr) - Returns the sum calculated from values of a group. */ +case class Sum(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SUM" + override def dataType: DataType = UnresolvedType +} + +/** var_pop(expr) - Returns the population variance calculated from values of a group. */ +case class VariancePop(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "VAR_POP" + override def dataType: DataType = UnresolvedType +} + +/** var_samp(expr) - Returns the sample variance calculated from values of a group. */ +case class VarianceSamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "VAR_SAMP" + override def dataType: DataType = UnresolvedType +} + +/** array_contains(array, value) - Returns true if the array contains the value. */ +case class ArrayContains(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_CONTAINS" + override def dataType: DataType = UnresolvedType +} + +/** array_distinct(array) - Removes duplicate values from the array. */ +case class ArrayDistinct(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ARRAY_DISTINCT" + override def dataType: DataType = UnresolvedType +} + +/** + * array_except(array1, array2) - Returns an array of the elements in array1 but not in array2, without duplicates. + */ +case class ArrayExcept(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_EXCEPT" + override def dataType: DataType = UnresolvedType +} + +/** + * array_intersect(array1, array2) - Returns an array of the elements in the intersection of array1 and array2, without + * duplicates. + */ +case class ArrayIntersect(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_INTERSECT" + override def dataType: DataType = UnresolvedType +} + +/** + * array_join(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array using the delimiter + * and an optional string to replace nulls. If no value is set for nullReplacement, any null value is filtered. + */ +case class ArrayJoin(left: Expression, right: Expression, c: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "ARRAY_JOIN" + override def children: Seq[Expression] = Seq(left, right) ++ c.toSeq + override def dataType: DataType = UnresolvedType +} + +/** array_max(array) - Returns the maximum value in the array. NULL elements are skipped. */ +case class ArrayMax(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ARRAY_MAX" + override def dataType: DataType = UnresolvedType +} + +/** array_min(array) - Returns the minimum value in the array. NULL elements are skipped. */ +case class ArrayMin(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "ARRAY_MIN" + override def dataType: DataType = UnresolvedType +} + +/** array_position(array, element) - Returns the (1-based) index of the first element of the array as long. */ +case class ArrayPosition(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_POSITION" + override def dataType: DataType = UnresolvedType +} + +/** array_remove(array, element) - Remove all elements that equal to element from array. */ +case class ArrayRemove(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_REMOVE" + override def dataType: DataType = UnresolvedType +} + +/** array_repeat(element, count) - Returns the array containing element count times. */ +case class ArrayRepeat(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_REPEAT" + override def dataType: DataType = UnresolvedType +} + +/** + * array_union(array1, array2) - Returns an array of the elements in the union of array1 and array2, without duplicates. + */ +case class ArrayUnion(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAY_UNION" + override def dataType: DataType = UnresolvedType +} + +/** + * arrays_overlap(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays + * have no common element and they are both non-empty and either of them contains a null element null is returned, false + * otherwise. + */ +case class ArraysOverlap(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ARRAYS_OVERLAP" + override def dataType: DataType = UnresolvedType +} + +/** + * arrays_zip(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all N-th values of + * input arrays. + */ +case class ArraysZip(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "ARRAYS_ZIP" + override def dataType: DataType = UnresolvedType +} + +/** concat(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN. */ +case class Concat(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "CONCAT" + override def dataType: DataType = UnresolvedType +} + +/** flatten(arrayOfArrays) - Transforms an array of arrays into a single array. */ +case class Flatten(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "FLATTEN" + override def dataType: DataType = UnresolvedType +} + +/** reverse(array) - Returns a reversed string or an array with reverse order of elements. */ +case class Reverse(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "REVERSE" + override def dataType: DataType = UnresolvedType +} + +/** + * sequence(start, stop, step) - Generates an array of elements from start to stop (inclusive), incrementing by step. + * The type of the returned elements is the same as the type of argument expressions. + * + * Supported types are: byte, short, integer, long, date, timestamp. + * + * The start and stop expressions must resolve to the same type. If start and stop expressions resolve to the 'date' or + * 'timestamp' type then the step expression must resolve to the 'interval' type, otherwise to the same type as the + * start and stop expressions. + */ +case class Sequence(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "SEQUENCE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** shuffle(array) - Returns a random permutation of the given array. */ +case class Shuffle(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SHUFFLE" + override def dataType: DataType = UnresolvedType +} + +/** + * slice(x, start, length) - Subsets array x starting from index start (array indices start at 1, or starting from the + * end if start is negative) with the specified length. + */ +case class Slice(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "SLICE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * sort_array(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural + * ordering of the array elements. Null elements will be placed at the beginning of the returned array in ascending + * order or at the end of the returned array in descending order. + */ +case class SortArray(left: Expression, right: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "SORT_ARRAY" + override def children: Seq[Expression] = Seq(left) ++ right.toSeq + override def dataType: DataType = UnresolvedType +} + +/** add_months(start_date, num_months) - Returns the date that is `num_months` after `start_date`. */ +case class AddMonths(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "ADD_MONTHS" + override def dataType: DataType = UnresolvedType +} + +/** + * current_date() - Returns the current date at the start of query evaluation. All calls of current_date within the same + * query return the same value. + * + * current_date - Returns the current date at the start of query evaluation. + */ +case class CurrentDate() extends LeafExpression with Fn { + override def prettyName: String = "CURRENT_DATE" + override def dataType: DataType = UnresolvedType +} + +/** + * current_timestamp() - Returns the current timestamp at the start of query evaluation. All calls of current_timestamp + * within the same query return the same value. + * + * current_timestamp - Returns the current timestamp at the start of query evaluation. + */ +case class CurrentTimestamp() extends LeafExpression with Fn { + override def prettyName: String = "CURRENT_TIMESTAMP" + override def dataType: DataType = UnresolvedType +} + +/** current_timezone() - Returns the current session local timezone. */ +case class CurrentTimeZone() extends LeafExpression with Fn { + override def prettyName: String = "CURRENT_TIMEZONE" + override def dataType: DataType = UnresolvedType +} + +/** date_add(start_date, num_days) - Returns the date that is `num_days` after `start_date`. */ +case class DateAdd(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATE_ADD" + override def dataType: DataType = UnresolvedType +} + +/** + * date_format(timestamp, fmt) - Converts `timestamp` to a value of string in the format specified by the date format + * `fmt`. + */ +case class DateFormatClass(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATE_FORMAT" + override def dataType: DataType = UnresolvedType +} + +/** date_from_unix_date(days) - Create date from the number of days since 1970-01-01. */ +case class DateFromUnixDate(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "DATE_FROM_UNIX_DATE" + override def dataType: DataType = UnresolvedType +} + +/** date_part(field, source) - Extracts a part of the date/timestamp or interval source. */ +case class DatePart(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATE_PART" + override def dataType: DataType = UnresolvedType +} + +/** date_sub(start_date, num_days) - Returns the date that is `num_days` before `start_date`. */ +case class DateSub(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATE_SUB" + override def dataType: DataType = UnresolvedType +} + +/** date_trunc(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`. */ +case class TruncTimestamp(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATE_TRUNC" + override def dataType: DataType = UnresolvedType +} + +/** datediff(endDate, startDate) - Returns the number of days from `startDate` to `endDate`. */ +case class DateDiff(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "DATEDIFF" + override def dataType: DataType = UnresolvedType +} + +/** datediff(units, start, end) - Returns the difference between two timestamps measured in `units`. */ +case class TimestampDiff(unit: String, start: Expression, end: Expression, timeZoneId: Option[String] = None) + extends Binary(start, end) + with Fn { + // TIMESTAMPDIFF and DATEDIFF are synonyms, but DATEDIFF is used in the example queries, so we stick to it for now. + override def prettyName: String = "DATEDIFF" + override def dataType: DataType = UnresolvedType +} + +/** dayofweek(date) - Returns the day of the week for date/timestamp (1 = Sunday, 2 = Monday, ..., 7 = Saturday). */ +case class DayOfWeek(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "DAYOFWEEK" + override def dataType: DataType = UnresolvedType +} + +/** dayofyear(date) - Returns the day of year of the date/timestamp. */ +case class DayOfYear(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "DAYOFYEAR" + override def dataType: DataType = UnresolvedType +} + +/** from_unixtime(unix_time[, fmt]) - Returns `unix_time` in the specified `fmt`. */ +case class FromUnixTime(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FROM_UNIXTIME" + override def dataType: DataType = UnresolvedType +} + +/** + * from_utc_timestamp(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in + * UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 + * 03:40:00.0'. + */ +case class FromUTCTimestamp(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "FROM_UTC_TIMESTAMP" + override def dataType: DataType = UnresolvedType +} + +/** hour(timestamp) - Returns the hour component of the string/timestamp. */ +case class Hour(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "HOUR" + override def dataType: DataType = UnresolvedType +} + +/** last_day(date) - Returns the last day of the month which the date belongs to. */ +case class LastDay(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "LAST_DAY" + override def dataType: DataType = UnresolvedType +} + +/** make_date(year, month, day) - Create date from year, month and day fields. */ +case class MakeDate(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "MAKE_DATE" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * make_timestamp(year, month, day, hour, min, sec[, timezone]) - Create timestamp from year, month, day, hour, min, sec + * and timezone fields. + */ +case class MakeTimestamp( + left: Expression, + right: Expression, + c: Expression, + d: Expression, + e: Expression, + f: Expression, + g: Option[Expression]) + extends Expression + with Fn { + override def prettyName: String = "MAKE_TIMESTAMP" + override def children: Seq[Expression] = Seq(left, right, c, d, e, f) ++ g.toSeq + override def dataType: DataType = UnresolvedType +} + +/** minute(timestamp) - Returns the minute component of the string/timestamp. */ +case class Minute(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MINUTE" + override def dataType: DataType = UnresolvedType +} + +/** month(date) - Returns the month component of the date/timestamp. */ +case class Month(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MONTH" + override def dataType: DataType = UnresolvedType +} + +/** + * months_between(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result is + * positive. If `timestamp1` and `timestamp2` are on the same day of month, or both are the last day of month, time of + * day will be ignored. Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits + * unless roundOff=false. + */ +case class MonthsBetween(left: Expression, right: Expression, c: Expression) extends Expression with Fn { + override def prettyName: String = "MONTHS_BETWEEN" + override def children: Seq[Expression] = Seq(left, right, c) + override def dataType: DataType = UnresolvedType +} + +/** + * next_day(start_date, day_of_week) - Returns the first date which is later than `start_date` and named as indicated. + */ +case class NextDay(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "NEXT_DAY" + override def dataType: DataType = UnresolvedType +} + +/** now() - Returns the current timestamp at the start of query evaluation. */ +case class Now() extends LeafExpression with Fn { + override def prettyName: String = "NOW" + override def dataType: DataType = UnresolvedType +} + +/** quarter(date) - Returns the quarter of the year for date, in the range 1 to 4. */ +case class Quarter(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "QUARTER" + override def dataType: DataType = UnresolvedType +} + +/** second(timestamp) - Returns the second component of the string/timestamp. */ +case class Second(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "SECOND" + override def dataType: DataType = UnresolvedType +} + +/** timestamp_micros(microseconds) - Creates timestamp from the number of microseconds since UTC epoch. */ +case class MicrosToTimestamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TIMESTAMP_MICROS" + override def dataType: DataType = UnresolvedType +} + +/** timestamp_millis(milliseconds) - Creates timestamp from the number of milliseconds since UTC epoch. */ +case class MillisToTimestamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TIMESTAMP_MILLIS" + override def dataType: DataType = UnresolvedType +} + +/** timestamp_seconds(seconds) - Creates timestamp from the number of seconds (can be fractional) since UTC epoch. */ +case class SecondsToTimestamp(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "TIMESTAMP_SECONDS" + override def dataType: DataType = UnresolvedType +} + +/** + * to_date(date_str[, fmt]) - Parses the `date_str` expression with the `fmt` expression to a date. Returns null with + * invalid input. By default, it follows casting rules to a date if the `fmt` is omitted. + */ +case class ParseToDate(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def prettyName: String = "TO_DATE" + override def children: Seq[Expression] = Seq(left) ++ right.toSeq + override def dataType: DataType = UnresolvedType +} + +/** + * to_timestamp(timestamp_str[, fmt]) - Parses the `timestamp_str` expression with the `fmt` expression to a timestamp. + * Returns null with invalid input. By default, it follows casting rules to a timestamp if the `fmt` is omitted. + */ +case class ParseToTimestamp(left: Expression, right: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "TO_TIMESTAMP" + override def dataType: DataType = UnresolvedType + override def children: Seq[Expression] = Seq(left) ++ right.toSeq +} + +/** to_unix_timestamp(timeExp[, fmt]) - Returns the UNIX timestamp of the given time. */ +case class ToUnixTimestamp(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TO_UNIX_TIMESTAMP" + override def dataType: DataType = UnresolvedType +} + +/** + * to_utc_timestamp(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in + * the given time zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 + * 01:40:00.0'. + */ +case class ToUTCTimestamp(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TO_UTC_TIMESTAMP" + override def dataType: DataType = UnresolvedType +} + +/** + * trunc(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format + * model `fmt`. + */ +case class TruncDate(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "TRUNC" + override def dataType: DataType = UnresolvedType +} + +/** unix_date(date) - Returns the number of days since 1970-01-01. */ +case class UnixDate(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNIX_DATE" + override def dataType: DataType = UnresolvedType +} + +/** unix_micros(timestamp) - Returns the number of microseconds since 1970-01-01 00:00:00 UTC. */ +case class UnixMicros(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNIX_MICROS" + override def dataType: DataType = UnresolvedType +} + +/** + * unix_millis(timestamp) - Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of + * precision. + */ +case class UnixMillis(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNIX_MILLIS" + override def dataType: DataType = UnresolvedType +} + +/** + * unix_seconds(timestamp) - Returns the number of seconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of + * precision. + */ +case class UnixSeconds(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "UNIX_SECONDS" + override def dataType: DataType = UnresolvedType +} + +/** unix_timestamp([timeExp[, fmt]]) - Returns the UNIX timestamp of current or specified time. */ +case class UnixTimestamp(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "UNIX_TIMESTAMP" + override def dataType: DataType = UnresolvedType +} + +/** weekday(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). */ +case class WeekDay(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "WEEKDAY" + override def dataType: DataType = UnresolvedType +} + +/** + * weekofyear(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week + * 1 is the first week with >3 days. + */ +case class WeekOfYear(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "WEEKOFYEAR" + override def dataType: DataType = UnresolvedType +} + +/** year(date) - Returns the year component of the date/timestamp. */ +case class Year(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "YEAR" + override def dataType: DataType = UnresolvedType +} + +/** from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. */ +case class JsonToStructs(left: Expression, right: Expression, c: Option[Expression]) extends Expression with Fn { + override def prettyName: String = "FROM_JSON" + override def children: Seq[Expression] = Seq(left, right) ++ c + override def dataType: DataType = UnresolvedType +} + +/** get_json_object(json_txt, path) - Extracts a json object from `path`. */ +case class GetJsonObject(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "GET_JSON_OBJECT" + override def dataType: DataType = UnresolvedType +} + +/** json_array_length(jsonArray) - Returns the number of elements in the outmost JSON array. */ +case class LengthOfJsonArray(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "JSON_ARRAY_LENGTH" + override def dataType: DataType = UnresolvedType +} + +/** json_object_keys(json_object) - Returns all the keys of the outmost JSON object as an array. */ +case class JsonObjectKeys(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "JSON_OBJECT_KEYS" + override def dataType: DataType = UnresolvedType +} + +/** + * json_tuple(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple + * names. All the input parameters and output column types are string. + */ +case class JsonTuple(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "JSON_TUPLE" + override def dataType: DataType = UnresolvedType +} + +/** schema_of_json(json[, options]) - Returns schema in the DDL format of JSON string. */ +case class SchemaOfJson(left: Expression, right: Expression) extends Binary(left, right) with Fn { + override def prettyName: String = "SCHEMA_OF_JSON" + override def dataType: DataType = UnresolvedType +} + +/** to_json(expr[, options]) - Returns a JSON string with a given struct value */ +case class StructsToJson(left: Expression, right: Option[Expression]) extends Expression with Fn { + override def prettyName: String = "TO_JSON" + override def children: Seq[Expression] = Seq(left) ++ right.toSeq + override def dataType: DataType = UnresolvedType +} + +/** map_concat(map, ...) - Returns the union of all the given maps */ +case class MapConcat(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "MAP_CONCAT" + override def dataType: DataType = UnresolvedType +} + +/** map_entries(map) - Returns an unordered array of all entries in the given map. */ +case class MapEntries(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MAP_ENTRIES" + override def dataType: DataType = UnresolvedType +} + +/** map_from_entries(arrayOfEntries) - Returns a map created from the given array of entries. */ +case class MapFromEntries(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MAP_FROM_ENTRIES" + override def dataType: DataType = UnresolvedType +} + +/** map_keys(map) - Returns an unordered array containing the keys of the map. */ +case class MapKeys(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MAP_KEYS" + override def dataType: DataType = UnresolvedType +} + +/** map_values(map) - Returns an unordered array containing the values of the map. */ +case class MapValues(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "MAP_VALUES" + override def dataType: DataType = UnresolvedType +} + +/** cume_dist() - Computes the position of a value relative to all values in the partition. */ +case class CumeDist() extends LeafExpression with Fn { + override def prettyName: String = "CUME_DIST" + override def dataType: DataType = UnresolvedType +} + +/** + * dense_rank() - Computes the rank of a value in a group of values. The result is one plus the previously assigned rank + * value. Unlike the function rank, dense_rank will not produce gaps in the ranking sequence. + */ +case class DenseRank(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "DENSE_RANK" + override def dataType: DataType = UnresolvedType +} + +/** + * lag(input[, offset[, default]]) - Returns the value of `input` at the `offset`th row before the current row in the + * window. The default value of `offset` is 1 and the default value of `default` is null. If the value of `input` at the + * `offset`th row is null, null is returned. If there is no such offset row (e.g., when the offset is 1, the first row + * of the window does not have any previous row), `default` is returned. + */ +case class Lag(left: Expression, offset: Option[Expression] = None, default: Option[Expression] = None) + extends Expression + with Fn { + override def prettyName: String = "LAG" + override def children: Seq[Expression] = Seq(left) ++ offset ++ default + override def dataType: DataType = left.dataType +} + +/** + * lead(input[, offset[, default]]) - Returns the value of `input` at the `offset`th row after the current row in the + * window. The default value of `offset` is 1 and the default value of `default` is null. If the value of `input` at the + * `offset`th row is null, null is returned. If there is no such an offset row (e.g., when the offset is 1, the last row + * of the window does not have any subsequent row), `default` is returned. + */ +case class Lead(left: Expression, offset: Option[Expression] = None, default: Option[Expression] = None) + extends Expression + with Fn { + override def children: Seq[Expression] = Seq(left) ++ offset ++ default + override def prettyName: String = "LEAD" + override def dataType: DataType = left.dataType +} + +/** + * nth_value(input[, offset]) - Returns the value of `input` at the row that is the `offset`th row from beginning of the + * window frame. Offset starts at 1. If ignoreNulls=true, we will skip nulls when finding the `offset`th row. Otherwise, + * every row counts for the `offset`. If there is no such an `offset`th row (e.g., when the offset is 10, size of the + * window frame is less than 10), null is returned. + */ +case class NthValue(input: Expression, offset: Expression = Literal(1), ignoreNulls: Option[Expression] = None) + extends Expression + with Fn { + override def children: Seq[Expression] = Seq(input, offset) ++ ignoreNulls + override def prettyName: String = "NTH_VALUE" + override def dataType: DataType = input.dataType +} + +/** + * ntile(n) - Divides the rows for each window partition into `n` buckets ranging from 1 to at most `n`. + */ +case class NTile(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "NTILE" + override def dataType: DataType = UnresolvedType +} + +/** percent_rank() - Computes the percentage ranking of a value in a group of values. */ +case class PercentRank(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "PERCENT_RANK" + override def dataType: DataType = UnresolvedType +} + +/** + * rank() - Computes the rank of a value in a group of values. The result is one plus the number of rows preceding or + * equal to the current row in the ordering of the partition. The values will produce gaps in the sequence. + */ +case class Rank(children: Seq[Expression]) extends Expression with Fn { + override def prettyName: String = "RANK" + override def dataType: DataType = UnresolvedType +} + +/** + * row_number() - Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows + * within the window partition. + */ +case class RowNumber() extends LeafExpression with Fn { + override def prettyName: String = "ROW_NUMBER" + override def dataType: DataType = UnresolvedType +} + +/** + * to_number(expr, fmt) - Returns expr cast to DECIMAL using formatting fmt + */ +case class ToNumber(expr: Expression, fmt: Expression) extends Binary(expr, fmt) with Fn { + override def prettyName: String = "TO_NUMBER" + override def dataType: DataType = UnresolvedType +} + +/** + * try_to_number(expr, fmt) - Returns expr cast to DECIMAL using formatting fmt, or NULL if expr does not match the + * format. + */ +case class TryToNumber(expr: Expression, fmt: Expression) extends Binary(expr, fmt) with Fn { + override def prettyName: String = "TRY_TO_NUMBER" + override def dataType: DataType = UnresolvedType +} + +/** + * try_to_timestamp(expr, fmt) - Returns expr cast to a timestamp using an optional formatting, or NULL if the cast + * fails. + */ +case class TryToTimestamp(expr: Expression, fmt: Option[Expression] = None) extends Expression with Fn { + override def prettyName: String = "TRY_TO_TIMESTAMP" + override def dataType: DataType = TimestampType + override def children: Seq[Expression] = Seq(expr) ++ fmt.toSeq +} + +/** + * timestampadd(unit, value, expr) - Adds value units to a timestamp expr + */ +case class TimestampAdd(unit: String, quantity: Expression, timestamp: Expression) extends Expression with Fn { + // TIMESTAMPADD, DATE_ADD and DATEADD are synonyms, but the latter is used in the examples. + override def prettyName: String = "DATEADD" + override def children: Seq[Expression] = Seq(quantity, timestamp) + override def dataType: DataType = TimestampType +} + +/** + * try_cast(sourceExpr AS targetType) - Returns the value of sourceExpr cast to data type targetType if possible, or + * NULL if not possible. + */ +case class TryCast(expr: Expression, override val dataType: DataType) extends Expression { + override def children: Seq[Expression] = Seq(expr) +} + +/** + * parse_json(expr) - Parses the JSON string `expr` and returns the resulting structure. + */ +case class ParseJson(left: Expression) extends Unary(left) with Fn { + override def prettyName: String = "PARSE_JSON" + override def dataType: DataType = VariantType +} + +/** + * startswith(expr, startExpr) - Returns true if expr begins with startExpr. + */ +case class StartsWith(expr: Expression, startExpr: Expression) extends Binary(expr, startExpr) with Fn { + override def prettyName: String = "STARTSWITH" + override def dataType: DataType = BooleanType +} + +/** + * array_append(array, elem) - Returns array appended by elem. + */ +case class ArrayAppend(array: Expression, elem: Expression) extends Binary(array, elem) with Fn { + override def prettyName: String = "ARRAY_APPEND" + override def dataType: DataType = array.dataType +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/literals.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/literals.scala new file mode 100644 index 0000000000..28b54aa172 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/literals.scala @@ -0,0 +1,195 @@ +package com.databricks.labs.remorph.intermediate + +import java.time.ZoneOffset + +case class Literal(value: Any, dataType: DataType) extends LeafExpression + +object Literal { + val True: Literal = Literal(true, BooleanType) + val False: Literal = Literal(false, BooleanType) + val Null: Literal = Literal(null, NullType) + private[this] val byteType = ByteType(Some(1)) + private[intermediate] val defaultDecimal = DecimalType(38, 18) + + // this factory returns Expression instead of Literal, because array and map literals have children + // and Literal is a LeafExpression, which has no children by definition. + def apply(value: Any): Expression = value match { + case null => Null + case i: Int => Literal(i, IntegerType) + case l: Long => Literal(l, LongType) + case f: Float => NumericLiteral(f) + case d: Double => NumericLiteral(d) + case b: Byte => Literal(b, byteType) + case s: Short => Literal(s, ShortType) + case s: String => Literal(s, StringType) + case c: Char => Literal(c.toString, StringType) + case b: Boolean => Literal(b, BooleanType) + case d: BigDecimal => Literal(d, DecimalType.fromBigDecimal(d)) + // TODO: revise date type handling later + case d: java.sql.Date => Literal(d.toLocalDate.toEpochDay, DateType) + case d: java.time.LocalDate => Literal(d.toEpochDay, DateType) + case d: java.sql.Timestamp => Literal(d.getTime, TimestampType) + case d: java.time.LocalDateTime => Literal(d.toEpochSecond(ZoneOffset.UTC), TimestampType) + case d: Array[Byte] => Literal(d, BinaryType) + case a: Array[_] => + val elementType = componentTypeToDataType(a.getClass.getComponentType) + val dataType = ArrayType(elementType) + Literal(convert(a, dataType), dataType) + case s: Seq[_] => + val elementType = componentTypeToDataType(s.head.getClass) + val dataType = ArrayType(elementType) + convert(s, dataType) + case m: Map[_, _] => + val keyType = componentTypeToDataType(m.keys.head.getClass) + val valueType = componentTypeToDataType(m.values.head.getClass) + val dataType = MapType(keyType, valueType) + convert(m, dataType) + case _ => throw new IllegalStateException(s"Unsupported value: $value") + } + + private def convert(value: Any, dataType: DataType): Expression = (value, dataType) match { + case (Some(v), t) if t.isPrimitive => Literal(v, t) + case (None, t) if t.isPrimitive => Null + case (v, t) if t.isPrimitive => Literal(v, t) + case (v: Array[_], ArrayType(elementType)) => + val elements = v.map { e => convert(e, elementType) }.toList + ArrayExpr(elements, dataType) + case (v: Seq[_], ArrayType(elementType)) => + val elements = v.map { e => convert(e, elementType) }.toList + ArrayExpr(elements, dataType) + case (v: Map[_, _], MapType(keyType, valueType)) => + val map = v.map { case (k, v) => convert(k, keyType) -> convert(v, valueType) } + MapExpr(map, dataType) + case _ => + throw new IllegalStateException(s"Unsupported value: $value and dataType: $dataType") + } + + private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match { + // primitive types + case java.lang.Short.TYPE => ShortType + case java.lang.Integer.TYPE => IntegerType + case java.lang.Long.TYPE => LongType + case java.lang.Double.TYPE => DoubleType + case java.lang.Byte.TYPE => byteType + case java.lang.Float.TYPE => FloatType + case java.lang.Boolean.TYPE => BooleanType + case java.lang.Character.TYPE => StringType + + // java classes + case _ if clz == classOf[java.sql.Date] => DateType + case _ if clz == classOf[java.time.LocalDate] => DateType + case _ if clz == classOf[java.time.Instant] => TimestampType + case _ if clz == classOf[java.sql.Timestamp] => TimestampType + case _ if clz == classOf[java.time.LocalDateTime] => TimestampNTZType + case _ if clz == classOf[java.time.Duration] => DayTimeIntervalType + case _ if clz == classOf[java.time.Period] => YearMonthIntervalType + case _ if clz == classOf[java.math.BigDecimal] => defaultDecimal + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[Array[Char]] => StringType + case _ if clz == classOf[java.lang.Short] => ShortType + case _ if clz == classOf[java.lang.Integer] => IntegerType + case _ if clz == classOf[java.lang.Long] => LongType + case _ if clz == classOf[java.lang.Double] => DoubleType + case _ if clz == classOf[java.lang.Byte] => byteType + case _ if clz == classOf[java.lang.Float] => FloatType + case _ if clz == classOf[java.lang.Boolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] => defaultDecimal + case _ if clz == classOf[BigDecimal] => defaultDecimal + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + + case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) + + case _ => + throw new IllegalStateException(s"Unsupported type: $clz") + } + + // TODO: validate the value and dataType +} + +case class ArrayExpr(children: Seq[Expression], dataType: DataType) extends Expression + +case class MapExpr(map: Map[Expression, Expression], dataType: DataType) extends Expression { + override def children: Seq[Expression] = (map.keys ++ map.values).toList +} + +object NumericLiteral { + def apply(v: String): Literal = convert(BigDecimal(v)) + def apply(v: Double): Literal = convert(BigDecimal(v)) + def apply(v: Float): Literal = apply(v.toString) + + private def convert(d: BigDecimal): Literal = d match { + case d if d.isValidInt => Literal(d.toInt, IntegerType) + case d if d.isValidLong => Literal(d.toLong, LongType) + case d if d.isDecimalFloat || d.isExactFloat => Literal(d.toFloat, FloatType) + case d if d.isDecimalDouble || d.isExactDouble => Literal(d.toDouble, DoubleType) + case d => DecimalLiteral.auto(d) + } +} + +object DecimalLiteral { + private[intermediate] def auto(d: BigDecimal): Literal = Literal(d, DecimalType.fromBigDecimal(d)) + def apply(v: Long): Literal = auto(BigDecimal(v)) + def apply(v: Double): Literal = auto(BigDecimal(v)) + def apply(v: String): Literal = auto(BigDecimal(v)) + def unapply(e: Expression): Option[BigDecimal] = e match { + case Literal(v: BigDecimal, _: DecimalType) => Some(v) + case _ => None + } +} + +object FloatLiteral { + def apply(f: Float): Literal = Literal(f, FloatType) + def unapply(a: Any): Option[Float] = a match { + case Literal(a: Float, FloatType) => Some(a) + case _ => None + } +} + +object DoubleLiteral { + def apply(d: Double): Literal = Literal(d, DoubleType) + def unapply(a: Any): Option[Double] = a match { + case Literal(a: Double, DoubleType) => Some(a) + case Literal(a: Float, FloatType) => Some(a.toDouble) + case _ => None + } +} + +object IntLiteral { + def apply(i: Int): Literal = Literal(i, IntegerType) + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + case Literal(a: Short, ShortType) => Some(a.toInt) + case _ => None + } +} + +object LongLiteral { + def apply(l: Long): Literal = Literal(l, LongType) + def apply(i: Int): Literal = Literal(i.toLong, LongType) + def apply(s: Short): Literal = Literal(s.toLong, LongType) + def unapply(a: Any): Option[Long] = a match { + case Literal(a: Long, LongType) => Some(a) + case Literal(a: Int, IntegerType) => Some(a.toLong) + case Literal(a: Short, ShortType) => Some(a.toLong) + case _ => None + } +} + +object StringLiteral { + def apply(s: String): Literal = Literal(s, StringType) + def unapply(a: Any): Option[String] = a match { + case Literal(s: String, StringType) => Some(s) + case _ => None + } +} + +object BooleanLiteral { + def apply(b: Boolean): Literal = if (b) Literal.True else Literal.False + def unapply(a: Any): Option[Boolean] = a match { + case Literal(true, BooleanType) => Some(true) + case _ => Some(false) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/operators.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/operators.scala new file mode 100644 index 0000000000..ddfa3296d9 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/operators.scala @@ -0,0 +1,111 @@ +package com.databricks.labs.remorph.intermediate + +trait Predicate extends AstExtension { + def dataType: DataType = BooleanType +} + +case class And(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class Or(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class Not(pred: Expression) extends Unary(pred) with Predicate + +case class Equals(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class NotEquals(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class LessThan(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class LessThanOrEqual(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class GreaterThan(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binary(left, right) with Predicate +case class Between(exp: Expression, lower: Expression, upper: Expression) extends Expression with Predicate { + override def children: Seq[Expression] = Seq(exp, lower, upper) +} + +trait Bitwise + +// Bitwise NOT is highest precedence after parens '(' ')' +case class BitwiseNot(expression: Expression) extends Unary(expression) with Bitwise { + override def dataType: DataType = expression.dataType +} + +// Binary bitwise expressions +case class BitwiseAnd(left: Expression, right: Expression) extends Binary(left, right) with Bitwise { + override def dataType: DataType = left.dataType +} + +case class BitwiseOr(left: Expression, right: Expression) extends Binary(left, right) with Bitwise { + override def dataType: DataType = left.dataType +} + +case class BitwiseXor(left: Expression, right: Expression) extends Binary(left, right) with Bitwise { + override def dataType: DataType = left.dataType +} + +trait Arithmetic + +// Unary arithmetic expressions +case class UMinus(expression: Expression) extends Unary(expression) with Arithmetic { + override def dataType: DataType = expression.dataType +} + +case class UPlus(expression: Expression) extends Unary(expression) with Arithmetic { + override def dataType: DataType = expression.dataType +} + +// Binary Arithmetic expressions +case class Multiply(left: Expression, right: Expression) extends Binary(left, right) with Arithmetic { + override def dataType: DataType = left.dataType +} + +case class Divide(left: Expression, right: Expression) extends Binary(left, right) with Arithmetic { + override def dataType: DataType = left.dataType +} + +case class Mod(left: Expression, right: Expression) extends Binary(left, right) with Arithmetic { + override def dataType: DataType = left.dataType +} + +case class Add(left: Expression, right: Expression) extends Binary(left, right) with Arithmetic { + override def dataType: DataType = left.dataType +} + +case class Subtract(left: Expression, right: Expression) extends Binary(left, right) with Arithmetic { + override def dataType: DataType = left.dataType +} + +/** + * str like pattern[ ESCAPE escape] - Returns true if str matches `pattern` with `escape`, null if any arguments are + * null, false otherwise. + * + * NB: escapeChar is a full expression that evaluates to a single char at runtime, not parse time + */ +case class Like(left: Expression, right: Expression, escapeChar: Option[Expression]) extends Binary(left, right) { + override def dataType: DataType = BooleanType +} + +case class LikeAll(child: Expression, patterns: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = child +: patterns + override def dataType: DataType = BooleanType +} + +case class LikeAny(child: Expression, patterns: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = child +: patterns + override def dataType: DataType = BooleanType +} + +// NB: escapeChar is a full expression that evaluates to a single char at runtime, not parse time +case class ILike(left: Expression, right: Expression, escapeChar: Option[Expression]) extends Binary(left, right) { + override def dataType: DataType = BooleanType +} + +case class ILikeAll(child: Expression, patterns: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = child +: patterns + override def dataType: DataType = BooleanType +} + +case class ILikeAny(child: Expression, patterns: Seq[Expression]) extends Expression { + override def children: Seq[Expression] = child +: patterns + override def dataType: DataType = BooleanType +} + +/** str rlike regexp - Returns true if `str` matches `regexp`, or false otherwise. */ +case class RLike(left: Expression, right: Expression) extends Binary(left, right) { + override def dataType: DataType = BooleanType +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/options.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/options.scala new file mode 100644 index 0000000000..3be5344e41 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/options.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.intermediate + +trait GenericOption { + def id: String +} + +case class OptionExpression(id: String, value: Expression, supplement: Option[String]) extends GenericOption +case class OptionString(id: String, value: String) extends GenericOption +case class OptionOn(id: String) extends GenericOption +case class OptionOff(id: String) extends GenericOption +case class OptionAuto(id: String) extends GenericOption +case class OptionDefault(id: String) extends GenericOption +case class OptionUnresolved(id: String) extends GenericOption + +class OptionLists( + val expressionOpts: Map[String, Expression], + val stringOpts: Map[String, String], + val boolFlags: Map[String, Boolean], + val autoFlags: List[String]) {} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/plans.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/plans.scala new file mode 100644 index 0000000000..dacaffe507 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/plans.scala @@ -0,0 +1,183 @@ +package com.databricks.labs.remorph.intermediate + +/** + * A [[Plan]] is the structure that carries the runtime information for the execution from the client to the server. A + * [[Plan]] can either be of the type [[Relation]] which is a reference to the underlying logical plan or it can be of + * the [[Command]] type that is used to execute commands on the server. [[Plan]] is a union of Spark's LogicalPlan and + * QueryPlan. + */ +abstract class Plan[PlanType <: Plan[PlanType]] extends TreeNode[PlanType] { + self: PlanType => + + def output: Seq[Attribute] + + /** + * Returns the set of attributes that are output by this node. + */ + @transient + lazy val outputSet: AttributeSet = new AttributeSet(output: _*) + + /** + * The set of all attributes that are child to this operator by its children. + */ + def inputSet: AttributeSet = new AttributeSet(children.flatMap(_.asInstanceOf[Plan[PlanType]].output): _*) + + /** + * The set of all attributes that are produced by this node. + */ + def producedAttributes: AttributeSet = new AttributeSet() + + /** + * All Attributes that appear in expressions from this operator. Note that this set does not include attributes that + * are implicitly referenced by being passed through to the output tuple. + */ + @transient + lazy val references: AttributeSet = new AttributeSet(expressions.flatMap(_.references): _*) -- producedAttributes + + /** + * Attributes that are referenced by expressions but not provided by this node's children. + */ + final def missingInput: AttributeSet = references -- inputSet + + /** + * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query operator. Users should not + * expect a specific directionality. If a specific directionality is needed, transformExpressionsDown or + * transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDown(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformDown(rule)) + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @return + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUp(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new query operator based on + * the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { + var changed = false + + @inline def transformExpression(e: Expression): Expression = { + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } + if (newE.fastEquals(e)) { + e + } else { + changed = true + newE + } + } + + def recursiveTransform(arg: Any): AnyRef = arg match { + case e: Expression => transformExpression(e) + case Some(value) => Some(recursiveTransform(value)) + case m: Map[_, _] => m + case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force + case seq: Iterable[_] => seq.map(recursiveTransform) + case other: AnyRef => other + case null => null + } + + val newArgs = mapProductIterator(recursiveTransform) + + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + /** + * Returns the result of running [[transformExpressions]] on this node and all its children. Note that this method + * skips expressions inside subqueries. + */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transform { case q: Plan[_] => + q.transformExpressions(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** + * Returns all of the expressions present in this query (that is expression defined in this plan operator and in each + * its descendants). + */ + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case p: Plan[_] => p.expressions + case s: Iterable[_] => seqToExpressions(s) + case other => Nil + } + + productIterator.flatMap { + case e: Expression => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) + case seq: Iterable[_] => seqToExpressions(seq) + case other => Nil + }.toSeq + } +} + +abstract class LogicalPlan extends Plan[LogicalPlan] { + + /** + * Returns true if this expression and all its children have been resolved to a specific schema and false if it still + * contains any unresolved placeholders. Implementations of LogicalPlan can override this (e.g.[[UnresolvedRelation]] + * should return `false`). + */ + lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved + + /** + * Returns true if all its children of this query plan have been resolved. + */ + def childrenResolved: Boolean = children.forall(_.resolved) + + def schema: DataType = { + val concrete = output.forall(_.isInstanceOf[Attribute]) + if (concrete) { + StructType(output.map { a: Attribute => + StructField(a.name, a.dataType) + }) + } else { + UnresolvedType + } + } + +} + +abstract class LeafNode extends LogicalPlan { + override def children: Seq[LogicalPlan] = Nil + override def producedAttributes: AttributeSet = outputSet +} + +abstract class UnaryNode extends LogicalPlan { + def child: LogicalPlan + override def children: Seq[LogicalPlan] = child :: Nil +} + +abstract class BinaryNode extends LogicalPlan { + def left: LogicalPlan + def right: LogicalPlan + override def children: Seq[LogicalPlan] = Seq(left, right) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CaseStatement.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CaseStatement.scala new file mode 100644 index 0000000000..4da682ef2a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CaseStatement.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.LogicalPlan + +// there's Switch(..) node in this package just to represent a case statement with value match. Theoretically, +// we can merge these two, but semantics are clearer this way. +case class CaseStatement(when: Seq[WhenClause], orElse: Seq[LogicalPlan] = Seq.empty) extends Statement { + override def children: Seq[LogicalPlan] = when ++ orElse +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CompoundStatement.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CompoundStatement.scala new file mode 100644 index 0000000000..c0c0ccfe85 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/CompoundStatement.scala @@ -0,0 +1,6 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.LogicalPlan + +// aka BEGIN ... END +case class CompoundStatement(children: Seq[LogicalPlan], label: Option[String] = None) extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareCondition.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareCondition.scala new file mode 100644 index 0000000000..940f1e091a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareCondition.scala @@ -0,0 +1,3 @@ +package com.databricks.labs.remorph.intermediate.procedures + +case class DeclareCondition(name: String, sqlstate: Option[String] = None) extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareContinueHandler.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareContinueHandler.scala new file mode 100644 index 0000000000..3e6d43065b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareContinueHandler.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +case class DeclareContinueHandler(conditionValues: Seq[Expression], handlerAction: LogicalPlan) extends Statement { + override def children: Seq[LogicalPlan] = Seq(handlerAction) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareExitHandler.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareExitHandler.scala new file mode 100644 index 0000000000..b385332c8d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareExitHandler.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +case class DeclareExitHandler(conditionValues: Seq[Expression], handlerAction: LogicalPlan) extends Statement { + override def children: Seq[LogicalPlan] = Seq(handlerAction) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareVariable.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareVariable.scala new file mode 100644 index 0000000000..d0aafd6a32 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/DeclareVariable.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{DataType, Expression} + +case class DeclareVariable(name: String, datatype: DataType, default: Option[Expression] = None) extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ElseIf.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ElseIf.scala new file mode 100644 index 0000000000..591109e3d4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ElseIf.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +case class ElseIf(condition: Expression, children: Seq[LogicalPlan]) extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ForStatement.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ForStatement.scala new file mode 100644 index 0000000000..ffb8a3cf97 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/ForStatement.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.LogicalPlan + +// FOR [ variable_name AS ] query ... DO ... END FOR +case class ForStatement( + variableName: Option[String], + query: LogicalPlan, + statements: Seq[LogicalPlan], + label: Option[String] = None) + extends Statement { + override def children: Seq[LogicalPlan] = Seq(query) ++ statements +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/If.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/If.scala new file mode 100644 index 0000000000..fbd69cf79c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/If.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +case class If( + condition: Expression, + thenDo: Seq[LogicalPlan], + elseIf: Seq[ElseIf] = Seq.empty, + orElse: Seq[LogicalPlan] = Seq.empty) + extends Statement { + override def children: Seq[LogicalPlan] = thenDo ++ elseIf ++ orElse +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Iterate.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Iterate.scala new file mode 100644 index 0000000000..ab759d14f8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Iterate.scala @@ -0,0 +1,4 @@ +package com.databricks.labs.remorph.intermediate.procedures + +// Stops the current iteration of a loop and moves on to the next one. +case class Iterate(label: String) extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/LeafStatement.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/LeafStatement.scala new file mode 100644 index 0000000000..4764f0c1d4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/LeafStatement.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.LogicalPlan + +abstract class LeafStatement extends Statement { + override def children: Seq[LogicalPlan] = Seq() +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Leave.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Leave.scala new file mode 100644 index 0000000000..1744c3b3d8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Leave.scala @@ -0,0 +1,4 @@ +package com.databricks.labs.remorph.intermediate.procedures + +// Resumes execution following the specified labeled statement. +case class Leave(label: String) extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Loop.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Loop.scala new file mode 100644 index 0000000000..105f2b6045 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Loop.scala @@ -0,0 +1,6 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.LogicalPlan + +// Executes a series of statements repeatedly. +case class Loop(children: Seq[LogicalPlan], label: Option[String] = None) extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/RepeatUntil.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/RepeatUntil.scala new file mode 100644 index 0000000000..2b2fc0f8f0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/RepeatUntil.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +// Executes a series of statements repeatedly until the loop-terminating condition is satisfied. +case class RepeatUntil(condition: Expression, children: Seq[LogicalPlan], label: Option[String] = None) + extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Return.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Return.scala new file mode 100644 index 0000000000..b347bdfbd4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Return.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.Expression + +case class Return(value: Expression) extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/SetVariable.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/SetVariable.scala new file mode 100644 index 0000000000..15c9c953da --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/SetVariable.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate._ + +case class SetVariable(name: Id, value: Expression, dataType: Option[DataType] = None) extends LeafNode with Command diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Signal.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Signal.scala new file mode 100644 index 0000000000..6d0512a86f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Signal.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.procedures + +case class Signal( + conditionName: String, + messageParms: Map[String, String] = Map.empty, + messageText: Option[String], + sqlState: Option[String] = None) + extends LeafStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Statement.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Statement.scala new file mode 100644 index 0000000000..cf47736659 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Statement.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Command, LogicalPlan} + +abstract class Statement extends LogicalPlan with Command diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Switch.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Switch.scala new file mode 100644 index 0000000000..3a43cacf51 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/Switch.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +// CASE toMatch WHEN x THEN y WHEN z THEN a ELSE b END CASE +case class Switch(toMatch: Expression, when: Seq[WhenClause], orElse: Seq[LogicalPlan] = Seq.empty) extends Statement { + override def children: Seq[LogicalPlan] = when ++ orElse +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/WhenClause.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/WhenClause.scala new file mode 100644 index 0000000000..f5a61777c6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/WhenClause.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +case class WhenClause(condition: Expression, children: Seq[LogicalPlan]) extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/While.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/While.scala new file mode 100644 index 0000000000..83ddc1b2b1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/procedures/While.scala @@ -0,0 +1,6 @@ +package com.databricks.labs.remorph.intermediate.procedures + +import com.databricks.labs.remorph.intermediate.{Expression, LogicalPlan} + +// Continuously executes a list of statements as long as a specified condition remains true. +case class While(condition: Expression, children: Seq[LogicalPlan], label: Option[String] = None) extends Statement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala new file mode 100644 index 0000000000..f44cb32b88 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala @@ -0,0 +1,399 @@ +package com.databricks.labs.remorph.intermediate + +abstract class Relation extends LogicalPlan + +abstract class RelationCommon extends Relation {} + +case class SQL(query: String, named_arguments: Map[String, Expression], pos_arguments: Seq[Expression]) + extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +abstract class Read(is_streaming: Boolean) extends LeafNode + +// TODO: replace it with TableIdentifier with catalog and schema filled +// TODO: replace most (if not all) occurrences with UnresolvedRelation +case class NamedTable( + unparsed_identifier: String, + options: Map[String, String] = Map.empty, + is_streaming: Boolean = false) + extends Read(is_streaming) { + override def output: Seq[Attribute] = Seq.empty +} + +case class DataSource( + format: String, + schemaString: String, + options: Map[String, String], + paths: Seq[String], + predicates: Seq[String], + is_streaming: Boolean) + extends Read(is_streaming) { + override def output: Seq[Attribute] = Seq.empty +} + +case class Project(input: LogicalPlan, columns: Seq[Expression]) extends UnaryNode { + override def child: LogicalPlan = input + // TODO: add resolver for Star + override def output: Seq[Attribute] = expressions.map { + case a: Attribute => a + case Alias(child, Id(name, _)) => AttributeReference(name, child.dataType) + case Id(name, _) => AttributeReference(name, UnresolvedType) + case Column(_, Id(name, _)) => AttributeReference(name, UnresolvedType) + case expr: Expression => throw new UnsupportedOperationException(s"cannot convert to attribute: $expr") + } +} + +case class Filter(input: LogicalPlan, condition: Expression) extends UnaryNode { + override def child: LogicalPlan = input + override def output: Seq[Attribute] = input.output +} + +abstract class JoinType + +abstract class SetOpType + +abstract class GroupType + +abstract class ParseFormat + +case class JoinDataType(is_left_struct: Boolean, is_right_struct: Boolean) + +case class Join( + left: LogicalPlan, + right: LogicalPlan, + join_condition: Option[Expression], + join_type: JoinType, + using_columns: Seq[String], + join_data_type: JoinDataType) + extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output +} + +case class SetOperation( + left: LogicalPlan, + right: LogicalPlan, + set_op_type: SetOpType, + is_all: Boolean, + by_name: Boolean, + allow_missing_columns: Boolean) + extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output +} + +case class Limit(child: LogicalPlan, limit: Expression) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Offset(child: LogicalPlan, offset: Expression) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Tail(child: LogicalPlan, limit: Int) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Pivot(col: Expression, values: Seq[Expression]) + +case class Aggregate( + child: LogicalPlan, + group_type: GroupType, + grouping_expressions: Seq[Expression], + pivot: Option[Pivot]) + extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ grouping_expressions.map(_.asInstanceOf[Attribute]) +} + +abstract class SortDirection(val sql: String) +case object UnspecifiedSortDirection extends SortDirection("") +case object Ascending extends SortDirection("ASC") +case object Descending extends SortDirection("DESC") + +abstract class NullOrdering(val sql: String) +case object SortNullsUnspecified extends NullOrdering("") +case object NullsFirst extends NullOrdering("NULLS FIRST") +case object NullsLast extends NullOrdering("NULLS LAST") + +case class SortOrder( + expr: Expression, + direction: SortDirection = UnspecifiedSortDirection, + nullOrdering: NullOrdering = SortNullsUnspecified) + extends Unary(expr) { + override def dataType: DataType = child.dataType +} + +case class Sort(child: LogicalPlan, order: Seq[SortOrder], is_global: Boolean = false) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Drop(child: LogicalPlan, columns: Seq[Expression], column_names: Seq[String]) extends UnaryNode { + override def output: Seq[Attribute] = child.output diff columns.map(_.asInstanceOf[Attribute]) +} + +case class Deduplicate( + child: LogicalPlan, + column_names: Seq[Expression], + all_columns_as_keys: Boolean, + within_watermark: Boolean) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class LocalRelation(child: LogicalPlan, data: Array[Byte], schemaString: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class CachedLocalRelation(hash: String) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class CachedRemoteRelation(relation_id: String) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class Sample( + child: LogicalPlan, + lower_bound: Double, + upper_bound: Double, + with_replacement: Boolean, + seed: Long, + deterministic_order: Boolean) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Range(start: Long, end: Long, step: Long, num_partitions: Int) extends LeafNode { + override def output: Seq[Attribute] = Seq(AttributeReference("id", LongType)) +} + +// TODO: most likely has to be SubqueryAlias(identifier: AliasIdentifier, child: LogicalPlan) +case class SubqueryAlias(child: LogicalPlan, alias: Id, columnNames: Seq[Id] = Seq.empty) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Repartition(child: LogicalPlan, num_partitions: Int, shuffle: Boolean) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class ShowString(child: LogicalPlan, num_rows: Int, truncate: Int, vertical: Boolean) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class HtmlString(child: LogicalPlan, num_rows: Int, truncate: Int) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatSummary(child: LogicalPlan, statistics: Seq[String]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatDescribe(child: LogicalPlan, cols: Seq[String]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatCrosstab(child: LogicalPlan, col1: String, col2: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatCov(child: LogicalPlan, col1: String, col2: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatCorr(child: LogicalPlan, col1: String, col2: String, method: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatApproxQuantile(child: LogicalPlan, cols: Seq[String], probabilities: Seq[Double], relative_error: Double) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class StatFreqItems(child: LogicalPlan, cols: Seq[String], support: Double) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Fraction(stratum: Literal, fraction: Double) + +case class StatSampleBy(child: LogicalPlan, col: Expression, fractions: Seq[Fraction], seed: Long) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class NAFill(child: LogicalPlan, cols: Seq[String], values: Seq[Literal]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class NADrop(child: LogicalPlan, cols: Seq[String], min_non_nulls: Int) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Replacement(old_value: Literal, new_value: Literal) + +case class NAReplace(child: LogicalPlan, cols: Seq[String], replacements: Seq[Replacement]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class ToDF(child: LogicalPlan, column_names: Seq[String]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class WithColumnsRenamed(child: LogicalPlan, rename_columns_map: Map[String, String]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class WithColumns(child: LogicalPlan, aliases: Seq[Alias]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class WithWatermark(child: LogicalPlan, event_time: String, delay_threshold: String) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Hint(child: LogicalPlan, name: String, parameters: Seq[Expression]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class Values(values: Seq[Seq[Expression]]) extends LeafNode { // TODO: fix it + override def output: Seq[Attribute] = Seq.empty +} + +case class Unpivot( + child: LogicalPlan, + ids: Seq[Expression], + values: Option[Values], + variable_column_name: Id, + value_column_name: Id) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class ToSchema(child: LogicalPlan, dataType: DataType) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class RepartitionByExpression(child: LogicalPlan, partition_exprs: Seq[Expression], num_partitions: Int) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class MapPartitions(child: LogicalPlan, func: CommonInlineUserDefinedTableFunction, is_barrier: Boolean) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class GroupMap( + child: LogicalPlan, + grouping_expressions: Seq[Expression], + func: CommonInlineUserDefinedFunction, + sorting_expressions: Seq[Expression], + initial_input: LogicalPlan, + initial_grouping_expressions: Seq[Expression], + is_map_groups_with_state: Boolean, + output_mode: String, + timeout_conf: String) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class CoGroupMap( + left: LogicalPlan, + input_grouping_expressions: Seq[Expression], + right: LogicalPlan, + other_grouping_expressions: Seq[Expression], + func: CommonInlineUserDefinedFunction, + input_sorting_expressions: Seq[Expression], + other_sorting_expressions: Seq[Expression]) + extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output +} + +case class ApplyInPandasWithState( + child: LogicalPlan, + grouping_expressions: Seq[Expression], + func: CommonInlineUserDefinedFunction, + output_schema: String, + state_schema: String, + output_mode: String, + timeout_conf: String) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class PythonUDTF(return_type: DataType, eval_type: Int, command: Array[Byte], python_ver: String) + +case class CommonInlineUserDefinedTableFunction( + function_name: String, + deterministic: Boolean, + arguments: Seq[Expression], + python_udtf: Option[PythonUDTF]) + extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case class CollectMetrics(child: LogicalPlan, name: String, metrics: Seq[Expression]) extends UnaryNode { + override def output: Seq[Attribute] = child.output + +} + +case class Parse(child: LogicalPlan, format: ParseFormat, dataType: DataType, options: Map[String, String]) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +case class AsOfJoin( + left: LogicalPlan, + right: LogicalPlan, + left_as_of: Expression, + right_as_of: Expression, + join_expr: Option[Expression], + using_columns: Seq[String], + join_type: String, + tolerance: Option[Expression], + allow_exact_matches: Boolean, + direction: String) + extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output +} + +case object Unknown extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} + +case object UnspecifiedJoin extends JoinType + +case object InnerJoin extends JoinType + +case object FullOuterJoin extends JoinType + +case object LeftOuterJoin extends JoinType + +case object RightOuterJoin extends JoinType + +case object LeftAntiJoin extends JoinType + +case object LeftSemiJoin extends JoinType + +case object CrossJoin extends JoinType + +case class NaturalJoin(joinType: JoinType) extends JoinType + +case object UnspecifiedSetOp extends SetOpType + +case object IntersectSetOp extends SetOpType + +case object UnionSetOp extends SetOpType + +case object ExceptSetOp extends SetOpType + +case object UnspecifiedGroupType extends GroupType + +case object GroupBy extends GroupType + +case object GroupByAll extends GroupType + +case object Pivot extends GroupType + +case object UnspecifiedFormat extends ParseFormat + +case object JsonFormat extends ParseFormat + +case object CsvFormat extends ParseFormat diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/rules.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/rules.scala new file mode 100644 index 0000000000..984c2e0ee8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/rules.scala @@ -0,0 +1,25 @@ +package com.databricks.labs.remorph.intermediate + +abstract class Rule[T <: TreeNode[_]] { + val ruleName: String = { + val className = getClass.getName + if (className endsWith "$") className.dropRight(1) else className + } + + def apply(tree: T): T +} + +case class Rules[T <: TreeNode[_]](rules: Rule[T]*) extends Rule[T] { + def apply(tree: T): T = { + rules.foldLeft(tree) { case (p, rule) => rule(p) } + } +} + +// We use UPPERCASE convention to refer to function names in the codebase, +// but it is not a requirement in the transpiled code. This rule is used to +// enforce the convention. +object AlwaysUpperNameForCallFunction extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case CallFunction(name, args) => + CallFunction(name.toUpperCase(), args) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/subqueries.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/subqueries.scala new file mode 100644 index 0000000000..6232b62a08 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/subqueries.scala @@ -0,0 +1,20 @@ +package com.databricks.labs.remorph.intermediate + +abstract class SubqueryExpression(val plan: LogicalPlan) extends Expression { + override def children: Seq[Expression] = plan.expressions // TODO: not sure if this is a good idea + override def dataType: DataType = plan.schema +} + +// returns one column. TBD if we want to split between +// one row (scala) and ListQuery (many rows), as it makes +// little difference for SQL code generation. +// scalar: SELECT * FROM a WHERE id = (SELECT id FROM b LIMIT 1) +// list: SELECT * FROM a WHERE id IN(SELECT id FROM b) +case class ScalarSubquery(relation: LogicalPlan) extends SubqueryExpression(relation) { + override def dataType: DataType = relation.schema +} + +// checks if a row exists in a subquery given some condition +case class Exists(relation: LogicalPlan) extends SubqueryExpression(relation) { + override def dataType: DataType = relation.schema +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/trees.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/trees.scala new file mode 100644 index 0000000000..3e892792c7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/trees.scala @@ -0,0 +1,684 @@ +package com.databricks.labs.remorph.intermediate + +import com.databricks.labs.remorph.utils.Strings.truncatedString +import com.fasterxml.jackson.annotation.JsonIgnore + +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ +private class MutableInt(var i: Int) + +case class Origin( + line: Option[Int] = None, + startPosition: Option[Int] = None, + startIndex: Option[Int] = None, + stopIndex: Option[Int] = None, + sqlText: Option[String] = None, + objectType: Option[String] = None, + objectName: Option[String] = None) + +object CurrentOrigin { + private[this] val value = new ThreadLocal[Origin]() { + override def initialValue: Origin = Origin() + } + + def get: Origin = value.get() + + def setPosition(line: Int, start: Int): Unit = { + value.set(value.get.copy(line = Some(line), startPosition = Some(start))) + } + + def withOrigin[A](o: Origin)(f: => A): A = { + set(o) + val ret = + try f + finally { reset() } + ret + } + + def set(o: Origin): Unit = value.set(o) + + def reset(): Unit = value.set(Origin()) +} + +class TreeNodeException[TreeType <: TreeNode[_]](@transient val tree: TreeType, msg: String, cause: Throwable) + extends Exception(msg, cause) { + + val treeString = tree.toString + + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT + // external project dependencies for some reason. + def this(tree: TreeType, msg: String) = this(tree, msg, null) + + override def getMessage: String = { + s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" + } +} + +// scalastyle:off +abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { + // scalastyle:on + self: BaseType => + + @JsonIgnore lazy val containsChild: Set[TreeNode[_]] = children.toSet + private lazy val _hashCode: Int = productHash(this, scala.util.hashing.MurmurHash3.productSeed) + private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]] + @JsonIgnore val origin: Origin = CurrentOrigin.get + + /** + * Returns a Seq of the children of this node. Children should not change. Immutability required for containsChild + * optimization + */ + def children: Seq[BaseType] + + override def hashCode(): Int = _hashCode + + /** + * Faster version of equality which short-circuits when two treeNodes are the same instance. We don't just override + * Object.equals, as doing so prevents the scala compiler from generating case class `equals` methods + */ + def fastEquals(other: TreeNode[_]): Boolean = { + this.eq(other) || this == other + } + + /** + * Find the first [[TreeNode]] that satisfies the condition specified by `f`. The condition is recursively applied to + * this node and all of its children (pre-order). + */ + def find(f: BaseType => Boolean): Option[BaseType] = if (f(this)) { + Some(this) + } else { + children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) } + } + + /** + * Runs the given function on this node and then recursively on [[children]]. + * @param f + * the function to be applied to each node in the tree. + */ + def foreach(f: BaseType => Unit): Unit = { + f(this) + children.foreach(_.foreach(f)) + } + + /** + * Runs the given function recursively on [[children]] then on this node. + * @param f + * the function to be applied to each node in the tree. + */ + def foreachUp(f: BaseType => Unit): Unit = { + children.foreach(_.foreachUp(f)) + f(this) + } + + /** + * Returns a Seq containing the result of applying the given function to each node in this tree in a preorder + * traversal. + * @param f + * the function to be applied. + */ + def map[A](f: BaseType => A): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret += f(_)) + ret.toList.toSeq + } + + /** + * Returns a Seq by applying a function to all nodes in this tree and using the elements of the resulting collections. + */ + def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret ++= f(_)) + ret.toList.toSeq + } + + /** + * Returns a Seq containing the result of applying a partial function to all elements in this tree on which the + * function is defined. + */ + def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = { + val ret = new collection.mutable.ArrayBuffer[B]() + val lifted = pf.lift + foreach(node => lifted(node).foreach(ret.+=)) + ret.toList.toSeq + } + + /** + * Returns a Seq containing the leaves in this tree. + */ + def collectLeaves(): Seq[BaseType] = { + this.collect { case p if p.children.isEmpty => p } + } + + /** + * Finds and returns the first [[TreeNode]] of the tree for which the given partial function is defined (pre-order), + * and applies the partial function to it. + */ + def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = { + val lifted = pf.lift + lifted(this).orElse { + children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) } + } + } + + /** + * Returns a copy of this node with the children replaced. TODO: Validate somewhere (in debug mode?) that children are + * ordered correctly. + */ + def withNewChildren(newChildren: Seq[BaseType]): BaseType = { + assert(newChildren.size == children.size, "Incorrect number of children") + var changed = false + val remainingNewChildren = newChildren.toBuffer + val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + // CaseWhen Case or any tuple type + case (left, right) => (mapChild(left), mapChild(right)) + case nonChild: AnyRef => nonChild + case null => null + } + val newArgs = mapProductIterator { + case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] + // Handle Seq[TreeNode] in TreeNode parameters. + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `map.mapValues().view.force` return `Map` in Scala 2.12 but return `IndexedSeq` in Scala + // 2.13, call `toMap` method manually to compatible with Scala 2.12 and Scala 2.13 + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force.toMap + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case Some(child) => Some(mapChild(child)) + case nonChild: AnyRef => nonChild + case null => null + } + + if (changed) makeCopy(newArgs) else this + } + + /** + * Efficient alternative to `productIterator.map(f).toArray`. + */ + protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = { + val arr = Array.ofDim[B](productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(productElement(i)) + i += 1 + } + arr + } + + /** + * Creates a copy of this type of tree node after a transformation. Must be overridden by child classes that have + * constructor arguments that are not present in the productIterator. + * @param newArgs + * the new product arguments. + */ + def makeCopy(newArgs: Array[AnyRef]): BaseType = makeCopy(newArgs, allowEmptyArgs = false) + + /** + * Creates a copy of this type of tree node after a transformation. Must be overridden by child classes that have + * constructor arguments that are not present in the productIterator. + * @param newArgs + * the new product arguments. + * @param allowEmptyArgs + * whether to allow argument list to be empty. + */ + private def makeCopy(newArgs: Array[AnyRef], allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") { + val allCtors = getClass.getConstructors + if (newArgs.isEmpty && allCtors.isEmpty) { + // This is a singleton object which doesn't have any constructor. Just return `this` as we + // can't copy it. + return this + } + + // Skip no-arg constructors that are just there for kryo. + val ctors = allCtors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) + if (ctors.isEmpty) { + sys.error(s"No valid constructor for $nodeName") + } + val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) { + newArgs + } else { + newArgs ++ otherCopyArgs + } + val defaultCtor = ctors + .find { ctor => + if (ctor.getParameterTypes.length != allArgs.length) { + false + } else if (allArgs.contains(null)) { + // if there is a `null`, we can't figure out the class, therefore we should just fallback + // to older heuristic + false + } else { + val argsArray: Array[Class[_]] = allArgs.map(_.getClass) + isAssignable(argsArray, ctor.getParameterTypes) + } + } + .getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic + + try { + CurrentOrigin.withOrigin(origin) { + val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + res + } + } catch { + case e: java.lang.IllegalArgumentException => + throw new TreeNodeException( + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |types: ${newArgs.map(_.getClass).mkString(", ")} + |args: ${newArgs.mkString(", ")} + """.stripMargin) + } + } + + /** + * Wraps any exceptions that are thrown while executing `f` in a [[TreeNodeException]], attaching the provided `tree`. + */ + def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = { + try f + catch { + // difference from the original code: we are not checking for SparkException + case NonFatal(e) => + throw new TreeNodeException(tree, msg, e) + } + } + + /** + * Args to the constructor that should be copied, but not transformed. These are appended to the transformed args + * automatically by makeCopy + * @return + */ + protected def otherCopyArgs: Seq[AnyRef] = Nil + + /** + * Simplified version compared to commons-lang3 + * @param classArray + * the class array + * @param toClassArray + * the class array to check against + * @return + * true if the classArray is assignable to toClassArray + */ + private def isAssignable(classArray: Array[Class[_]], toClassArray: Array[Class[_]]): Boolean = { + if (classArray.length != toClassArray.length) { + false + } else { + classArray.zip(toClassArray).forall { case (c, toC) => + c.isPrimitive match { + case true => c == toC + case false => toC.isAssignableFrom(c) + } + } + } + } + + /** + * Returns the name of this type of TreeNode. Defaults to the class name. Note that we remove the "Exec" suffix for + * physical operators here. + */ + def nodeName: String = simpleClassName.replaceAll("Exec$", "") + + private def simpleClassName: String = try { + this.getClass.getSimpleName + } catch { + case _: InternalError => + val name = this.getClass.getName + val dollar = name.lastIndexOf('$') + if (dollar == -1) name else name.substring(0, dollar) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. When `rule` does not apply to a + * given node it is left unchanged. Users should not expect a specific directionality. If a specific directionality is + * needed, transformDown or transformUp should be used. + * + * @param rule + * the function use to transform this nodes children + */ + def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = { + transformDown(rule) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its children (pre-order). + * When `rule` does not apply to a given node it is left unchanged. + * + * @param rule + * the function used to transform this nodes children + */ + def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + mapChildren(_.transformDown(rule)) + } else { + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.mapChildren(_.transformDown(rule)) + } + } + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its children and then itself + * (post-order). When `rule` does not apply to a given node, it is left unchanged. + * + * @param rule + * the function use to transform this nodes children + */ + def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRuleOnChildren = mapChildren(_.transformUp(rule)) + val newNode = if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } + } + // If the transform function replaces this node with a new one, carry over the tags. + newNode + } + + /** + * Returns a copy of this node where `f` has been applied to all the nodes in `children`. + */ + def mapChildren(f: BaseType => BaseType): BaseType = { + if (containsChild.nonEmpty) { + mapChildren(f, forceCopy = false) + } else { + this + } + } + + override def clone(): BaseType = { + mapChildren(_.clone(), forceCopy = true) + } + + /** Returns a string representing the arguments to this node, minus any children */ + def argString(maxFields: Int): String = stringArgs + .flatMap { + case tn: TreeNode[_] if allChildren.contains(tn) => Nil + case Some(tn: TreeNode[_]) if allChildren.contains(tn) => Nil + case Some(tn: TreeNode[_]) => tn.simpleString(maxFields) :: Nil + case tn: TreeNode[_] => tn.simpleString(maxFields) :: Nil + case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil + case iter: Iterable[_] if iter.isEmpty => Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]", maxFields) :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}", maxFields) :: Nil + case array: Array[_] if array.isEmpty => Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]", maxFields) :: Nil + case null => Nil + case None => Nil + case Some(null) => Nil + case Some(any) => any :: Nil + case map: Map[_, _] => + redactMapString(map, maxFields) + case table: CatalogTable => + table.identifier :: Nil + case other => other :: Nil + } + .mkString(", ") + + /** + * ONE line description of this node. + * @param maxFields + * Maximum number of fields that will be converted to strings. Any elements beyond the limit will be dropped. + */ + def simpleString(maxFields: Int): String = s"$nodeName ${argString(maxFields)}".trim + + override def toString: String = treeString.replaceAll("\n]\n", "]\n") // TODO: fix properly + + /** Returns a string representation of the nodes in this tree */ + final def treeString: String = treeString() + + final def treeString(maxFields: Int = 25): String = { + val concat = new StringBuilder() + treeString(str => concat.append(str), maxFields) + concat.toString + } + + def treeString(append: String => Unit, maxFields: Int): Unit = { + generateTreeString(0, Nil, append, "", maxFields) + } + + /** + * Returns a string representation of the nodes in this tree, where each operator is numbered. The numbers can be used + * with [[TreeNode.apply]] to easily access specific subtrees. + * + * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first before children). + */ + def numberedTreeString: String = + treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") + + /** + * Returns the tree node at the specified number, used primarily for interactive debugging. Numbers for each node can + * be found in the [[numberedTreeString]]. + * + * Note that this cannot return BaseType because logical plan's plan node might return physical plan for + * innerChildren, e.g. in-memory child logical plan node has a reference to the physical plan node it is referencing. + */ + def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull + + /** + * Returns the tree node at the specified number, used primarily for interactive debugging. Numbers for each node can + * be found in the [[numberedTreeString]]. + * + * This is a variant of [[apply]] that returns the node as BaseType (if the type matches). + */ + def p(number: Int): BaseType = apply(number).asInstanceOf[BaseType] + + /** + * All the nodes that should be shown as a inner nested tree of this node. For example, this can be used to show + * sub-queries. + */ + def innerChildren: Seq[TreeNode[_]] = Seq.empty + + /** + * Returns a 'scala code' representation of this `TreeNode` and its children. Intended for use when debugging where + * the prettier toString function is obfuscating the actual structure. In the case of 'pure' `TreeNodes` that only + * contain primitives and other TreeNodes, the result can be pasted in the REPL to build an equivalent Tree. + */ + def asCode: String = pprint.apply(self).plainText + + /** + * The arguments that should be included in the arg string. Defaults to the `productIterator`. + */ + protected def stringArgs: Iterator[Any] = productIterator + + // Copied from Scala 2.13.1 + // github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73 + // to prevent the issue https://github.com/scala/bug/issues/10495 + // TODO(SPARK-30848): Remove this once we drop Scala 2.12. + private final def productHash(x: Product, seed: Int, ignorePrefix: Boolean = false): Int = { + val arr = x.productArity + // Case objects have the hashCode inlined directly into the + // synthetic hashCode method, but this method should still give + // a correct result if passed a case object. + if (arr == 0) { + x.productPrefix.hashCode + } else { + var h = seed + if (!ignorePrefix) h = scala.util.hashing.MurmurHash3.mix(h, x.productPrefix.hashCode) + var i = 0 + while (i < arr) { + h = scala.util.hashing.MurmurHash3.mix(h, x.productElement(i).##) + i += 1 + } + scala.util.hashing.MurmurHash3.finalizeHash(h, arr) + } + } + + /** + * Returns a copy of this node where `f` has been applied to all the nodes in `children`. + * @param f + * The transform function to be applied on applicable `TreeNode` elements. + * @param forceCopy + * Whether to force making a copy of the nodes even if no child has been changed. + */ + private def mapChildren(f: BaseType => BaseType, forceCopy: Boolean): BaseType = { + var changed = false + + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + + val newArgs = mapProductIterator { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case Some(arg: TreeNode[_]) if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } + // `map.mapValues().view.force` return `Map` in Scala 2.12 but return `IndexedSeq` in Scala + // 2.13, call `toMap` method manually to compatible with Scala 2.12 and Scala 2.13 + case m: Map[_, _] => + m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view + .force + .toMap // `mapValues` is lazy and we need to force it to materialize + case d: DataType => d // Avoid unpacking Structs + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Iterable[_] => args.map(mapChild) + case nonChild: AnyRef => nonChild + case null => null + } + if (forceCopy || changed) makeCopy(newArgs, forceCopy) else this + } + + private def redactMapString[K, V](map: Map[K, V], maxFields: Int): List[String] = { + // For security reason, redact the map value if the key is in centain patterns + val redactedMap = map.toMap + // construct the redacted map as strings of the format "key=value" + val keyValuePairs = redactedMap.toSeq.map { item => + item._1 + "=" + item._2 + } + truncatedString(keyValuePairs, "[", ", ", "]", maxFields) :: Nil + } + + private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = { + if (number.i < 0) { + None + } else if (number.i == 0) { + Some(this) + } else { + number.i -= 1 + // Note that this traversal order must be the same as numberedTreeString. + innerChildren.map(_.getNodeNumbered(number)).find(_.isDefined).getOrElse { + children.map(_.getNodeNumbered(number)).find(_.isDefined).flatten + } + } + } + + /** + * Appends the string representation of this node and its children to the given Writer. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at depth `i + 1` is the + * last child of its own parent node. The depth of the root node is 0, and `lastChildren` for the root node should be + * empty. + * + * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]]. + */ + private def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + prefix: String = "", + maxFields: Int): Unit = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + append(if (isLast) " " else ": ") + } + append(if (lastChildren.last) "+- " else ":- ") + } + + val str = simpleString(maxFields) + append(prefix) + append(str) + append("\n") + + if (innerChildren.nonEmpty) { + innerChildren.init.foreach( + _.generateTreeString(depth + 2, lastChildren :+ children.isEmpty :+ false, append, maxFields = maxFields)) + innerChildren.last.generateTreeString( + depth + 2, + lastChildren :+ children.isEmpty :+ true, + append, + maxFields = maxFields) + } + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, append, prefix, maxFields)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, append, prefix, maxFields) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/unresolved.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/unresolved.scala new file mode 100644 index 0000000000..b267ed9caf --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/unresolved.scala @@ -0,0 +1,181 @@ +package com.databricks.labs.remorph.intermediate + +trait UnwantedInGeneratorInput + +trait Unresolved[T] { + def ruleText: String + def message: String + def ruleName: String + def tokenName: Option[String] + + def annotate(newRuleName: String, newTokenName: Option[String]): T +} +case class UnresolvedRelation( + ruleText: String, + message: String = "", + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends LeafNode + with UnwantedInGeneratorInput + with Unresolved[UnresolvedRelation] { + override def output: Seq[Attribute] = Seq.empty + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedRelation = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedExpression( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends LeafExpression + with UnwantedInGeneratorInput + with Unresolved[UnresolvedExpression] { + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedExpression = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedAttribute( + unparsed_identifier: String, + plan_id: Long = 0, + is_metadata_column: Boolean = false, + ruleText: String = "", + message: String = "", + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends LeafExpression + with Unresolved[UnresolvedAttribute] { + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedAttribute = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedFunction( + function_name: String, + arguments: Seq[Expression], + is_distinct: Boolean, + is_user_defined_function: Boolean, + has_incorrect_argc: Boolean = false, + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Expression + with Unresolved[UnresolvedFunction] { + override def children: Seq[Expression] = arguments + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedFunction = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedStar( + unparsed_target: String, + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends LeafExpression + with Unresolved[UnresolvedStar] { + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedStar = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedRegex( + col_name: String, + plan_id: Long, + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends LeafExpression + with Unresolved[UnresolvedRegex] { + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedRegex = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedExtractValue( + child: Expression, + extraction: Expression, + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Expression + with Unresolved[UnresolvedExtractValue] { + override def children: Seq[Expression] = child :: extraction :: Nil + override def dataType: DataType = UnresolvedType + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedExtractValue = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedCommand( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Catalog + with Command + with UnwantedInGeneratorInput + with Unresolved[UnresolvedCommand] { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedCommand = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedCatalog( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Catalog + with UnwantedInGeneratorInput + with Unresolved[UnresolvedCatalog] { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedCatalog = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedCTAS( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Catalog + with Command + with UnwantedInGeneratorInput + with Unresolved[UnresolvedCTAS] { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedCTAS = + copy(ruleName = newRuleName, tokenName = newTokenName) +} + +case class UnresolvedModification( + ruleText: String, + message: String, + ruleName: String = "rule name undetermined", + tokenName: Option[String] = None) + extends Modification + with UnwantedInGeneratorInput + with Unresolved[UnresolvedModification] { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + + override def annotate(newRuleName: String, newTokenName: Option[String]): UnresolvedModification = + copy(ruleName = newRuleName, tokenName = newTokenName) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/JobNode.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/JobNode.scala new file mode 100644 index 0000000000..d08ea79232 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/JobNode.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows + +import com.databricks.labs.remorph.intermediate.TreeNode + +abstract class JobNode extends TreeNode[JobNode] + +abstract class LeafJobNode extends JobNode { + override def children: Seq[JobNode] = Seq() +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AutoScale.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AutoScale.scala new file mode 100644 index 0000000000..6cc8d3c5b5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AutoScale.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class AutoScale(maxWorkers: Option[Int], minWorkers: Option[Int] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() + + def toSDK: compute.AutoScale = { + val raw = new compute.AutoScale() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AwsAttributes.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AwsAttributes.scala new file mode 100644 index 0000000000..e7ae6ebc4e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AwsAttributes.scala @@ -0,0 +1,24 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute +import com.databricks.sdk.service.compute.{AwsAvailability, EbsVolumeType} + +case class AwsAttributes( + availability: Option[AwsAvailability] = None, + ebsVolumeCount: Option[Int] = None, + ebsVolumeIops: Option[Int] = None, + ebsVolumeSize: Option[Int] = None, + ebsVolumeThroughput: Option[Int] = None, + ebsVolumeType: Option[EbsVolumeType] = None, + firstOnDemand: Option[Int] = None, + instanceProfileArn: Option[String] = None, + spotBidPricePercent: Option[Int] = None, + zoneId: Option[String] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.AwsAttributes = { + val raw = new compute.AwsAttributes() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AzureAttributes.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AzureAttributes.scala new file mode 100644 index 0000000000..a91d5491f8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/AzureAttributes.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute +import com.databricks.sdk.service.compute.AzureAvailability + +case class AzureAttributes( + availability: Option[AzureAvailability] = None, + firstOnDemand: Option[Int] = None, + logAnalyticsInfo: Option[LogAnalyticsInfo] = None, + spotBidMaxPrice: Option[Float] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.AzureAttributes = { + val raw = new compute.AzureAttributes() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClientsTypes.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClientsTypes.scala new file mode 100644 index 0000000000..5bbb02af2a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClientsTypes.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class ClientsTypes(jobs: Boolean = false, notebooks: Boolean) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.ClientsTypes = { + val raw = new compute.ClientsTypes() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterLogConf.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterLogConf.scala new file mode 100644 index 0000000000..eb12d7e0ee --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterLogConf.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.sources.{DbfsStorageInfo, S3StorageInfo} +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class ClusterLogConf(dbfs: Option[DbfsStorageInfo], s3: Option[S3StorageInfo] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() ++ dbfs ++ s3 + def toSDK: compute.ClusterLogConf = { + val raw = new compute.ClusterLogConf() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterSpec.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterSpec.scala new file mode 100644 index 0000000000..b70c909753 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/ClusterSpec.scala @@ -0,0 +1,20 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.libraries.Library +import com.databricks.sdk.service.jobs + +case class ClusterSpec( + existingClusterId: Option[String] = None, + jobClusterKey: Option[String] = None, + libraries: Seq[Library] = Seq.empty, + newCluster: Option[NewClusterSpec] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ libraries ++ newCluster + def toSDK: jobs.ClusterSpec = new jobs.ClusterSpec() + .setExistingClusterId(existingClusterId.orNull) + .setJobClusterKey(jobClusterKey.orNull) + .setLibraries(libraries.map(_.toSDK).asJava) + .setNewCluster(newCluster.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskSpec.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskSpec.scala new file mode 100644 index 0000000000..6116736eef --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskSpec.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class DiskSpec( + diskCount: Option[Int] = None, + diskIops: Option[Int] = None, + diskSize: Option[Int] = None, + diskThroughput: Option[Int] = None, + diskType: Option[DiskType] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.DiskSpec = { + val raw = new compute.DiskSpec() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskType.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskType.scala new file mode 100644 index 0000000000..eb7b827c63 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/DiskType.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute +import com.databricks.sdk.service.compute.{DiskTypeAzureDiskVolumeType, DiskTypeEbsVolumeType} + +case class DiskType( + azureDiskVolumeType: Option[DiskTypeAzureDiskVolumeType] = None, + ebsVolumeType: Option[DiskTypeEbsVolumeType] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.DiskType = { + val raw = new compute.DiskType() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/GcpAttributes.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/GcpAttributes.scala new file mode 100644 index 0000000000..7be8bb35a3 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/GcpAttributes.scala @@ -0,0 +1,20 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute +import com.databricks.sdk.service.compute.GcpAvailability + +case class GcpAttributes( + availability: Option[GcpAvailability] = None, + bootDiskSize: Option[Int] = None, + googleServiceAccount: Option[String] = None, + localSsdCount: Option[Int] = None, + usePreemptibleExecutors: Boolean = false, + zoneId: Option[String] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.GcpAttributes = { + val raw = new compute.GcpAttributes() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/InitScriptInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/InitScriptInfo.scala new file mode 100644 index 0000000000..3241020e04 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/InitScriptInfo.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.sources.{Adlsgen2Info, DbfsStorageInfo, GcsStorageInfo, LocalFileInfo, S3StorageInfo, VolumesStorageInfo, WorkspaceStorageInfo} +import com.databricks.labs.remorph.intermediate.workflows._ +import com.databricks.sdk.service.compute + +case class InitScriptInfo( + abfss: Option[Adlsgen2Info] = None, + dbfs: Option[DbfsStorageInfo] = None, + file: Option[LocalFileInfo] = None, + gcs: Option[GcsStorageInfo] = None, + s3: Option[S3StorageInfo] = None, + volumes: Option[VolumesStorageInfo] = None, + workspace: Option[WorkspaceStorageInfo] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ abfss ++ dbfs ++ file ++ gcs ++ s3 ++ volumes ++ workspace + def toSDK: compute.InitScriptInfo = { + val raw = new compute.InitScriptInfo() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/LogAnalyticsInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/LogAnalyticsInfo.scala new file mode 100644 index 0000000000..c6355caa08 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/LogAnalyticsInfo.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class LogAnalyticsInfo(logAnalyticsPrimaryKey: Option[String], logAnalyticsWorkspaceId: Option[String] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.LogAnalyticsInfo = { + val raw = new compute.LogAnalyticsInfo() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/NewClusterSpec.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/NewClusterSpec.scala new file mode 100644 index 0000000000..d45437455d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/NewClusterSpec.scala @@ -0,0 +1,68 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.libraries.DockerImage +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute +import com.databricks.sdk.service.compute.{DataSecurityMode, RuntimeEngine} + +case class NewClusterSpec( + applyPolicyDefaultValues: Boolean = false, + autoscale: Option[AutoScale] = None, + autoterminationMinutes: Option[Long] = None, + awsAttributes: Option[AwsAttributes] = None, + azureAttributes: Option[AzureAttributes] = None, + clusterLogConf: Option[ClusterLogConf] = None, + clusterName: Option[String] = None, + customTags: Map[String, String] = Map.empty, + dataSecurityMode: Option[DataSecurityMode] = None, + dockerImage: Option[DockerImage] = None, + driverInstancePoolId: Option[String] = None, + driverNodeTypeId: Option[String] = None, + enableElasticDisk: Boolean = false, + enableLocalDiskEncryption: Boolean = false, + gcpAttributes: Option[GcpAttributes] = None, + initScripts: Seq[InitScriptInfo] = Seq.empty, + instancePoolId: Option[String] = None, + nodeTypeId: Option[String] = None, + numWorkers: Option[Long] = None, + policyId: Option[String] = None, + runtimeEngine: Option[RuntimeEngine] = None, + singleUserName: Option[String] = None, + sparkConf: Map[String, String] = Map.empty, + sparkEnvVars: Map[String, String] = Map.empty, + sparkVersion: Option[String] = None, + sshPublicKeys: Seq[String] = Seq.empty, + workloadType: Option[WorkloadType] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ autoscale ++ awsAttributes ++ azureAttributes ++ + clusterLogConf ++ gcpAttributes ++ workloadType ++ dockerImage ++ initScripts + def toSDK: compute.ClusterSpec = new compute.ClusterSpec() + .setApplyPolicyDefaultValues(applyPolicyDefaultValues) + .setAutoscale(autoscale.map(_.toSDK).orNull) + // .setAutoterminationMinutes(autoterminationMinutes.getOrElse(null)) + .setAwsAttributes(awsAttributes.map(_.toSDK).orNull) + .setAzureAttributes(azureAttributes.map(_.toSDK).orNull) + .setGcpAttributes(gcpAttributes.map(_.toSDK).orNull) + .setClusterLogConf(clusterLogConf.map(_.toSDK).orNull) + .setClusterName(clusterName.orNull) + .setCustomTags(customTags.asJava) + .setDataSecurityMode(dataSecurityMode.orNull) + .setDockerImage(dockerImage.map(_.toSDK).orNull) + .setInstancePoolId(instancePoolId.orNull) + .setDriverInstancePoolId(driverInstancePoolId.orNull) + .setDriverNodeTypeId(driverNodeTypeId.orNull) + .setEnableElasticDisk(enableElasticDisk) + .setEnableLocalDiskEncryption(enableLocalDiskEncryption) + .setInitScripts(initScripts.map(_.toSDK).asJava) + .setNodeTypeId(nodeTypeId.orNull) + // .setNumWorkers(numWorkers.orNull) + .setPolicyId(policyId.orNull) + .setRuntimeEngine(runtimeEngine.orNull) + .setSingleUserName(singleUserName.orNull) + .setSparkConf(sparkConf.asJava) + .setSparkEnvVars(sparkEnvVars.asJava) + .setSparkVersion(sparkVersion.orNull) + .setSshPublicKeys(sshPublicKeys.asJava) + .setWorkloadType(workloadType.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/WorkloadType.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/WorkloadType.scala new file mode 100644 index 0000000000..3d42e049ea --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/clusters/WorkloadType.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.clusters + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class WorkloadType(clients: ClientsTypes) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.WorkloadType = { + val raw = new compute.WorkloadType() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobCluster.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobCluster.scala new file mode 100644 index 0000000000..39a551f4da --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobCluster.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.clusters.NewClusterSpec +import com.databricks.sdk.service.jobs + +case class JobCluster(jobClusterKey: String, newCluster: NewClusterSpec) extends JobNode { + override def children: Seq[JobNode] = Seq(newCluster) + def toSDK: jobs.JobCluster = new jobs.JobCluster() + .setJobClusterKey(jobClusterKey) + .setNewCluster(newCluster.toSDK) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEmailNotifications.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEmailNotifications.scala new file mode 100644 index 0000000000..d488b9e2ef --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEmailNotifications.scala @@ -0,0 +1,22 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class JobEmailNotifications( + noAlertForSkippedRuns: Boolean = false, + onDurationWarningThresholdExceeded: Seq[String] = Seq.empty, + onFailure: Seq[String] = Seq.empty, + onStart: Seq[String] = Seq.empty, + onStreamingBacklogExceeded: Seq[String] = Seq.empty, + onSuccess: Seq[String] = Seq.empty) + extends LeafJobNode { + def toSDK: jobs.JobEmailNotifications = new jobs.JobEmailNotifications() + .setNoAlertForSkippedRuns(noAlertForSkippedRuns) + .setOnDurationWarningThresholdExceeded(onDurationWarningThresholdExceeded.asJava) + .setOnFailure(onFailure.asJava) + .setOnStart(onStart.asJava) + .setOnStreamingBacklogExceeded(onStreamingBacklogExceeded.asJava) + .setOnSuccess(onSuccess.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEnvironment.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEnvironment.scala new file mode 100644 index 0000000000..58e6096e1d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobEnvironment.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.libraries.Environment +import com.databricks.sdk.service.jobs + +case class JobEnvironment(environmentKey: String, spec: Option[Environment] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() ++ spec + def toSDK: jobs.JobEnvironment = new jobs.JobEnvironment() + .setEnvironmentKey(environmentKey) + .setSpec(spec.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobNotificationSettings.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobNotificationSettings.scala new file mode 100644 index 0000000000..55c610c7ea --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobNotificationSettings.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class JobNotificationSettings(noAlertForCanceledRuns: Boolean = false, noAlertForSkippedRuns: Boolean) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.JobNotificationSettings = new jobs.JobNotificationSettings() + .setNoAlertForCanceledRuns(noAlertForCanceledRuns) + .setNoAlertForSkippedRuns(noAlertForSkippedRuns) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameter.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameter.scala new file mode 100644 index 0000000000..d53a21385d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameter.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class JobParameter(default: Option[String], name: Option[String], value: Option[String] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.JobParameter = { + val raw = new jobs.JobParameter() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameterDefinition.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameterDefinition.scala new file mode 100644 index 0000000000..650dc911c5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobParameterDefinition.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class JobParameterDefinition(name: String, default: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.JobParameterDefinition = new jobs.JobParameterDefinition().setName(name).setDefault(default) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobRunAs.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobRunAs.scala new file mode 100644 index 0000000000..2cd004e7eb --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobRunAs.scala @@ -0,0 +1,11 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class JobRunAs(servicePrincipalName: Option[String], userName: Option[String] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.JobRunAs = new jobs.JobRunAs() + .setServicePrincipalName(servicePrincipalName.orNull) + .setUserName(userName.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSettings.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSettings.scala new file mode 100644 index 0000000000..521f8e3f9e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSettings.scala @@ -0,0 +1,73 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.schedules.{Continuous, CronSchedule, TriggerSettings} +import com.databricks.labs.remorph.intermediate.workflows.tasks.Task +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.labs.remorph.intermediate.workflows.webhooks.WebhookNotifications +import com.databricks.sdk.service.jobs + +import java.util.Locale + +case class JobSettings( + name: String, + tasks: Seq[Task], + tags: Map[String, String] = Map.empty, + description: Option[String] = None, + parameters: Seq[JobParameterDefinition] = Seq.empty, + jobClusters: Seq[JobCluster] = Seq.empty, + continuous: Option[Continuous] = None, + schedule: Option[CronSchedule] = None, + trigger: Option[TriggerSettings] = None, + environments: Seq[JobEnvironment] = Seq.empty, + health: Option[JobsHealthRules] = None, + timeoutSeconds: Option[Long] = None, + maxConcurrentRuns: Option[Long] = None, + runAs: Option[JobRunAs] = None, + emailNotifications: Option[JobEmailNotifications] = None, + notificationSettings: Option[JobNotificationSettings] = None, + webhookNotifications: Option[WebhookNotifications] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ continuous ++ emailNotifications ++ + health ++ notificationSettings ++ runAs ++ schedule ++ trigger ++ webhookNotifications + + def resourceName: String = name.toLowerCase(Locale.ROOT).replaceAll("[^A-Za-z0-9]", "_") + + def toUpdate: jobs.JobSettings = new jobs.JobSettings() + .setContinuous(continuous.map(_.toSDK).orNull) + .setDescription(description.orNull) + .setEmailNotifications(emailNotifications.map(_.toSDK).orNull) + .setEnvironments(environments.map(_.toSDK).asJava) + .setHealth(health.map(_.toSDK).orNull) + .setJobClusters(jobClusters.map(_.toSDK).asJava) + // .setMaxConcurrentRuns(maxConcurrentRuns.orNull) + .setName(name) + .setNotificationSettings(notificationSettings.map(_.toSDK).orNull) + .setParameters(parameters.map(_.toSDK).asJava) + .setRunAs(runAs.map(_.toSDK).orNull) + .setSchedule(schedule.map(_.toSDK).orNull) + .setTags(tags.asJava) + .setTasks(tasks.map(_.toSDK).asJava) + // .setTimeoutSeconds(timeoutSeconds.orNull) + .setTrigger(trigger.map(_.toSDK).orNull) + .setWebhookNotifications(webhookNotifications.map(_.toSDK).orNull) + + def toCreate: jobs.CreateJob = new jobs.CreateJob() + .setContinuous(continuous.map(_.toSDK).orNull) + .setDescription(description.orNull) + .setEmailNotifications(emailNotifications.map(_.toSDK).orNull) + .setEnvironments(environments.map(_.toSDK).asJava) + .setHealth(health.map(_.toSDK).orNull) + .setJobClusters(jobClusters.map(_.toSDK).asJava) + // .setMaxConcurrentRuns(maxConcurrentRuns.orNull) + .setName(name) + .setNotificationSettings(notificationSettings.map(_.toSDK).orNull) + .setParameters(parameters.map(_.toSDK).asJava) + .setRunAs(runAs.map(_.toSDK).orNull) + .setSchedule(schedule.map(_.toSDK).orNull) + .setTags(tags.asJava) + .setTasks(tasks.map(_.toSDK).asJava) + // .setTimeoutSeconds(timeoutSeconds.orNull) + .setTrigger(trigger.map(_.toSDK).orNull) + .setWebhookNotifications(webhookNotifications.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSource.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSource.scala new file mode 100644 index 0000000000..e3c213fefa --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobSource.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.JobSourceDirtyState + +case class JobSource(jobConfigPath: String, importFromGitBranch: String, dirtyState: Option[JobSourceDirtyState] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.JobSource = { + val raw = new jobs.JobSource() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRule.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRule.scala new file mode 100644 index 0000000000..db23add0e1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRule.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.{JobsHealthMetric, JobsHealthOperator} + +case class JobsHealthRule(metric: JobsHealthMetric, op: JobsHealthOperator, value: Int) extends LeafJobNode { + def toSDK: jobs.JobsHealthRule = new jobs.JobsHealthRule().setMetric(metric).setOp(op).setValue(value) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRules.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRules.scala new file mode 100644 index 0000000000..0744f865fe --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/jobs/JobsHealthRules.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.intermediate.workflows.jobs + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class JobsHealthRules(rules: Seq[JobsHealthRule]) extends JobNode { + override def children: Seq[JobNode] = rules + def toSDK: jobs.JobsHealthRules = new jobs.JobsHealthRules().setRules(rules.map(_.toSDK).asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerBasicAuth.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerBasicAuth.scala new file mode 100644 index 0000000000..22e2cf3a08 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerBasicAuth.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class DockerBasicAuth(password: Option[String], username: Option[String] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.DockerBasicAuth = { + val raw = new compute.DockerBasicAuth() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerImage.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerImage.scala new file mode 100644 index 0000000000..96811698df --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/DockerImage.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class DockerImage(basicAuth: Option[DockerBasicAuth], url: Option[String] = None) extends JobNode { + override def children: Seq[JobNode] = Seq() ++ basicAuth + def toSDK: compute.DockerImage = { + val raw = new compute.DockerImage() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Environment.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Environment.scala new file mode 100644 index 0000000000..177e056003 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Environment.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.compute + +case class Environment(client: String, dependencies: Seq[String] = Seq.empty) extends LeafJobNode { + def toSDK: compute.Environment = new compute.Environment().setClient(client).setDependencies(dependencies.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Library.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Library.scala new file mode 100644 index 0000000000..c343969e54 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/Library.scala @@ -0,0 +1,24 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class Library( + cran: Option[RCranLibrary] = None, + egg: Option[String] = None, + jar: Option[String] = None, + maven: Option[MavenLibrary] = None, + pypi: Option[PythonPyPiLibrary] = None, + requirements: Option[String] = None, + whl: Option[String] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ cran ++ maven ++ pypi + def toSDK: compute.Library = new compute.Library() + .setCran(cran.map(_.toSDK).orNull) + .setEgg(egg.orNull) + .setJar(jar.orNull) + .setMaven(maven.map(_.toSDK).orNull) + .setPypi(pypi.map(_.toSDK).orNull) + .setRequirements(requirements.orNull) + .setWhl(whl.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/MavenLibrary.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/MavenLibrary.scala new file mode 100644 index 0000000000..f6f72d5ee7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/MavenLibrary.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.compute + +case class MavenLibrary(coordinates: String, exclusions: Seq[String] = Seq.empty, repo: Option[String] = None) + extends LeafJobNode { + def toSDK: compute.MavenLibrary = new compute.MavenLibrary() + .setCoordinates(coordinates) + .setExclusions(exclusions.asJava) + .setRepo(repo.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/PythonPyPiLibrary.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/PythonPyPiLibrary.scala new file mode 100644 index 0000000000..ec50bf4f7d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/PythonPyPiLibrary.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.compute + +case class PythonPyPiLibrary(spec: String, repo: Option[String] = None) extends LeafJobNode { + def toSDK: compute.PythonPyPiLibrary = new compute.PythonPyPiLibrary().setPackage(spec).setRepo(repo.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/RCranLibrary.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/RCranLibrary.scala new file mode 100644 index 0000000000..5850e3c7f2 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/libraries/RCranLibrary.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.workflows.libraries + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.compute + +case class RCranLibrary(spec: String, repo: Option[String] = None) extends LeafJobNode { + def toSDK: compute.RCranLibrary = new compute.RCranLibrary().setPackage(spec).setRepo(repo.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/Continuous.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/Continuous.scala new file mode 100644 index 0000000000..68c186774c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/Continuous.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.PauseStatus + +case class Continuous(pauseStatus: Option[PauseStatus] = None) extends LeafJobNode { + def toSDK: jobs.Continuous = new jobs.Continuous().setPauseStatus(pauseStatus.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/CronSchedule.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/CronSchedule.scala new file mode 100644 index 0000000000..ce94d37037 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/CronSchedule.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.PauseStatus + +case class CronSchedule(quartzCronExpression: String, timezoneId: String, pauseStatus: Option[PauseStatus] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.CronSchedule = new jobs.CronSchedule() + .setQuartzCronExpression(quartzCronExpression) + .setTimezoneId(timezoneId) + .setPauseStatus(pauseStatus.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/FileArrivalTriggerConfiguration.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/FileArrivalTriggerConfiguration.scala new file mode 100644 index 0000000000..3025901ab0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/FileArrivalTriggerConfiguration.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class FileArrivalTriggerConfiguration( + url: String, + minTimeBetweenTriggersSeconds: Option[Int] = None, + waitAfterLastChangeSeconds: Option[Int] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.FileArrivalTriggerConfiguration = { + val raw = new jobs.FileArrivalTriggerConfiguration() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/PeriodicTriggerConfiguration.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/PeriodicTriggerConfiguration.scala new file mode 100644 index 0000000000..798a7193a0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/PeriodicTriggerConfiguration.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.PeriodicTriggerConfigurationTimeUnit + +case class PeriodicTriggerConfiguration(interval: Int, unit: PeriodicTriggerConfigurationTimeUnit) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.PeriodicTriggerConfiguration = { + val raw = new jobs.PeriodicTriggerConfiguration() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TableUpdateTriggerConfiguration.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TableUpdateTriggerConfiguration.scala new file mode 100644 index 0000000000..3b044e1728 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TableUpdateTriggerConfiguration.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.Condition + +case class TableUpdateTriggerConfiguration( + condition: Option[Condition] = None, + minTimeBetweenTriggersSeconds: Option[Int] = None, + tableNames: Seq[String] = Seq.empty, + waitAfterLastChangeSeconds: Option[Int] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.TableUpdateTriggerConfiguration = { + val raw = new jobs.TableUpdateTriggerConfiguration() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TriggerSettings.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TriggerSettings.scala new file mode 100644 index 0000000000..1c4a003a38 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/schedules/TriggerSettings.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.intermediate.workflows.schedules + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.PauseStatus + +case class TriggerSettings( + fileArrival: Option[FileArrivalTriggerConfiguration] = None, + pauseStatus: Option[PauseStatus] = None, + periodic: Option[PeriodicTriggerConfiguration] = None, + table: Option[TableUpdateTriggerConfiguration] = None, + tableUpdate: Option[TableUpdateTriggerConfiguration] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ fileArrival ++ periodic ++ table ++ tableUpdate + def toSDK: jobs.TriggerSettings = new jobs.TriggerSettings() + .setFileArrival(fileArrival.map(_.toSDK).orNull) + .setPauseStatus(pauseStatus.orNull) + .setPeriodic(periodic.map(_.toSDK).orNull) + .setTable(table.map(_.toSDK).orNull) + .setTableUpdate(tableUpdate.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/Adlsgen2Info.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/Adlsgen2Info.scala new file mode 100644 index 0000000000..1fff8d7da7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/Adlsgen2Info.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class Adlsgen2Info(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.Adlsgen2Info = new compute.Adlsgen2Info().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/DbfsStorageInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/DbfsStorageInfo.scala new file mode 100644 index 0000000000..346f3eba15 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/DbfsStorageInfo.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class DbfsStorageInfo(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.DbfsStorageInfo = new compute.DbfsStorageInfo().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/GcsStorageInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/GcsStorageInfo.scala new file mode 100644 index 0000000000..03d1b2a593 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/GcsStorageInfo.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class GcsStorageInfo(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.GcsStorageInfo = new compute.GcsStorageInfo().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/LocalFileInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/LocalFileInfo.scala new file mode 100644 index 0000000000..75e20f2a04 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/LocalFileInfo.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class LocalFileInfo(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.LocalFileInfo = new compute.LocalFileInfo().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/S3StorageInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/S3StorageInfo.scala new file mode 100644 index 0000000000..b175ed00d7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/S3StorageInfo.scala @@ -0,0 +1,20 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class S3StorageInfo( + destination: String, + cannedAcl: Option[String] = None, + enableEncryption: Boolean = false, + encryptionType: Option[String] = None, + endpoint: Option[String] = None, + kmsKey: Option[String] = None, + region: Option[String] = None) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.S3StorageInfo = { + val raw = new compute.S3StorageInfo() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/VolumesStorageInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/VolumesStorageInfo.scala new file mode 100644 index 0000000000..719db146c8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/VolumesStorageInfo.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class VolumesStorageInfo(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.VolumesStorageInfo = new compute.VolumesStorageInfo().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/WorkspaceStorageInfo.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/WorkspaceStorageInfo.scala new file mode 100644 index 0000000000..7ff8c7a5b1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sources/WorkspaceStorageInfo.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.sources + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.compute + +case class WorkspaceStorageInfo(destination: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: compute.WorkspaceStorageInfo = new compute.WorkspaceStorageInfo().setDestination(destination) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskAlert.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskAlert.scala new file mode 100644 index 0000000000..c792243ad0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskAlert.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.intermediate.workflows.sql + +import scala.jdk.CollectionConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class SqlTaskAlert( + alertId: String, + pauseSubscriptions: Boolean = false, + subscriptions: Seq[SqlTaskSubscription] = Seq.empty) + extends LeafJobNode { + def toSDK: jobs.SqlTaskAlert = new jobs.SqlTaskAlert() + .setAlertId(alertId) + .setPauseSubscriptions(pauseSubscriptions) + .setSubscriptions(subscriptions.map(_.toSDK).asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskDashboard.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskDashboard.scala new file mode 100644 index 0000000000..742414df80 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskDashboard.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.intermediate.workflows.sql + +import scala.jdk.CollectionConverters._ +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class SqlTaskDashboard( + dashboardId: String, + customSubject: Option[String] = None, + pauseSubscriptions: Boolean = false, + subscriptions: Seq[SqlTaskSubscription] = Seq.empty) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ subscriptions + def toSDK: jobs.SqlTaskDashboard = new jobs.SqlTaskDashboard() + .setDashboardId(dashboardId) + .setCustomSubject(customSubject.orNull) + .setPauseSubscriptions(pauseSubscriptions) + .setSubscriptions(subscriptions.map(_.toSDK).asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskFile.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskFile.scala new file mode 100644 index 0000000000..b2b6c469ce --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskFile.scala @@ -0,0 +1,11 @@ +package com.databricks.labs.remorph.intermediate.workflows.sql + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.Source + +case class SqlTaskFile(path: String, source: Option[Source] = None) extends LeafJobNode { + def toSDK: jobs.SqlTaskFile = new jobs.SqlTaskFile() + .setPath(path) + .setSource(source.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskQuery.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskQuery.scala new file mode 100644 index 0000000000..ef785b2d8a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskQuery.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.workflows.sql + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class SqlTaskQuery(queryId: String) extends LeafJobNode { + def toSDK: jobs.SqlTaskQuery = new jobs.SqlTaskQuery().setQueryId(queryId) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskSubscription.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskSubscription.scala new file mode 100644 index 0000000000..aa23715f48 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/sql/SqlTaskSubscription.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.intermediate.workflows.sql + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class SqlTaskSubscription(destinationId: Option[String], userName: Option[String] = None) extends LeafJobNode { + def toSDK: jobs.SqlTaskSubscription = new jobs.SqlTaskSubscription() + .setUserName(userName.orNull) + .setDestinationId(destinationId.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/CodeAsset.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/CodeAsset.scala new file mode 100644 index 0000000000..04347f436a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/CodeAsset.scala @@ -0,0 +1,3 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +trait CodeAsset diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ConditionTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ConditionTask.scala new file mode 100644 index 0000000000..e4ca5f3c02 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ConditionTask.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.ConditionTaskOp + +case class ConditionTask(op: ConditionTaskOp, left: String, right: String) extends LeafJobNode { + def toSDK: jobs.ConditionTask = new jobs.ConditionTask() + .setOp(op) + .setLeft(left) + .setRight(right) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/DbtTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/DbtTask.scala new file mode 100644 index 0000000000..6b14314cd0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/DbtTask.scala @@ -0,0 +1,26 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.jdk.CollectionConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.Source + +case class DbtTask( + commands: Seq[String], + catalog: Option[String] = None, + profilesDirectory: Option[String] = None, + projectDirectory: Option[String] = None, + schema: Option[String] = None, + source: Option[Source] = None, + warehouseId: Option[String] = None) + extends LeafJobNode + with NeedsWarehouse { + def toSDK: jobs.DbtTask = new jobs.DbtTask() + .setCommands(commands.asJava) + .setCatalog(catalog.orNull) + .setProfilesDirectory(profilesDirectory.orNull) + .setProjectDirectory(projectDirectory.orNull) + .setSchema(schema.orNull) + .setSource(source.orNull) + .setWarehouseId(warehouseId.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ForEachTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ForEachTask.scala new file mode 100644 index 0000000000..e377653e4e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/ForEachTask.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class ForEachTask(inputs: String, task: Task, concurrency: Long = 20) extends JobNode { + override def children: Seq[JobNode] = Seq(task) + def toSDK: jobs.ForEachTask = new jobs.ForEachTask() + .setTask(task.toSDK) + .setInputs(inputs) + .setConcurrency(concurrency) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NeedsWarehouse.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NeedsWarehouse.scala new file mode 100644 index 0000000000..75173c1942 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NeedsWarehouse.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +trait NeedsWarehouse { + final val DEFAULT_WAREHOUSE_ID = sys.env.getOrElse("DATABRICKS_WAREHOUSE_ID", "__DEFAULT_WAREHOUSE_ID__") +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NotebookTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NotebookTask.scala new file mode 100644 index 0000000000..c00b99ea50 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/NotebookTask.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.{JobNode, LeafJobNode} +import com.databricks.sdk.service.jobs + +import scala.jdk.CollectionConverters._ + +case class NotebookTask( + notebookPath: String, + baseParameters: Option[Map[String, String]] = None, + warehouseId: Option[String] = None) + extends LeafJobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.NotebookTask = new jobs.NotebookTask() + .setNotebookPath(notebookPath) + .setBaseParameters(baseParameters.map(_.asJava).orNull) + .setWarehouseId(warehouseId.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PipelineTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PipelineTask.scala new file mode 100644 index 0000000000..fd20598bb8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PipelineTask.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class PipelineTask(pipelineId: String, fullRefresh: Boolean) extends LeafJobNode with CodeAsset { + def toSDK: jobs.PipelineTask = new jobs.PipelineTask() + .setPipelineId(pipelineId) + .setFullRefresh(fullRefresh) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PythonWheelTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PythonWheelTask.scala new file mode 100644 index 0000000000..916138d4c6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/PythonWheelTask.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class PythonWheelTask( + packageName: String, + entryPoint: String, + namedParameters: Option[Map[String, String]] = None, + parameters: Seq[String] = Seq.empty) + extends LeafJobNode + with CodeAsset { + def toSDK: jobs.PythonWheelTask = new jobs.PythonWheelTask() + .setPackageName(packageName) + .setEntryPoint(entryPoint) + .setNamedParameters(namedParameters.getOrElse(Map.empty).asJava) + .setParameters(parameters.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunConditionTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunConditionTask.scala new file mode 100644 index 0000000000..01f16ae47d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunConditionTask.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.ConditionTaskOp + +case class RunConditionTask(op: ConditionTaskOp, left: String, right: String, outcome: Option[String] = None) + extends LeafJobNode { + def toSDK: jobs.RunConditionTask = new jobs.RunConditionTask() + .setOp(op) + .setLeft(left) + .setRight(right) + .setOutcome(outcome.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunJobTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunJobTask.scala new file mode 100644 index 0000000000..f4aa4da724 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/RunJobTask.scala @@ -0,0 +1,31 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.PipelineParams + +case class RunJobTask( + jobId: Long, + jobParams: Map[String, String] = Map.empty, + notebookParams: Map[String, String] = Map.empty, + pythonNamedParams: Map[String, String] = Map.empty, + sqlParams: Map[String, String] = Map.empty, + dbtArgs: Seq[String] = Seq.empty, + jarParams: Seq[String] = Seq.empty, + pythonArgs: Seq[String] = Seq.empty, + sparkSubmitArgs: Seq[String] = Seq.empty, + fullPipelineRefresh: Boolean = false) + extends LeafJobNode { + def toSDK: jobs.RunJobTask = new jobs.RunJobTask() + .setJobId(jobId) + .setDbtCommands(dbtArgs.asJava) + .setJarParams(jarParams.asJava) + .setJobParameters(jobParams.asJava) + .setNotebookParams(notebookParams.asJava) + .setPipelineParams(if (fullPipelineRefresh) new PipelineParams().setFullRefresh(true) else null) + .setPythonNamedParams(pythonNamedParams.asJava) + .setPythonParams(pythonArgs.asJava) + .setSparkSubmitParams(sparkSubmitArgs.asJava) + .setSqlParams(sqlParams.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkJarTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkJarTask.scala new file mode 100644 index 0000000000..b66ad6b16d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkJarTask.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class SparkJarTask(jarUri: Option[String], mainClassName: Option[String], parameters: Seq[String] = Seq.empty) + extends LeafJobNode { + def toSDK: jobs.SparkJarTask = new jobs.SparkJarTask() + .setJarUri(jarUri.orNull) + .setMainClassName(mainClassName.orNull) + .setParameters(parameters.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkPythonTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkPythonTask.scala new file mode 100644 index 0000000000..fac7e006d6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkPythonTask.scala @@ -0,0 +1,15 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs +import com.databricks.sdk.service.jobs.Source + +case class SparkPythonTask(pythonFile: String, parameters: Seq[String] = Seq.empty, source: Option[Source] = None) + extends LeafJobNode + with CodeAsset { + def toSDK: jobs.SparkPythonTask = new jobs.SparkPythonTask() + .setPythonFile(pythonFile) + .setParameters(parameters.asJava) + .setSource(source.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkSubmitTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkSubmitTask.scala new file mode 100644 index 0000000000..11f6fd8c8a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SparkSubmitTask.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class SparkSubmitTask(parameters: Seq[String] = Seq.empty) extends LeafJobNode { + def toSDK: jobs.SparkSubmitTask = new jobs.SparkSubmitTask().setParameters(parameters.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SqlTask.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SqlTask.scala new file mode 100644 index 0000000000..f163bc67d3 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/SqlTask.scala @@ -0,0 +1,25 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.jdk.CollectionConverters._ +import com.databricks.labs.remorph.intermediate.workflows._ +import com.databricks.labs.remorph.intermediate.workflows.sql.{SqlTaskAlert, SqlTaskDashboard, SqlTaskFile, SqlTaskQuery} +import com.databricks.sdk.service.jobs + +case class SqlTask( + warehouseId: String, + alert: Option[SqlTaskAlert] = None, + dashboard: Option[SqlTaskDashboard] = None, + file: Option[SqlTaskFile] = None, + parameters: Option[Map[String, String]] = None, + query: Option[SqlTaskQuery] = None) + extends JobNode + with NeedsWarehouse { + override def children: Seq[JobNode] = Seq() ++ alert ++ dashboard ++ file ++ query + def toSDK: jobs.SqlTask = new jobs.SqlTask() + .setWarehouseId(warehouseId) + .setAlert(alert.map(_.toSDK).orNull) + .setDashboard(dashboard.map(_.toSDK).orNull) + .setFile(file.map(_.toSDK).orNull) + .setParameters(parameters.map(_.asJava).orNull) + .setQuery(query.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/Task.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/Task.scala new file mode 100644 index 0000000000..efe2d58cd5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/Task.scala @@ -0,0 +1,95 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.clusters.NewClusterSpec +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobsHealthRules +import com.databricks.labs.remorph.intermediate.workflows.libraries.Library +import com.databricks.labs.remorph.intermediate.workflows._ +import com.databricks.labs.remorph.intermediate.workflows.webhooks.WebhookNotifications +import com.databricks.sdk.service.jobs + +case class Task( + taskKey: String, + description: Option[String] = None, + dependsOn: Seq[TaskDependency] = Seq.empty, + dbtTask: Option[DbtTask] = None, + conditionTask: Option[ConditionTask] = None, + forEachTask: Option[ForEachTask] = None, + notebookTask: Option[NotebookTask] = None, + pipelineTask: Option[PipelineTask] = None, + pythonWheelTask: Option[PythonWheelTask] = None, + runJobTask: Option[RunJobTask] = None, + sparkJarTask: Option[SparkJarTask] = None, + sparkPythonTask: Option[SparkPythonTask] = None, + sparkSubmitTask: Option[SparkSubmitTask] = None, + sqlTask: Option[SqlTask] = None, + libraries: Seq[Library] = Seq.empty, + newCluster: Option[NewClusterSpec] = None, + existingClusterId: Option[String] = None, + jobClusterKey: Option[String] = None, + runIf: Option[jobs.RunIf] = None, + disableAutoOptimization: Boolean = false, + environmentKey: Option[String] = None, + maxRetries: Option[Int] = None, + minRetryIntervalMillis: Option[Int] = None, + health: Option[JobsHealthRules] = None, + retryOnTimeout: Boolean = false, + timeoutSeconds: Option[Int] = None, + notificationSettings: Option[TaskNotificationSettings] = None, + emailNotifications: Option[TaskEmailNotifications] = None, + webhookNotifications: Option[WebhookNotifications] = None) + extends JobNode { + + override def children: Seq[JobNode] = Seq() ++ + conditionTask ++ + dbtTask ++ + dependsOn ++ + emailNotifications ++ + forEachTask ++ + health ++ + libraries ++ + newCluster ++ + notebookTask ++ + notificationSettings ++ + pipelineTask ++ + pythonWheelTask ++ + runJobTask ++ + sparkJarTask ++ + sparkPythonTask ++ + sparkSubmitTask ++ + sqlTask ++ + webhookNotifications + + def dependOn(task: Task): Task = copy(dependsOn = dependsOn :+ TaskDependency(task.taskKey)) + + def toSDK: jobs.Task = new jobs.Task() + .setTaskKey(taskKey) + .setConditionTask(conditionTask.map(_.toSDK).orNull) + .setDbtTask(dbtTask.map(_.toSDK).orNull) + .setDependsOn(dependsOn.map(_.toSDK).asJava) + .setDescription(description.orNull) + .setDisableAutoOptimization(disableAutoOptimization) + .setEmailNotifications(emailNotifications.map(_.toSDK).orNull) + .setEnvironmentKey(environmentKey.orNull) + .setExistingClusterId(existingClusterId.orNull) + .setForEachTask(forEachTask.map(_.toSDK).orNull) + .setHealth(health.map(_.toSDK).orNull) + .setJobClusterKey(jobClusterKey.orNull) + .setLibraries(libraries.map(_.toSDK).asJava) + .setMaxRetries(maxRetries.map(_.asInstanceOf[java.lang.Long]).orNull) + .setMinRetryIntervalMillis(minRetryIntervalMillis.map(_.asInstanceOf[java.lang.Long]).orNull) + .setNewCluster(newCluster.map(_.toSDK).orNull) + .setNotebookTask(notebookTask.map(_.toSDK).orNull) + .setNotificationSettings(notificationSettings.map(_.toSDK).orNull) + .setPipelineTask(pipelineTask.map(_.toSDK).orNull) + .setPythonWheelTask(pythonWheelTask.map(_.toSDK).orNull) + .setRetryOnTimeout(retryOnTimeout) + .setRunIf(runIf.orNull) + .setRunJobTask(runJobTask.map(_.toSDK).orNull) + .setSparkJarTask(sparkJarTask.map(_.toSDK).orNull) + .setSparkPythonTask(sparkPythonTask.map(_.toSDK).orNull) + .setSparkSubmitTask(sparkSubmitTask.map(_.toSDK).orNull) + .setSqlTask(sqlTask.map(_.toSDK).orNull) + .setTimeoutSeconds(timeoutSeconds.map(_.asInstanceOf[java.lang.Long]).orNull) + .setWebhookNotifications(webhookNotifications.map(_.toSDK).orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskDependency.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskDependency.scala new file mode 100644 index 0000000000..894cbe6a78 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskDependency.scala @@ -0,0 +1,8 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class TaskDependency(taskKey: String, outcome: Option[String] = None) extends LeafJobNode { + def toSDK: jobs.TaskDependency = new jobs.TaskDependency().setTaskKey(taskKey).setOutcome(outcome.orNull) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskEmailNotifications.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskEmailNotifications.scala new file mode 100644 index 0000000000..ee3454731b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskEmailNotifications.scala @@ -0,0 +1,22 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.LeafJobNode +import com.databricks.sdk.service.jobs + +case class TaskEmailNotifications( + noAlertForSkippedRuns: Boolean = false, + onDurationWarningThresholdExceeded: Seq[String] = Seq.empty, + onFailure: Seq[String] = Seq.empty, + onStart: Seq[String] = Seq.empty, + onStreamingBacklogExceeded: Seq[String] = Seq.empty, + onSuccess: Seq[String] = Seq.empty) + extends LeafJobNode { + def toSDK: jobs.TaskEmailNotifications = new jobs.TaskEmailNotifications() + .setNoAlertForSkippedRuns(noAlertForSkippedRuns) + .setOnDurationWarningThresholdExceeded(onDurationWarningThresholdExceeded.asJava) + .setOnFailure(onFailure.asJava) + .setOnStart(onStart.asJava) + .setOnStreamingBacklogExceeded(onStreamingBacklogExceeded.asJava) + .setOnSuccess(onSuccess.asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskNotificationSettings.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskNotificationSettings.scala new file mode 100644 index 0000000000..17c6ee5277 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/tasks/TaskNotificationSettings.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.intermediate.workflows.tasks + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class TaskNotificationSettings( + alertOnLastAttempt: Boolean = false, + noAlertForCanceledRuns: Boolean = false, + noAlertForSkippedRuns: Boolean) + extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.TaskNotificationSettings = { + val raw = new jobs.TaskNotificationSettings() + raw + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/Webhook.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/Webhook.scala new file mode 100644 index 0000000000..2f7310dbc9 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/Webhook.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.intermediate.workflows.webhooks + +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class Webhook(id: String) extends JobNode { + override def children: Seq[JobNode] = Seq() + def toSDK: jobs.Webhook = new jobs.Webhook().setId(id) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/WebhookNotifications.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/WebhookNotifications.scala new file mode 100644 index 0000000000..9468b7b752 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/workflows/webhooks/WebhookNotifications.scala @@ -0,0 +1,22 @@ +package com.databricks.labs.remorph.intermediate.workflows.webhooks + +import scala.collection.JavaConverters._ +import com.databricks.labs.remorph.intermediate.workflows.JobNode +import com.databricks.sdk.service.jobs + +case class WebhookNotifications( + onDurationWarningThresholdExceeded: Seq[Webhook] = Seq.empty, + onFailure: Seq[Webhook] = Seq.empty, + onStart: Seq[Webhook] = Seq.empty, + onStreamingBacklogExceeded: Seq[Webhook] = Seq.empty, + onSuccess: Seq[Webhook] = Seq.empty) + extends JobNode { + override def children: Seq[JobNode] = Seq() ++ onDurationWarningThresholdExceeded ++ onFailure ++ + onStart ++ onStreamingBacklogExceeded ++ onSuccess + def toSDK: jobs.WebhookNotifications = new jobs.WebhookNotifications() + .setOnDurationWarningThresholdExceeded(onDurationWarningThresholdExceeded.map(_.toSDK).asJava) + .setOnFailure(onFailure.map(_.toSDK).asJava) + .setOnStart(onStart.map(_.toSDK).asJava) + .setOnStreamingBacklogExceeded(onStreamingBacklogExceeded.map(_.toSDK).asJava) + .setOnSuccess(onSuccess.map(_.toSDK).asJava) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/ConversionStrategy.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/ConversionStrategy.scala new file mode 100644 index 0000000000..f003ba1b85 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/ConversionStrategy.scala @@ -0,0 +1,17 @@ +package com.databricks.labs.remorph.parsers +import com.databricks.labs.remorph.{intermediate => ir} + +import java.util.Locale + +trait ConversionStrategy { + def convert(irName: String, args: Seq[ir.Expression]): ir.Expression +} + +trait StringConverter { + // Preserves case if the original name was all lower case. Otherwise, converts to upper case. + // All bets are off if the original name was mixed case, but that is rarely seen in SQL and we are + // just making reasonable efforts here. + def convertString(irName: String, newName: String): String = { + if (irName.forall(c => c.isLower || c == '_')) newName.toLowerCase(Locale.ROOT) else newName + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/ErrorCollector.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/ErrorCollector.scala new file mode 100644 index 0000000000..3599835d03 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/ErrorCollector.scala @@ -0,0 +1,129 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.coverage.ErrorEncoders +import com.databricks.labs.remorph.intermediate.{ParsingError, RemorphError} +import org.antlr.v4.runtime._ +import org.apache.logging.log4j.{LogManager, Logger} + +import scala.collection.mutable.ListBuffer +import io.circe.syntax._ + +sealed trait ErrorCollector extends BaseErrorListener { + def logErrors(): Unit = {} + def errorsAsJson: String = "{}" + def errorCount: Int = 0 + private[remorph] def formatErrors: Seq[String] = Seq() + def reset(): Unit = {} +} + +class EmptyErrorCollector extends ErrorCollector + +class DefaultErrorCollector extends ErrorCollector { + + var count: Int = 0 + private[this] val antlrErr: ConsoleErrorListener = new ConsoleErrorListener() + + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + antlrErr.syntaxError(recognizer, offendingSymbol, line, charPositionInLine, msg, e) + count += 1 + } + + override def errorCount: Int = count + override def reset(): Unit = count = 0 +} + +class ProductionErrorCollector(sourceCode: String, fileName: String) extends ErrorCollector with ErrorEncoders { + val errors: ListBuffer[ParsingError] = ListBuffer() + val logger: Logger = LogManager.getLogger(classOf[ErrorCollector]) + + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val errorDetail = offendingSymbol match { + case t: Token => + val width = t.getStopIndex - t.getStartIndex + 1 + ParsingError(line, charPositionInLine, msg, width, t.getText, tokenName(recognizer, t), ruleName(recognizer, e)) + case _ => ParsingError(line, charPositionInLine, msg, 0, "", "missing", ruleName(recognizer, e)) + } + errors += errorDetail + } + + override private[remorph] def formatErrors: Seq[String] = { + val lines = sourceCode.split("\n") + errors.map { error => + val errorLine = lines(error.line - 1) + val errorText = formatError(errorLine, error.charPositionInLine, error.offendingTokenWidth) + s"${error.msg}\nFile: $fileName, Line: ${error.line}, Token: ${error.offendingTokenText}\n$errorText" + } + } + + private[parsers] def tokenName(recognizer: Recognizer[_, _], token: Token): String = token match { + case t: Token => + Option(recognizer) + .map(_.getVocabulary.getSymbolicName(t.getType)) + .getOrElse("unresolved token name") + case _ => "missing" + } + + private[parsers] def ruleName(recognizer: Recognizer[_, _], e: RecognitionException): String = + (Option(recognizer), Option(e)) match { + case (Some(rec), Some(exc)) => rec.getRuleNames()(exc.getCtx.getRuleIndex) + case _ => "unresolved rule name" + } + + private[parsers] def formatError( + errorLine: String, + errorPosition: Int, + errorWidth: Int, + windowWidth: Int = 80): String = { + val roomForContext = (windowWidth - errorWidth) / 2 + val clipLeft = errorLine.length > windowWidth && errorPosition >= roomForContext + val clipRight = + errorLine.length > windowWidth && + errorLine.length - errorPosition - errorWidth >= roomForContext + val clipMark = "..." + val (markerStart, clippedLine) = (clipLeft, clipRight) match { + case (false, false) => (errorPosition, errorLine) + case (true, false) => + ( + windowWidth - (errorLine.length - errorPosition), + clipMark + errorLine.substring(errorLine.length - windowWidth + clipMark.length)) + case (false, true) => + (errorPosition, errorLine.take(windowWidth - clipMark.length) + clipMark) + case (true, true) => + val start = errorPosition - roomForContext + val clippedLineWithoutClipMarks = + errorLine.substring(start, Math.min(start + windowWidth, errorLine.length - 1)) + ( + roomForContext, + clipMark + clippedLineWithoutClipMarks.substring( + clipMark.length, + clipMark.length + windowWidth - 2 * clipMark.length) + clipMark) + + } + clippedLine + "\n" + " " * markerStart + "^" * errorWidth + } + + override def logErrors(): Unit = { + val formattedErrors = formatErrors + if (formattedErrors.nonEmpty) { + formattedErrors.foreach(error => logger.error(error)) + } + } + + override def errorsAsJson: String = errors.toList.map(_.asInstanceOf[RemorphError]).asJson.noSpaces + + override def errorCount: Int = errors.size + + override def reset(): Unit = errors.clear() +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/FunctionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/FunctionBuilder.scala new file mode 100644 index 0000000000..1f6c80842c --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/FunctionBuilder.scala @@ -0,0 +1,397 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.parsers.snowflake.NamedArgumentExpression +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeFunctionConverters.SynonymOf +import com.databricks.labs.remorph.{intermediate => ir} + +sealed trait FunctionType +case object StandardFunction extends FunctionType +case object XmlFunction extends FunctionType +case object NotConvertibleFunction extends FunctionType +case object UnknownFunction extends FunctionType + +sealed trait FunctionArity + +// For functions with a fixed number of arguments (ie. all arguments are required) +case class FixedArity(arity: Int) extends FunctionArity +// For functions with a varying number of arguments (ie. some arguments are optional) +case class VariableArity(argMin: Int, argMax: Int) extends FunctionArity +// For functions with named arguments (ie. some arguments are optional and arguments may be provided in any order) +case class SymbolicArity(requiredArguments: Set[String], optionalArguments: Set[String]) extends FunctionArity + +object FunctionArity { + def verifyArguments(arity: FunctionArity, args: Seq[ir.Expression]): Boolean = arity match { + case FixedArity(n) => args.size == n + case VariableArity(argMin, argMax) => argMin <= args.size && args.size <= argMax + case SymbolicArity(required, optional) => + val namedArguments = args.collect { case n: NamedArgumentExpression => n } + // all provided arguments are named + if (namedArguments.size == args.size) { + // all required arguments are present + required.forall(r => namedArguments.exists(_.key.toUpperCase() == r.toUpperCase())) && + // no unexpected argument was provided + namedArguments.forall(na => (required ++ optional).map(_.toUpperCase()).contains(na.key.toUpperCase())) + } else if (namedArguments.isEmpty) { + // arguments were provided positionally + args.size >= required.size && args.size <= required.size + optional.size + } else { + // a mix of named and positional arguments were provided, which isn't supported + false + } + } +} + +case class FunctionDefinition( + arity: FunctionArity, + functionType: FunctionType, + conversionStrategy: Option[ConversionStrategy] = None) { + def withConversionStrategy(strategy: ConversionStrategy): FunctionDefinition = + copy(conversionStrategy = Some(strategy)) +} + +object FunctionDefinition { + def standard(fixedArgNumber: Int): FunctionDefinition = + FunctionDefinition(FixedArity(fixedArgNumber), StandardFunction) + def standard(minArg: Int, maxArg: Int): FunctionDefinition = + FunctionDefinition(VariableArity(minArg, maxArg), StandardFunction) + + def symbolic(required: Set[String], optional: Set[String]): FunctionDefinition = + FunctionDefinition(SymbolicArity(required, optional), StandardFunction) + + def xml(fixedArgNumber: Int): FunctionDefinition = + FunctionDefinition(FixedArity(fixedArgNumber), XmlFunction) + + def notConvertible(fixedArgNumber: Int): FunctionDefinition = + FunctionDefinition(FixedArity(fixedArgNumber), NotConvertibleFunction) + def notConvertible(minArg: Int, maxArg: Int): FunctionDefinition = + FunctionDefinition(VariableArity(minArg, maxArg), NotConvertibleFunction) +} + +abstract class FunctionBuilder { + + protected val commonFunctionsPf: PartialFunction[String, FunctionDefinition] = { + case "ABS" => FunctionDefinition.standard(1) + case "ACOS" => FunctionDefinition.standard(1) + case "APP_NAME" => FunctionDefinition.standard(0) + case "APPLOCK_MODE" => FunctionDefinition.standard(3) + case "APPLOCK_TEST" => FunctionDefinition.standard(4) + case "APPROX_COUNT_DISTINCT" => FunctionDefinition.standard(1) + case "APPROX_PERCENTILE" => FunctionDefinition.standard(2) + case "APPROX_PERCENTILE_CONT" => FunctionDefinition.standard(1) + case "APPROX_PERCENTILE_DISC" => FunctionDefinition.standard(1) + case "ARRAYAGG" => FunctionDefinition.standard(1) + case "ASCII" => FunctionDefinition.standard(1) + case "ASIN" => FunctionDefinition.standard(1) + case "ASSEMBLYPROPERTY" => FunctionDefinition.standard(2) + case "ATAN" => FunctionDefinition.standard(1) + case "ATN2" => FunctionDefinition.standard(2) + case "AVG" => FunctionDefinition.standard(1) + case "BINARY_CHECKSUM" => FunctionDefinition.standard(1, Int.MaxValue) + case "BIT_COUNT" => FunctionDefinition.standard(1) + case "CEILING" => FunctionDefinition.standard(1) + case "CERT_ID" => FunctionDefinition.standard(1) + case "CERTENCODED" => FunctionDefinition.standard(1) + case "CERTPRIVATEKEY" => FunctionDefinition.standard(2, 3) + case "CHAR" => FunctionDefinition.standard(1) + case "CHARINDEX" => FunctionDefinition.standard(2, 3) + case "CHECKSUM" => FunctionDefinition.standard(2, Int.MaxValue) + case "CHECKSUM_AGG" => FunctionDefinition.standard(1) + case "COALESCE" => FunctionDefinition.standard(1, Int.MaxValue) + case "COL_LENGTH" => FunctionDefinition.standard(2) + case "COL_NAME" => FunctionDefinition.standard(2) + case "COLUMNPROPERTY" => FunctionDefinition.standard(3) + case "COMPRESS" => FunctionDefinition.standard(1) + case "CONCAT" => FunctionDefinition.standard(2, Int.MaxValue) + case "CONCAT_WS" => FunctionDefinition.standard(3, Int.MaxValue) + case "CONNECTIONPROPERTY" => FunctionDefinition.notConvertible(1) + case "CONTEXT_INFO" => FunctionDefinition.standard(0) + case "CONVERT" => FunctionDefinition.standard(2, 3) + case "COS" => FunctionDefinition.standard(1) + case "COT" => FunctionDefinition.standard(1) + case "COUNT" => FunctionDefinition.standard(1) + case "COUNT_BIG" => FunctionDefinition.standard(1) + case "CUME_DIST" => FunctionDefinition.standard(0) + case "CURRENT_DATE" => FunctionDefinition.standard(0) + case "CURRENT_REQUEST_ID" => FunctionDefinition.standard(0) + case "CURRENT_TIMESTAMP" => FunctionDefinition.standard(0) + case "CURRENT_TIMEZONE" => FunctionDefinition.standard(0) + case "CURRENT_TIMEZONE_ID" => FunctionDefinition.standard(0) + case "CURRENT_TRANSACTION_ID" => FunctionDefinition.standard(0) + case "CURRENT_USER" => FunctionDefinition.standard(0) + case "CURSOR_ROWS" => FunctionDefinition.standard(0) + case "CURSOR_STATUS" => FunctionDefinition.standard(2) + case "DATABASE_PRINCIPAL_ID" => FunctionDefinition.standard(0, 1) + case "DATABASEPROPERTY" => FunctionDefinition.standard(2) + case "DATABASEPROPERTYEX" => FunctionDefinition.standard(2) + case "DATALENGTH" => FunctionDefinition.standard(1) + case "DATE_BUCKET" => FunctionDefinition.standard(3, 4) + case "DATE_DIFF_BIG" => FunctionDefinition.standard(3) + case "DATEADD" => FunctionDefinition.standard(3) + case "DATEDIFF" => FunctionDefinition.standard(3) + case "DATEFROMPARTS" => FunctionDefinition.standard(3) + case "DATE_FORMAT" => FunctionDefinition.standard(2) + case "DATENAME" => FunctionDefinition.standard(2) + case "DATEPART" => FunctionDefinition.standard(2) + case "DATETIME2FROMPARTS" => FunctionDefinition.standard(8) + case "DATETIMEFROMPARTS" => FunctionDefinition.standard(7) + case "DATETIMEOFFSETFROMPARTS" => FunctionDefinition.standard(10) + case "DATETRUNC" => FunctionDefinition.standard(2) + case "DAY" => FunctionDefinition.standard(1) + case "DB_ID" => FunctionDefinition.standard(0, 1) + case "DB_NAME" => FunctionDefinition.standard(0, 1) + case "DECOMPRESS" => FunctionDefinition.standard(1) + case "DEGREES" => FunctionDefinition.standard(1) + case "DENSE_RANK" => FunctionDefinition.standard(0) + case "DIFFERENCE" => FunctionDefinition.standard(2) + case "EOMONTH" => FunctionDefinition.standard(1, 2) + case "ERROR_LINE" => FunctionDefinition.standard(0) + case "ERROR_MESSAGE" => FunctionDefinition.standard(0) + case "ERROR_NUMBER" => FunctionDefinition.standard(0) + case "ERROR_PROCEDURE" => FunctionDefinition.standard(0) + case "ERROR_SEVERITY" => FunctionDefinition.standard(0) + case "ERROR_STATE" => FunctionDefinition.standard(0) + case "EXIST" => FunctionDefinition.xml(1) + case "EXP" => FunctionDefinition.standard(1) + case "FILE_ID" => FunctionDefinition.standard(1) + case "FILE_IDEX" => FunctionDefinition.standard(1) + case "FILE_NAME" => FunctionDefinition.standard(1) + case "FILEGROUP_ID" => FunctionDefinition.standard(1) + case "FILEGROUP_NAME" => FunctionDefinition.standard(1) + case "FILEGROUPPROPERTY" => FunctionDefinition.standard(2) + case "FILEPROPERTY" => FunctionDefinition.standard(2) + case "FILEPROPERTYEX" => FunctionDefinition.standard(2) + case "FIRST_VALUE" => FunctionDefinition.standard(1) + case "FLOOR" => FunctionDefinition.standard(1) + case "FORMAT" => FunctionDefinition.standard(2, 3) + case "FORMATMESSAGE" => FunctionDefinition.standard(2, Int.MaxValue) + case "FULLTEXTCATALOGPROPERTY" => FunctionDefinition.standard(2) + case "FULLTEXTSERVICEPROPERTY" => FunctionDefinition.standard(1) + case "GET_FILESTREAM_TRANSACTION_CONTEXT" => FunctionDefinition.standard(0) + case "GETANCESTGOR" => FunctionDefinition.standard(1) + case "GETANSINULL" => FunctionDefinition.standard(0, 1) + case "GETDATE" => FunctionDefinition.standard(0) + case "GETDESCENDANT" => FunctionDefinition.standard(2) + case "GETLEVEL" => FunctionDefinition.standard(0) + case "GETREPARENTEDVALUE" => FunctionDefinition.standard(2) + case "GETUTCDATE" => FunctionDefinition.standard(0) + case "GREATEST" => FunctionDefinition.standard(1, Int.MaxValue) + case "GROUPING" => FunctionDefinition.standard(1) + case "GROUPING_ID" => FunctionDefinition.standard(0, Int.MaxValue) + case "HAS_DBACCESS" => FunctionDefinition.standard(1) + case "HAS_PERMS_BY_NAME" => FunctionDefinition.standard(4, 5) + case "HOST_ID" => FunctionDefinition.standard(0) + case "HOST_NAME" => FunctionDefinition.standard(0) + case "IDENT_CURRENT" => FunctionDefinition.standard(1) + case "IDENT_INCR" => FunctionDefinition.standard(1) + case "IDENT_SEED" => FunctionDefinition.standard(1) + case "IDENTITY" => FunctionDefinition.standard(1, 3) + case "IFF" => FunctionDefinition.standard(3) + case "INDEX_COL" => FunctionDefinition.standard(3) + case "INDEXKEY_PROPERTY" => FunctionDefinition.standard(3) + case "INDEXPROPERTY" => FunctionDefinition.standard(3) + case "IS_MEMBER" => FunctionDefinition.standard(1) + case "IS_ROLEMEMBER" => FunctionDefinition.standard(1, 2) + case "IS_SRVROLEMEMBER" => FunctionDefinition.standard(1, 2) + case "ISDATE" => FunctionDefinition.standard(1) + case "ISDESCENDANTOF" => FunctionDefinition.standard(1) + case "ISJSON" => FunctionDefinition.standard(1, 2) + case "ISNUMERIC" => FunctionDefinition.standard(1) + case "JSON_MODIFY" => FunctionDefinition.standard(3) + case "JSON_PATH_EXISTS" => FunctionDefinition.standard(2) + case "JSON_QUERY" => FunctionDefinition.standard(2) + case "JSON_VALUE" => FunctionDefinition.standard(2) + case "LAG" => FunctionDefinition.standard(1, 3) + case "LAST_VALUE" => FunctionDefinition.standard(1) + case "LEAD" => FunctionDefinition.standard(1, 3) + case "LEAST" => FunctionDefinition.standard(1, Int.MaxValue) + case "LEFT" => FunctionDefinition.standard(2) + case "LEN" => FunctionDefinition.standard(1) + case "LISTAGG" => FunctionDefinition.standard(1, 2) + case "LN" => FunctionDefinition.standard(1) + case "LOG" => FunctionDefinition.standard(1, 2) + case "LOG10" => FunctionDefinition.standard(1) + case "LOGINPROPERTY" => FunctionDefinition.standard(2) + case "LOWER" => FunctionDefinition.standard(1) + case "LTRIM" => FunctionDefinition.standard(1) + case "MAX" => FunctionDefinition.standard(1) + case "MIN" => FunctionDefinition.standard(1) + case "MIN_ACTIVE_ROWVERSION" => FunctionDefinition.standard(0) + case "MONTH" => FunctionDefinition.standard(1) + case "NCHAR" => FunctionDefinition.standard(1) + case "NEWID" => FunctionDefinition.standard(0) + case "NEWSEQUENTIALID" => FunctionDefinition.standard(0) + case "NODES" => FunctionDefinition.xml(1) + case "NTILE" => FunctionDefinition.standard(1) + case "NULLIF" => FunctionDefinition.standard(2) + case "OBJECT_DEFINITION" => FunctionDefinition.standard(1) + case "OBJECT_ID" => FunctionDefinition.standard(1, 2) + case "OBJECT_NAME" => FunctionDefinition.standard(1, 2) + case "OBJECT_SCHEMA_NAME" => FunctionDefinition.standard(1, 2) + case "OBJECTPROPERTY" => FunctionDefinition.standard(2) + case "OBJECTPROPERTYEX" => FunctionDefinition.standard(2) + case "ORIGINAL_DB_NAME" => FunctionDefinition.standard(0) + case "ORIGINAL_LOGIN" => FunctionDefinition.standard(0) + case "PARSE" => FunctionDefinition.notConvertible(2, 3) // Not in DBSQL + case "PARSENAME" => FunctionDefinition.standard(2) + case "PATINDEX" => FunctionDefinition.standard(2) + case "PERCENT_RANK" => FunctionDefinition.standard(0) + case "PERCENTILE_CONT" => FunctionDefinition.standard(1) + case "PERCENTILE_DISC" => FunctionDefinition.standard(1) + case "PERMISSIONS" => FunctionDefinition.notConvertible(0, 2) // not in DBSQL + case "PI" => FunctionDefinition.standard(0) + case "POWER" => FunctionDefinition.standard(2) + case "PWDCOMPARE" => FunctionDefinition.standard(2, 3) + case "PWDENCRYPT" => FunctionDefinition.standard(1) + case "QUERY" => FunctionDefinition.xml(1) + case "QUOTENAME" => FunctionDefinition.standard(1, 2) + case "RADIANS" => FunctionDefinition.standard(1) + case "RAND" => FunctionDefinition.standard(0, 1) + case "RANK" => FunctionDefinition.standard(0) + case "REPLACE" => FunctionDefinition.standard(3) + case "REPLICATE" => FunctionDefinition.standard(2) + case "REVERSE" => FunctionDefinition.standard(1) + case "RIGHT" => FunctionDefinition.standard(2) + case "ROUND" => FunctionDefinition.standard(1, 3) + case "ROW_NUMBER" => FunctionDefinition.standard(0) + case "ROWCOUNT_BIG" => FunctionDefinition.standard(0) + case "RTRIM" => FunctionDefinition.standard(1) + case "SCHEMA_ID" => FunctionDefinition.standard(0, 1) + case "SCHEMA_NAME" => FunctionDefinition.standard(0, 1) + case "SCOPE_IDENTITY" => FunctionDefinition.standard(0) + case "SERVERPROPERTY" => FunctionDefinition.standard(1) + case "SESSION_CONTEXT" => FunctionDefinition.standard(1, 2) + case "SESSION_USER" => FunctionDefinition.standard(0) + case "SESSIONPROPERTY" => FunctionDefinition.standard(1) + case "SIGN" => FunctionDefinition.standard(1) + case "SIN" => FunctionDefinition.standard(1) + case "SMALLDATETIMEFROMPARTS" => FunctionDefinition.standard(5) + case "SOUNDEX" => FunctionDefinition.standard(1) + case "SPACE" => FunctionDefinition.standard(1) + case "SQL_VARIANT_PROPERTY" => FunctionDefinition.standard(2) + case "SQRT" => FunctionDefinition.standard(1) + case "SQUARE" => FunctionDefinition.standard(1) + case "STATS_DATE" => FunctionDefinition.standard(2) + case "STDEV" => FunctionDefinition.standard(1) + case "STDEVP" => FunctionDefinition.standard(1) + case "STR" => FunctionDefinition.standard(1, 3) + case "STRING_AGG" => FunctionDefinition.standard(2, 3) + case "STRING_ESCAPE" => FunctionDefinition.standard(2) + case "STUFF" => FunctionDefinition.standard(4) + case "SUBSTR" => FunctionDefinition.standard(2, 3) + case "SUBSTRING" => FunctionDefinition.standard(2, 3).withConversionStrategy(SynonymOf("SUBSTR")) + case "SUM" => FunctionDefinition.standard(1) + case "SUSER_ID" => FunctionDefinition.standard(0, 1) + case "SUSER_NAME" => FunctionDefinition.standard(0, 1) + case "SUSER_SID" => FunctionDefinition.standard(0, 2) + case "SUSER_SNAME" => FunctionDefinition.standard(0, 1) + case "SWITCHOFFSET" => FunctionDefinition.standard(2) + case "SYSDATETIME" => FunctionDefinition.standard(0) + case "SYSDATETIMEOFFSET" => FunctionDefinition.standard(0) + case "SYSTEM_USER" => FunctionDefinition.standard(0) + case "SYSUTCDATETIME" => FunctionDefinition.standard(0) + case "TAN" => FunctionDefinition.standard(1) + case "TIMEFROMPARTS" => FunctionDefinition.standard(5) + case "TODATETIMEOFFSET" => FunctionDefinition.standard(2) + case "TOSTRING" => FunctionDefinition.standard(0) + case "TRANSLATE" => FunctionDefinition.standard(3) + case "TRIM" => FunctionDefinition.standard(1, 2) + case "TYPE_ID" => FunctionDefinition.standard(1) + case "TYPE_NAME" => FunctionDefinition.standard(1) + case "TYPEPROPERTY" => FunctionDefinition.standard(2) + case "UNICODE" => FunctionDefinition.standard(1) + case "UPPER" => FunctionDefinition.standard(1) + case "USER" => FunctionDefinition.standard(0) + case "USER_ID" => FunctionDefinition.standard(0, 1) + case "USER_NAME" => FunctionDefinition.standard(0, 1) + case "VALUE" => FunctionDefinition.xml(2) + case "VAR" => FunctionDefinition.standard(1) + case "VARP" => FunctionDefinition.standard(1) + case "XACT_STATE" => FunctionDefinition.standard(0) + case "YEAR" => FunctionDefinition.standard(1) + } + + def functionDefinition(name: String): Option[FunctionDefinition] = + commonFunctionsPf.lift(name.toUpperCase()) + + def functionType(name: String): FunctionType = { + functionDefinition(name).map(_.functionType).getOrElse(UnknownFunction) + } + + def buildFunction(id: ir.Id, args: Seq[ir.Expression]): ir.Expression = { + val name = if (id.caseSensitive) id.id else id.id.toUpperCase() + buildFunction(name, args) + } + + def buildFunction(name: String, args: Seq[ir.Expression]): ir.Expression = { + val irName = removeQuotesAndBrackets(name) + val defnOption = functionDefinition(irName) + + defnOption match { + case Some(functionDef) if functionDef.functionType == NotConvertibleFunction => + ir.UnresolvedFunction( + name, + args, + is_distinct = false, + is_user_defined_function = false, + ruleText = s"$irName(...)", + message = s"Function $irName is not convertible to Databricks SQL", + ruleName = "N/A", + tokenName = Some("N/A")) + + case Some(funDef) if FunctionArity.verifyArguments(funDef.arity, args) => + applyConversionStrategy(funDef, args, irName) + + // Found the function but the arg count is incorrect + case Some(_) => + ir.UnresolvedFunction( + irName, + args, + is_distinct = false, + is_user_defined_function = false, + has_incorrect_argc = true, + ruleText = s"$irName(...)", + message = s"Invocation of $irName has incorrect argument count", + ruleName = "N/A", + tokenName = Some("N/A")) + + // Unsupported function + case None => + ir.UnresolvedFunction( + irName, + args, + is_distinct = false, + is_user_defined_function = false, + ruleText = s"$irName(...)", + message = s"Function $irName is not convertible to Databricks SQL", + ruleName = "N/A", + tokenName = Some("N/A")) + } + } + + def applyConversionStrategy( + functionArity: FunctionDefinition, + args: Seq[ir.Expression], + irName: String): ir.Expression + + /** + * Functions can be called even if they are quoted or bracketed. This function removes the quotes and brackets. + * @param str + * the possibly quoted function name + * @return + * function name for use in lookup/matching + */ + private def removeQuotesAndBrackets(str: String): String = { + val quotations = Map('\'' -> "'", '"' -> "\"", '[' -> "]", '\\' -> "\\") + str match { + case s if s.length < 2 => s + case s => + quotations.get(s.head).fold(s) { closingQuote => + if (s.endsWith(closingQuote)) { + s.substring(1, s.length - 1) + } else { + s + } + } + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/ParseException.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/ParseException.scala new file mode 100644 index 0000000000..8760b215a4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/ParseException.scala @@ -0,0 +1,3 @@ +package com.databricks.labs.remorph.parsers + +case class ParseException(msg: String) extends RuntimeException(msg) diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/ParserCommon.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/ParserCommon.scala new file mode 100644 index 0000000000..617fafe82f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/ParserCommon.scala @@ -0,0 +1,152 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.{intermediate => ir} +import com.typesafe.scalalogging.LazyLogging +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree._ +import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token} + +import scala.collection.JavaConverters._ + +trait ParserCommon[A] extends ParseTreeVisitor[A] with LazyLogging { self: AbstractParseTreeVisitor[A] => + + val vc: VisitorCoordinator + + protected def occursBefore(a: ParseTree, b: ParseTree): Boolean = { + a != null && b != null && a.getSourceInterval.startsBeforeDisjoint(b.getSourceInterval) + } + + def visitMany[R <: RuleContext](contexts: java.lang.Iterable[R]): Seq[A] = contexts.asScala.map(_.accept(self)).toSeq + + /** + *

+ * An implementation of this should return some type of ir.UnresolvedXYZ object that represents the + * unresolved input that we have no visitor for. This is used in the default visitor to wrap the + * unresolved input. + *

+ * @param ruleText Which piece of source code the unresolved object represents + * @param message What message the unresolved object should contain, such as missing visitor + * @return An instance of the type returned by the implementing visitor + */ + protected def unresolved(ruleText: String, message: String): A + + protected override def defaultResult(): A = { + unresolved(contextText(currentNode.getRuleContext), s"Unimplemented visitor $caller in class $implementor") + } + + /** + *

+ * Creates a string representation of the text represented by the given ANTLR ParserRuleContext. + *

+ *

+ * Note that this should exactly reflect the original input text as bounded by the source interval + * recorded by the parser. + *

+ */ + def contextText(ctx: RuleContext): String = try { + ctx match { + case ctx: ParserRuleContext => + ctx.getStart.getInputStream.getText(new Interval(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) + case _ => "Unsupported RuleContext type - cannot generate source string" + } + } catch { + // Anything that does this will have been mocked and the mockery will be huge to get the above code to work + case _: Throwable => "Mocked string" + } + + /** + *

+ * Returns the rule name that a particular context represents. + *

+ *

+ * We can do this by referencing the vocab stored in the visitor coordinator + *

+ * @param ctx The context for which we want the rule name + * @return the rule name for the context + */ + def contextRuleName(ctx: ParserRuleContext): String = + vc.ruleName(ctx) + + /** + * Given a token, return its symbolic name (what it is called in the lexer) + */ + def tokenName(tok: Token): String = + vc.tokenName(tok) + + /** + *

+ * The default visitor needs some way to aggregate the results of visiting all children and so calls this method. In + * fact, we should never rely on this as there is no way to know exactly what to do in all circumstances. + *

+ *

+ * Note that while we have unimplemented visitors, some parts of the IR building will 'work by accident' as + * this method will just produce the first and only result in agg. But we should implement the missing visitor that + * explicitly returns the required result as it is flaky to rely on the default here + *

+ *

+ * We do not try and resolve what the input should actually be as that is the job of a concrete + * visitor. + *

+ * + * @param agg The current result as seen by the default visitor + * @param next The next result that should somehow be aggregated to form a single result + * @return The aggregated result from the two supplied results (accumulate error strings) + */ + override def aggregateResult(agg: A, next: A): A = + Option(next).getOrElse(agg) + + protected var currentNode: RuleNode = _ + protected var caller: String = _ + protected var implementor: String = _ + + /** + * Overrides the default visitChildren to report that there is an unimplemented + * visitor def for the given RuleNode, then call the ANTLR default visitor + * @param node the RuleNode that was not visited + * @return T an instance of the type returned by the implementing visitor + */ + abstract override def visitChildren(node: RuleNode): A = { + caller = Thread.currentThread().getStackTrace()(4).getMethodName + implementor = this.getClass.getSimpleName + logger.warn( + s"Unimplemented visitor for method: $caller in class: $implementor" + + s" for: ${contextText(node.getRuleContext)}") + currentNode = node + val result = super.visitChildren(node) + result match { + case c: ir.Unresolved[A] => + node match { + case ctx: ParserRuleContext => + c.annotate(contextRuleName(ctx), Some(tokenName(ctx.getStart))) + } + case _ => + result + } + } + + /** + * If the parser recognizes a syntax error, then it generates an ErrorNode. The ErrorNode represents unparsable text + * and contains a manufactured token that encapsulates all the text that the parser ignored when it recovered + * from the error. Note that if the error recovery strategy inserts a token rather than deletes one, then an + * error node will not be created; those errors will only be reported via the ErrorCollector + * TODO: It may be reasonable to add a check for inserted tokens and generate an error node in that case + * + * @param ctx the context to check for error nodes + * @return The unresolved object representing the error and containing the text that was skipped + */ + def errorCheck(ctx: ParserRuleContext): Option[A] = { + val unparsedText = Option(ctx.children) + .map(_.asScala) + .getOrElse(Seq.empty) + .collect { case e: ErrorNode => + s"Unparsable text: ${e.getSymbol.getText}" + } + .mkString("\n") + + if (unparsedText.nonEmpty) { + Some(unresolved(unparsedText, "Unparsed input - ErrorNode encountered")) + } else { + None + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/PlanParser.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/PlanParser.scala new file mode 100644 index 0000000000..d758008859 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/PlanParser.scala @@ -0,0 +1,90 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.intermediate.{ParsingErrors, PlanGenerationFailure, TranspileFailure} +import com.databricks.labs.remorph.{BuildingAst, KoResult, OkResult, Optimizing, Parsing, PartialResult, Transformation, TransformationConstructors, WorkflowStage, intermediate => ir} +import org.antlr.v4.runtime._ +import org.json4s.jackson.Serialization +import org.json4s.{Formats, NoTypeHints} + +import scala.util.control.NonFatal + +trait PlanParser[P <: Parser] extends TransformationConstructors { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + protected def createLexer(input: CharStream): Lexer + protected def createParser(stream: TokenStream): P + protected def createTree(parser: P): ParserRuleContext + protected def createPlan(tree: ParserRuleContext): ir.LogicalPlan + protected def addErrorStrategy(parser: P): Unit + def dialect: String + + // TODO: This is probably not where the optimizer should be as this is a Plan "Parser" - it is here for now + protected def createOptimizer: ir.Rules[ir.LogicalPlan] + + /** + * Parse the input source code into a Parse tree + * @return Returns a parse tree on success otherwise a description of the errors + */ + def parse: Transformation[ParserRuleContext] = { + + getCurrentPhase.flatMap { + case Parsing(source, filename, _, _) => + val inputString = CharStreams.fromString(source) + val lexer = createLexer(inputString) + val tokenStream = new CommonTokenStream(lexer) + val parser = createParser(tokenStream) + addErrorStrategy(parser) + val errListener = new ProductionErrorCollector(source, filename) + parser.removeErrorListeners() + parser.addErrorListener(errListener) + val tree = createTree(parser) + if (errListener.errorCount > 0) { + lift(PartialResult(tree, ParsingErrors(errListener.errors))) + } else { + lift(OkResult(tree)) + } + case other => ko(WorkflowStage.PARSE, ir.IncoherentState(other, classOf[Parsing])) + } + } + + /** + * Visit the parse tree and create a logical plan + * @param tree The parse tree + * @return Returns a logical plan on success otherwise a description of the errors + */ + def visit(tree: ParserRuleContext): Transformation[ir.LogicalPlan] = { + updatePhase { + case p: Parsing => BuildingAst(tree, Some(p)) + case _ => BuildingAst(tree) + }.flatMap { _ => + try { + ok(createPlan(tree)) + } catch { + case NonFatal(e) => + lift(KoResult(stage = WorkflowStage.PLAN, PlanGenerationFailure(e))) + } + } + } + + // TODO: This is probably not where the optimizer should be as this is a Plan "Parser" - it is here for now + /** + * Optimize the logical plan + * + * @param logicalPlan The logical plan + * @return Returns an optimized logical plan on success otherwise a description of the errors + */ + def optimize(logicalPlan: ir.LogicalPlan): Transformation[ir.LogicalPlan] = { + updatePhase { + case b: BuildingAst => Optimizing(logicalPlan, Some(b)) + case _ => Optimizing(logicalPlan) + }.flatMap { _ => + try { + ok(createOptimizer.apply(logicalPlan)) + } catch { + case NonFatal(e) => + lift(KoResult(stage = WorkflowStage.OPTIMIZE, TranspileFailure(e))) + } + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/SqlErrorStrategy.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/SqlErrorStrategy.scala new file mode 100644 index 0000000000..bf4deabce5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/SqlErrorStrategy.scala @@ -0,0 +1,128 @@ +package com.databricks.labs.remorph.parsers + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.{Interval, IntervalSet, Pair} +import org.antlr.v4.runtime.tree.ErrorNodeImpl + +/** + * Custom error strategy for SQL parsing

While we do not do anything super special here, we wish to override a + * couple of the message generating methods and the token insert and delete messages, which do not create an exception + * and don't allow us to create an error message in context. Additionally, we can now implement i18n, should that ever + * become necessary.

+ * + *

At the moment, we require valid SQL as input to the conversion process, but if we ever change that strategy, then + * we can implement custom recovery steps here based upon context, though there is no improvement on the sync() + * method.

+ */ +abstract class SqlErrorStrategy extends DefaultErrorStrategy { + + private def createErrorNode(token: Token, text: String): ErrorNodeImpl = { + val errorToken = new CommonToken( + new Pair(token.getTokenSource, token.getInputStream), + Token.INVALID_TYPE, + Token.DEFAULT_CHANNEL, + token.getStartIndex, + token.getStopIndex) + errorToken.setText(text) + errorToken.setLine(token.getLine) + errorToken.setCharPositionInLine(token.getCharPositionInLine) + new ErrorNodeImpl(errorToken) + } + + override def recover(recognizer: Parser, e: RecognitionException): Unit = { + val tokens: TokenStream = recognizer.getInputStream + val startIndex: Int = tokens.index + val first = tokens.LT(1) + super.recover(recognizer, e) + val endIndex: Int = tokens.index + if (startIndex < endIndex) { + val interval = new Interval(startIndex, endIndex) + val errorText = s"parser recovered by ignoring: ${tokens.getText(interval)}" + val errorNode = createErrorNode(first, errorText) + + // Here we add the error node to the current context so that we can report it in the correct place + recognizer.getContext.addErrorNode(errorNode) + } + } + + override protected def reportNoViableAlternative(recognizer: Parser, e: NoViableAltException): Unit = { + val tokens = recognizer.getInputStream + val input = if (tokens != null) { + if (e.getStartToken.getType == Token.EOF) "" + else tokens.getText(e.getStartToken, e.getOffendingToken) + } else "" + + val msg = new StringBuilder() + msg.append("input is not parsable ") + msg.append(escapeWSAndQuote(input)) + recognizer.notifyErrorListeners(e.getOffendingToken, msg.toString(), e) + + // Here we add the error node to the current context so that we can report it in the correct place + val errorNode = createErrorNode(e.getStartToken, input) + recognizer.getContext.addErrorNode(errorNode) + } + + override protected def reportInputMismatch(recognizer: Parser, e: InputMismatchException): Unit = { + val msg = new StringBuilder() + msg.append(getTokenErrorDisplay(e.getOffendingToken)) + msg.append(" was unexpected ") + msg.append(generateMessage(recognizer, e)) + msg.append("\nexpecting one of: ") + msg.append(buildExpectedMessage(recognizer, e.getExpectedTokens)) + recognizer.notifyErrorListeners(e.getOffendingToken, msg.toString(), e) + + // Here we add the error node to the current context so that we can report it in the correct place + val errorNode = createErrorNode(e.getOffendingToken, msg.toString()) + recognizer.getContext.addErrorNode(errorNode) + } + + override protected def reportUnwantedToken(recognizer: Parser): Unit = { + if (inErrorRecoveryMode(recognizer)) return + beginErrorCondition(recognizer) + val t = recognizer.getCurrentToken + val tokenName = getTokenErrorDisplay(t) + val expecting = getExpectedTokens(recognizer) + val msg = new StringBuilder() + msg.append("unexpected extra input ") + msg.append(tokenName) + msg.append(' ') + msg.append(generateMessage(recognizer, new InputMismatchException(recognizer))) + msg.append("\nexpecting one of: ") + msg.append(buildExpectedMessage(recognizer, expecting)) + recognizer.notifyErrorListeners(t, msg.toString(), null) + + // Here we add the error node to the current context so that we can report it in the correct place + val errorNode = createErrorNode(t, msg.toString()) + recognizer.getContext.addErrorNode(errorNode) + } + + override protected def reportMissingToken(recognizer: Parser): Unit = { + if (inErrorRecoveryMode(recognizer)) return + + beginErrorCondition(recognizer) + val t = recognizer.getCurrentToken + val expecting = getExpectedTokens(recognizer) + val msg = new StringBuilder() + msg.append("missing ") + msg.append(buildExpectedMessage(recognizer, expecting)) + msg.append(" at ") + msg.append(getTokenErrorDisplay(t)) + msg.append('\n') + msg.append(generateMessage(recognizer, new InputMismatchException(recognizer))) + recognizer.notifyErrorListeners(t, msg.toString(), null) + + // Here we add the error node to the current context so that we can report it in the correct place + val errorNode = createErrorNode(t, msg.toString()) + recognizer.getContext.addErrorNode(errorNode) + } + + val capitalizedSort: Ordering[String] = Ordering.fromLessThan((a, b) => + (a.exists(_.isLower), b.exists(_.isLower)) match { + case (true, false) => true + case (false, true) => false + case _ => a.compareTo(b) < 0 + }) + + protected def generateMessage(recognizer: Parser, e: RecognitionException): String + protected def buildExpectedMessage(recognizer: Parser, expected: IntervalSet): String +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/VisitorCoordinator.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/VisitorCoordinator.scala new file mode 100644 index 0000000000..18f18ea270 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/VisitorCoordinator.scala @@ -0,0 +1,34 @@ +package com.databricks.labs.remorph.parsers +import com.databricks.labs.remorph.parsers.tsql.DataTypeBuilder +import org.antlr.v4.runtime.{ParserRuleContext, Token, Vocabulary} +import org.antlr.v4.runtime.tree.ParseTreeVisitor + +/** + *

+ * An implementation of this class should provide an instance of each of the ParseTreeVisitors + * required to build IR and each visitor should be initialized with a reference to the implementation + * of this class so that the visitors can call each other without needing to know the specific implementation + * of any visitor or creating circular dependencies between them. + *

+ *

+ * Implementations can also supply other shared resources, builders, etc. via the same mechanism + *

+ */ +abstract class VisitorCoordinator(val parserVocab: Vocabulary, val ruleNames: Array[String]) { + + // Parse tree visitors + val astBuilder: ParseTreeVisitor[_] + val relationBuilder: ParseTreeVisitor[_] + val expressionBuilder: ParseTreeVisitor[_] + val dmlBuilder: ParseTreeVisitor[_] + val ddlBuilder: ParseTreeVisitor[_] + + // A function builder that can be used to build function calls for a particular dialect + val functionBuilder: FunctionBuilder + + // Common builders that are used across all parsers, but can still be overridden + val dataTypeBuilder = new DataTypeBuilder + + def ruleName(ctx: ParserRuleContext): String = ruleNames(ctx.getRuleIndex) + def tokenName(tok: Token): String = parserVocab.getSymbolicName(tok.getType) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala new file mode 100644 index 0000000000..d9e14ca399 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala @@ -0,0 +1,152 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => _, _} +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters._ + +/** + * @see + * org.apache.spark.sql.catalyst.parser.AstBuilder + */ +class SnowflakeAstBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.LogicalPlan] + with ParserCommon[ir.LogicalPlan] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.LogicalPlan = + ir.UnresolvedRelation(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitSnowflakeFile(ctx: SnowflakeFileContext): ir.LogicalPlan = { + // This very top level visitor does not ignore any valid statements for the batch, instead + // we prepend any errors to the batch plan, so they are generated first in the output. + val errors = errorCheck(ctx) + val batchPlan = Option(ctx.batch()).map(buildBatch).getOrElse(Seq.empty) + errors match { + case Some(errorResult) => ir.Batch(errorResult +: batchPlan) + case None => ir.Batch(batchPlan) + } + } + + private def buildBatch(ctx: BatchContext): Seq[ir.LogicalPlan] = { + // This very top level visitor does not ignore any valid statements for the batch, instead + // we prepend any errors to the batch plan, so they are generated first in the output. + val errors = errorCheck(ctx) + val statements = visitMany(ctx.sqlClauses()) + errors match { + case Some(errorResult) => errorResult +: statements + case None => statements + } + } + + override def visitSqlClauses(ctx: SqlClausesContext): ir.LogicalPlan = { + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.ddlCommand() != null => c.ddlCommand().accept(this) + case c if c.dmlCommand() != null => c.dmlCommand().accept(this) + case c if c.showCommand() != null => c.showCommand().accept(this) + case c if c.useCommand() != null => c.useCommand().accept(this) + case c if c.describeCommand() != null => c.describeCommand().accept(this) + case c if c.otherCommand() != null => c.otherCommand().accept(this) + case c if c.snowSqlCommand() != null => c.snowSqlCommand().accept(this) + case _ => + ir.UnresolvedCommand( + ruleText = contextText(ctx), + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart)), + message = "Unknown command in SnowflakeAstBuilder.visitSqlCommand") + } + } + } + + // TODO: Sort out where to visitSubquery + override def visitQueryStatement(ctx: QueryStatementContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse { + val query = ctx.queryExpression().accept(this) + Option(ctx.withExpression()).foldRight(query)(buildCTE) + } + } + + override def visitQueryInParenthesis(ctx: QueryInParenthesisContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse(ctx.queryExpression().accept(this)) + } + + override def visitQueryIntersect(ctx: QueryIntersectContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse { + val Seq(lhs, rhs) = ctx.queryExpression().asScala.map(_.accept(this)) + ir.SetOperation(lhs, rhs, ir.IntersectSetOp, is_all = false, by_name = false, allow_missing_columns = false) + } + } + + override def visitQueryUnion(ctx: QueryUnionContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse { + val Seq(lhs, rhs) = ctx.queryExpression().asScala.map(_.accept(this)) + val setOp = ctx match { + case u if u.UNION() != null => ir.UnionSetOp + case e if e.EXCEPT() != null || e.MINUS_() != null => ir.ExceptSetOp + } + val isAll = ctx.ALL() != null + ir.SetOperation(lhs, rhs, setOp, is_all = isAll, by_name = false, allow_missing_columns = false) + } + } + + override def visitQuerySimple(ctx: QuerySimpleContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse(ctx.selectStatement().accept(vc.relationBuilder)) + } + + override def visitDdlCommand(ctx: DdlCommandContext): ir.LogicalPlan = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.accept(vc.ddlBuilder) + } + + private def buildCTE(ctx: WithExpressionContext, relation: ir.LogicalPlan): ir.LogicalPlan = { + if (ctx == null) { + return relation + } + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + if (ctx.RECURSIVE() == null) { + val ctes = vc.relationBuilder.visitMany(ctx.commonTableExpression()) + ir.WithCTE(ctes, relation) + } else { + // TODO With Recursive CTE are not support by default, will require a custom implementation IR to be redefined + val ctes = vc.relationBuilder.visitMany(ctx.commonTableExpression()) + ir.WithRecursiveCTE(ctes, relation) + } + } + } + + override def visitDmlCommand(ctx: DmlCommandContext): ir.LogicalPlan = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.queryStatement() != null => c.queryStatement().accept(this) + case c => c.accept(vc.dmlBuilder) + } + } + + override def visitOtherCommand(ctx: OtherCommandContext): ir.LogicalPlan = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.accept(vc.commandBuilder) + } + + override def visitSnowSqlCommand(ctx: SnowSqlCommandContext): ir.LogicalPlan = { + ir.UnresolvedCommand( + ruleText = contextText(ctx), + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart)), + message = "Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand") + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilder.scala new file mode 100644 index 0000000000..e9e51c7757 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilder.scala @@ -0,0 +1,110 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate.procedures.SetVariable +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => _, _} +import com.databricks.labs.remorph.{intermediate => ir} + +class SnowflakeCommandBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.Command] + with ParserCommon[ir.Command] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Command = + ir.UnresolvedCommand(ruleText, message) + + // Concrete visitors + + // TODO: Implement Cursor and Exception for Declare Statements. + // TODO: Implement Cursor for Let Statements. + + override def visitDeclareWithDefault(ctx: DeclareWithDefaultContext): ir.Command = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val variableName = ctx.id().accept(vc.expressionBuilder).asInstanceOf[ir.Id] + val dataType = vc.typeBuilder.buildDataType(ctx.dataType()) + val variableValue = ctx.expr().accept(vc.expressionBuilder) + ir.CreateVariable(variableName, dataType, Some(variableValue), replace = false) + } + + override def visitDeclareSimple(ctx: DeclareSimpleContext): ir.Command = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val variableName = ctx.id().accept(vc.expressionBuilder).asInstanceOf[ir.Id] + val dataType = vc.typeBuilder.buildDataType(ctx.dataType()) + ir.CreateVariable(variableName, dataType, None, replace = false) + } + + override def visitDeclareResultSet(ctx: DeclareResultSetContext): ir.Command = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val variableName = ctx.id().accept(vc.expressionBuilder).asInstanceOf[ir.Id] + val variableValue = ctx.expr() match { + case null => None + case stmt => Some(stmt.accept(vc.expressionBuilder)) + } + val dataType = ir.StructType(Seq()) + ir.CreateVariable(variableName, dataType, variableValue, replace = false) + } + + override def visitLetVariableAssignment(ctx: LetVariableAssignmentContext): ir.Command = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val variableName = ctx.id().accept(vc.expressionBuilder).asInstanceOf[ir.Id] + val variableValue = ctx.expr().accept(vc.expressionBuilder) + + val variableDataType = variableValue match { + case s: ir.ScalarSubquery => Some(s.dataType) + case _ => Option(ctx.dataType()).flatMap(dt => Some(vc.typeBuilder.buildDataType(dt))) + } + SetVariable(variableName, variableValue, variableDataType) + } + + override def visitExecuteTask(ctx: ExecuteTaskContext): ir.Command = { + ir.UnresolvedCommand( + ruleText = contextText(ctx), + message = "Execute Task is not yet supported", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + + override def visitOtherCommand(ctx: OtherCommandContext): ir.Command = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.copyIntoTable != null => c.copyIntoTable.accept(this) + case c if c.copyIntoLocation != null => c.copyIntoLocation.accept(this) + case c if c.comment != null => c.comment.accept(this) + case c if c.commit != null => c.commit.accept(this) + case e if e.executeImmediate != null => e.executeImmediate.accept(this) + case e if e.executeTask != null => e.executeTask.accept(this) + case e if e.explain != null => e.explain.accept(this) + case g if g.getDml != null => g.getDml.accept(this) + case g if g.grantOwnership != null => g.grantOwnership.accept(this) + case g if g.grantToRole != null => g.grantToRole.accept(this) + case g if g.grantToShare != null => g.grantToShare.accept(this) + case g if g.grantRole != null => g.grantRole.accept(this) + case l if l.list != null => l.list.accept(this) + case p if p.put != null => p.put.accept(this) + case r if r.remove != null => r.remove.accept(this) + case r if r.revokeFromRole != null => r.revokeFromRole.accept(this) + case r if r.revokeFromShare != null => r.revokeFromShare.accept(this) + case r if r.revokeRole != null => r.revokeRole.accept(this) + case r if r.rollback != null => r.rollback.accept(this) + case s if s.set != null => s.set.accept(this) + case t if t.truncateMaterializedView != null => t.truncateMaterializedView.accept(this) + case t if t.truncateTable != null => t.truncateTable.accept(this) + case u if u.unset != null => u.unset.accept(this) + case c if c.call != null => c.call.accept(this) + case b if b.beginTxn != null => b.beginTxn.accept(this) + case d if d.declareCommand != null => d.declareCommand.accept(this) + case l if l.let != null => l.let.accept(this) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilder.scala new file mode 100644 index 0000000000..fa33832840 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilder.scala @@ -0,0 +1,406 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => StrContext, _} +import com.databricks.labs.remorph.{intermediate => ir} + +import java.util.Locale +import scala.collection.JavaConverters._ +class SnowflakeDDLBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.Catalog] + with ParserCommon[ir.Catalog] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Catalog = + ir.UnresolvedCatalog(ruleText = ruleText, message = message) + + // Concrete visitors + + private def extractString(ctx: StrContext): String = + ctx.accept(vc.expressionBuilder) match { + case ir.StringLiteral(s) => s + // TODO: Do not throw an error here - we need to generate an UnresolvedCatalog for it + // However, it is likely that we will neveer see this in the wild + case e => throw new IllegalArgumentException(s"Expected a string literal, got $e") + } + + override def visitDdlCommand(ctx: DdlCommandContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case a if a.alterCommand() != null => a.alterCommand().accept(this) + case c if c.createCommand() != null => c.createCommand().accept(this) + case d if d.dropCommand() != null => d.dropCommand().accept(this) + case u if u.undropCommand() != null => u.undropCommand().accept(this) + } + } + + override def visitCreateCommand(ctx: CreateCommandContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.createAccount() != null => c.createAccount().accept(this) + case c if c.createAlert() != null => c.createAlert().accept(this) + case c if c.createApiIntegration() != null => c.createApiIntegration().accept(this) + case c if c.createObjectClone() != null => c.createObjectClone().accept(this) + case c if c.createConnection() != null => c.createConnection().accept(this) + case c if c.createDatabase() != null => c.createDatabase().accept(this) + case c if c.createDynamicTable() != null => c.createDynamicTable().accept(this) + case c if c.createEventTable() != null => c.createEventTable().accept(this) + case c if c.createExternalFunction() != null => c.createExternalFunction().accept(this) + case c if c.createExternalTable() != null => c.createExternalTable().accept(this) + case c if c.createFailoverGroup() != null => c.createFailoverGroup().accept(this) + case c if c.createFileFormat() != null => c.createFileFormat().accept(this) + case c if c.createFunction() != null => c.createFunction().accept(this) + case c if c.createManagedAccount() != null => c.createManagedAccount().accept(this) + case c if c.createMaskingPolicy() != null => c.createMaskingPolicy().accept(this) + case c if c.createMaterializedView() != null => c.createMaterializedView().accept(this) + case c if c.createNetworkPolicy() != null => c.createNetworkPolicy().accept(this) + case c if c.createNotificationIntegration() != null => c.createNotificationIntegration().accept(this) + case c if c.createPipe() != null => c.createPipe().accept(this) + case c if c.createProcedure() != null => c.createProcedure().accept(this) + case c if c.createReplicationGroup() != null => c.createReplicationGroup().accept(this) + case c if c.createResourceMonitor() != null => c.createResourceMonitor().accept(this) + case c if c.createRole() != null => c.createRole().accept(this) + case c if c.createRowAccessPolicy() != null => c.createRowAccessPolicy().accept(this) + case c if c.createSchema() != null => c.createSchema().accept(this) + case c if c.createSecurityIntegrationExternalOauth() != null => + c.createSecurityIntegrationExternalOauth().accept(this) + case c if c.createSecurityIntegrationSnowflakeOauth() != null => + c.createSecurityIntegrationSnowflakeOauth().accept(this) + case c if c.createSecurityIntegrationSaml2() != null => c.createSecurityIntegrationSaml2().accept(this) + case c if c.createSecurityIntegrationScim() != null => c.createSecurityIntegrationScim().accept(this) + case c if c.createSequence() != null => c.createSequence().accept(this) + case c if c.createSessionPolicy() != null => c.createSessionPolicy().accept(this) + case c if c.createShare() != null => c.createShare().accept(this) + case c if c.createStage() != null => c.createStage().accept(this) + case c if c.createStorageIntegration() != null => c.createStorageIntegration().accept(this) + case c if c.createStream() != null => c.createStream().accept(this) + case c if c.createTable() != null => c.createTable().accept(this) + case c if c.createTableAsSelect() != null => c.createTableAsSelect().accept(this) + case c if c.createTableLike() != null => c.createTableLike().accept(this) + case c if c.createTag() != null => c.createTag().accept(this) + case c if c.createTask() != null => c.createTask().accept(this) + case c if c.createUser() != null => c.createUser().accept(this) + case c if c.createView() != null => c.createView().accept(this) + case c if c.createWarehouse() != null => c.createWarehouse().accept(this) + case _ => + ir.UnresolvedCatalog(ruleText = contextText(ctx), "Unknown CREATE XXX command", ruleName = "createCommand") + } + } + + override def visitCreateFunction(ctx: CreateFunctionContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val runtime = Option(ctx.id()).map(_.getText.toLowerCase(Locale.ROOT)).getOrElse("sql") + val runtimeInfo = + runtime match { + case r if r == "java" => buildJavaUDF(ctx) + case r if r == "python" => buildPythonUDF(ctx) + case r if r == "javascript" => ir.JavaScriptRuntimeInfo + case r if r == "scala" => buildScalaUDF(ctx) + case _ => ir.SQLRuntimeInfo(ctx.MEMOIZABLE() != null) + } + val name = ctx.dotIdentifier().getText + val returnType = vc.typeBuilder.buildDataType(ctx.dataType()) + val parameters = ctx.argDecl().asScala.map(buildParameter) + val acceptsNullParameters = ctx.CALLED() != null + val body = buildFunctionBody(ctx.functionDefinition()) + val comment = Option(ctx.com).map(extractString) + ir.CreateInlineUDF(name, returnType, parameters, runtimeInfo, acceptsNullParameters, comment, body) + } + + private def buildParameter(ctx: ArgDeclContext): ir.FunctionParameter = + ir.FunctionParameter( + name = ctx.id().getText, + dataType = vc.typeBuilder.buildDataType(ctx.dataType()), + defaultValue = Option(ctx.expr()).map(_.accept(vc.expressionBuilder))) + + private def buildFunctionBody(ctx: FunctionDefinitionContext): String = extractString(ctx.string()).trim + + private def buildJavaUDF(ctx: CreateFunctionContext): ir.RuntimeInfo = buildJVMUDF(ctx)(ir.JavaRuntimeInfo.apply) + private def buildScalaUDF(ctx: CreateFunctionContext): ir.RuntimeInfo = buildJVMUDF(ctx)(ir.ScalaRuntimeInfo.apply) + + private def buildJVMUDF(ctx: CreateFunctionContext)( + ctr: (Option[String], Seq[String], String) => ir.RuntimeInfo): ir.RuntimeInfo = { + val imports = + ctx + .stringList() + .asScala + .find(occursBefore(ctx.IMPORTS(), _)) + .map(_.string().asScala.map(extractString)) + .getOrElse(Seq()) + ctr(extractRuntimeVersion(ctx), imports, extractHandler(ctx)) + } + private def extractRuntimeVersion(ctx: CreateFunctionContext): Option[String] = ctx.string().asScala.collectFirst { + case c if occursBefore(ctx.RUNTIME_VERSION(), c) => extractString(c) + } + + private def extractHandler(ctx: CreateFunctionContext): String = + Option(ctx.HANDLER()).flatMap(h => ctx.string().asScala.find(occursBefore(h, _))).map(extractString).get + + private def buildPythonUDF(ctx: CreateFunctionContext): ir.PythonRuntimeInfo = { + val packages = + ctx + .stringList() + .asScala + .find(occursBefore(ctx.PACKAGES(0), _)) + .map(_.string().asScala.map(extractString)) + .getOrElse(Seq()) + ir.PythonRuntimeInfo(extractRuntimeVersion(ctx), packages, extractHandler(ctx)) + } + + override def visitCreateTable(ctx: CreateTableContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = ctx.dotIdentifier().getText + val columns = buildColumnDeclarations( + ctx + .createTableClause() + .columnDeclItemListParen() + .columnDeclItemList() + .columnDeclItem() + .asScala) + if (ctx.REPLACE() != null) { + ir.ReplaceTableCommand(tableName, columns, true) + } else { + ir.CreateTableCommand(tableName, columns) + } + + } + + override def visitCreateTableAsSelect(ctx: CreateTableAsSelectContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = ctx.dotIdentifier().getText + val selectStatement = ctx.queryStatement().accept(vc.relationBuilder) + // Currently TableType is not used in the IR and Databricks doesn't support Temporary Tables + val create = if (ctx.REPLACE() != null) { + ir.ReplaceTableAsSelect(tableName, selectStatement, Map.empty[String, String], true, false) + } else { + ir.CreateTableAsSelect(tableName, selectStatement, None, None, None) + } + // Wrapping the CreateTableAsSelect in a CreateTableParams to maintain implementation consistency + // TODO Capture other Table Properties + val colConstraints = Map.empty[String, Seq[ir.Constraint]] + val colOptions = Map.empty[String, Seq[ir.GenericOption]] + val constraints = Seq.empty[ir.Constraint] + val indices = Seq.empty[ir.Constraint] + val partition = None + val options = None + ir.CreateTableParams(create, colConstraints, colOptions, constraints, indices, partition, options) + } + + override def visitCreateStream(ctx: CreateStreamContext): ir.Catalog = + ir.UnresolvedCommand( + ruleText = contextText(ctx), + "CREATE STREAM UNSUPPORTED", + ruleName = contextRuleName(ctx), + tokenName = Some("STREAM")) + + override def visitCreateTask(ctx: CreateTaskContext): ir.Catalog = { + ir.UnresolvedCommand( + ruleText = contextText(ctx), + "CREATE TASK UNSUPPORTED", + ruleName = "createTask", + tokenName = Some("TASK")) + } + + private def buildColumnDeclarations(ctx: Seq[ColumnDeclItemContext]): Seq[ir.ColumnDeclaration] = { + // According to the grammar, either ctx.fullColDecl or ctx.outOfLineConstraint is non-null. + val columns = ctx.collect { + case c if c.fullColDecl() != null => buildColumnDeclaration(c.fullColDecl()) + } + // An out-of-line constraint may apply to one or many columns + // When an out-of-line constraint applies to multiple columns, + // we record a column-name -> constraint mapping for each. + val outOfLineConstraints: Seq[(String, ir.Constraint)] = ctx.collect { + case c if c.outOfLineConstraint() != null => buildOutOfLineConstraints(c.outOfLineConstraint()) + }.flatten + + // Finally, for every column, we "inject" the relevant out-of-line constraints + columns.map { col => + val additionalConstraints = outOfLineConstraints.collect { + case (columnName, constraint) if columnName == col.name => constraint + } + col.copy(constraints = col.constraints ++ additionalConstraints) + } + } + + private def buildColumnDeclaration(ctx: FullColDeclContext): ir.ColumnDeclaration = { + val name = ctx.colDecl().columnName().getText + val dataType = vc.typeBuilder.buildDataType(ctx.colDecl().dataType()) + val constraints = ctx.inlineConstraint().asScala.map(buildInlineConstraint) + val identityConstraints = if (ctx.defaultValue() != null) { + ctx.defaultValue().asScala.map(buildDefaultValue) + } else { + Seq() + } + val nullability = if (ctx.NULL().isEmpty) { + Seq() + } else { + Seq(ir.Nullability(ctx.NOT() == null)) + } + ir.ColumnDeclaration( + name, + dataType, + virtualColumnDeclaration = None, + nullability ++ constraints ++ identityConstraints) + } + + private def buildDefaultValue(ctx: DefaultValueContext): ir.Constraint = { + ctx match { + case c if c.DEFAULT() != null => ir.DefaultValueConstraint(c.expr().accept(vc.expressionBuilder)) + case c if c.AUTOINCREMENT() != null => ir.IdentityConstraint(None, None, always = true) + case c if c.IDENTITY() != null => + ir.IdentityConstraint(Some(ctx.startWith().getText), Some(ctx.incrementBy().getText), false, true) + } + } + + private[snowflake] def buildOutOfLineConstraints(ctx: OutOfLineConstraintContext): Seq[(String, ir.Constraint)] = { + val columnNames = ctx.columnListInParentheses(0).columnList().columnName().asScala.map(_.getText) + val repeatForEveryColumnName = List.fill[ir.UnnamedConstraint](columnNames.size)(_) + val unnamedConstraints = ctx match { + case c if c.UNIQUE() != null => repeatForEveryColumnName(ir.Unique(Seq.empty)) + case c if c.primaryKey() != null => repeatForEveryColumnName(ir.PrimaryKey(Seq.empty)) + case c if c.foreignKey() != null => + val referencedObject = c.dotIdentifier().getText + val references = + c.columnListInParentheses(1).columnList().columnName().asScala.map(referencedObject + "." + _.getText) + references.map(ref => ir.ForeignKey("", ref, "", Seq.empty)) + case c => repeatForEveryColumnName(ir.UnresolvedConstraint(c.getText)) + } + val constraintNameOpt = Option(ctx.id()).map(_.getText) + val constraints = constraintNameOpt.fold[Seq[ir.Constraint]](unnamedConstraints) { name => + unnamedConstraints.map(ir.NamedConstraint(name, _)) + } + columnNames.zip(constraints) + } + + private[snowflake] def buildInlineConstraint(ctx: InlineConstraintContext): ir.Constraint = ctx match { + case c if c.UNIQUE() != null => ir.Unique() + case c if c.primaryKey() != null => ir.PrimaryKey() + case c if c.foreignKey() != null => + val references = c.dotIdentifier().getText + Option(ctx.columnName()).map("." + _.getText).getOrElse("") + ir.ForeignKey("", references, "", Seq.empty) + case c => ir.UnresolvedConstraint(c.getText) + } + + override def visitCreateUser(ctx: CreateUserContext): ir.Catalog = + ir.UnresolvedCommand( + ruleText = contextText(ctx), + message = "CREATE USER UNSUPPORTED", + ruleName = "createUser", + tokenName = Some("USER")) + + override def visitAlterCommand(ctx: AlterCommandContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.alterTable() != null => c.alterTable().accept(this) + case _ => + ir.UnresolvedCommand( + ruleText = contextText(ctx), + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart)), + message = s"Unknown ALTER command variant") + } + } + + override def visitAlterTable(ctx: AlterTableContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = ctx.dotIdentifier(0).getText + ctx match { + case c if c.tableColumnAction() != null => + ir.AlterTableCommand(tableName, buildColumnActions(c.tableColumnAction())) + case c if c.constraintAction() != null => + ir.AlterTableCommand(tableName, buildConstraintActions(c.constraintAction())) + case _ => + ir.UnresolvedCommand( + ruleText = contextText(ctx), + message = "Unknown ALTER TABLE variant", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + private[snowflake] def buildColumnActions(ctx: TableColumnActionContext): Seq[ir.TableAlteration] = ctx match { + case c if c.ADD() != null => + Seq(ir.AddColumn(c.fullColDecl().asScala.map(buildColumnDeclaration))) + case c if !c.alterColumnClause().isEmpty => + c.alterColumnClause().asScala.map(buildColumnAlterations) + case c if c.DROP() != null => + Seq(ir.DropColumns(c.columnList().columnName().asScala.map(_.getText))) + case c if c.RENAME() != null => + Seq(ir.RenameColumn(c.columnName(0).getText, c.columnName(1).getText)) + case _ => + Seq( + ir.UnresolvedTableAlteration( + ruleText = contextText(ctx), + message = "Unknown COLUMN action variant", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart)))) + } + + private[snowflake] def buildColumnAlterations(ctx: AlterColumnClauseContext): ir.TableAlteration = { + val columnName = ctx.columnName().getText + ctx match { + case c if c.dataType() != null => + ir.ChangeColumnDataType(columnName, vc.typeBuilder.buildDataType(c.dataType())) + case c if c.DROP() != null && c.NULL() != null => + ir.DropConstraint(Some(columnName), ir.Nullability(c.NOT() == null)) + case c if c.NULL() != null => + ir.AddConstraint(columnName, ir.Nullability(c.NOT() == null)) + case _ => + ir.UnresolvedTableAlteration( + ruleText = contextText(ctx), + message = "Unknown ALTER COLUMN variant", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + private[snowflake] def buildConstraintActions(ctx: ConstraintActionContext): Seq[ir.TableAlteration] = ctx match { + case c if c.ADD() != null => + buildOutOfLineConstraints(c.outOfLineConstraint()).map(ir.AddConstraint.tupled) + case c if c.DROP() != null => + buildDropConstraints(c) + case c if c.RENAME() != null => + Seq(ir.RenameConstraint(c.id(0).getText, c.id(1).getText)) + case c => + Seq( + ir.UnresolvedTableAlteration( + ruleText = contextText(c), + message = "Unknown CONSTRAINT variant", + ruleName = vc.ruleName(c), + tokenName = Some(tokenName(ctx.getStart)))) + } + + private[snowflake] def buildDropConstraints(ctx: ConstraintActionContext): Seq[ir.TableAlteration] = { + val columnListOpt = Option(ctx.columnListInParentheses()) + val affectedColumns = columnListOpt.map(_.columnList().columnName().asScala.map(_.getText)).getOrElse(Seq()) + ctx match { + case c if c.primaryKey() != null => dropConstraints(affectedColumns, ir.PrimaryKey()) + case c if c.UNIQUE() != null => dropConstraints(affectedColumns, ir.Unique()) + case c if c.id.size() > 0 => Seq(ir.DropConstraintByName(c.id(0).getText)) + case _ => + Seq( + ir.UnresolvedTableAlteration( + ruleText = contextText(ctx), + message = "Unknown DROP constraint variant", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart)))) + } + } + + private def dropConstraints(affectedColumns: Seq[String], constraint: ir.Constraint): Seq[ir.TableAlteration] = { + if (affectedColumns.isEmpty) { + Seq(ir.DropConstraint(None, constraint)) + } else { + affectedColumns.map(col => ir.DropConstraint(Some(col), constraint)) + } + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilder.scala new file mode 100644 index 0000000000..1baf13b851 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilder.scala @@ -0,0 +1,152 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate.IRHelpers +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._ +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters._ + +class SnowflakeDMLBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.Modification] + with ParserCommon[ir.Modification] + with IRHelpers { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Modification = + ir.UnresolvedModification(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitDmlCommand(ctx: DmlCommandContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case q if q.queryStatement() != null => q.queryStatement().accept(this) + case i if i.insertStatement() != null => i.insertStatement().accept(this) + case i if i.insertMultiTableStatement() != null => i.insertMultiTableStatement().accept(this) + case u if u.updateStatement() != null => u.updateStatement().accept(this) + case d if d.deleteStatement() != null => d.deleteStatement().accept(this) + case m if m.mergeStatement() != null => m.mergeStatement().accept(this) + case _ => unresolved("dmlCommand", "everything is null") + } + } + + override def visitInsertStatement(ctx: InsertStatementContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val table = ctx.dotIdentifier().accept(vc.relationBuilder) + val columns = Option(ctx.ids).map(_.asScala).filter(_.nonEmpty).map(_.map(vc.expressionBuilder.buildId)) + val values = ctx match { + case c if c.queryStatement() != null => c.queryStatement().accept(vc.relationBuilder) + case c if c.valuesTableBody() != null => c.valuesTableBody().accept(vc.relationBuilder) + } + val overwrite = ctx.OVERWRITE() != null + ir.InsertIntoTable(table, columns, values, None, None, overwrite) + } + + override def visitDeleteStatement(ctx: DeleteStatementContext): ir.Modification = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.tableRef().accept(vc.relationBuilder) + val where = Option(ctx.searchCondition()).map(_.accept(vc.expressionBuilder)) + Option(ctx.tablesOrQueries()) match { + case Some(value) => + val relation = vc.relationBuilder.visit(value) + ir.MergeIntoTable(target, relation, where.getOrElse(ir.Noop), matchedActions = Seq(ir.DeleteAction(None))) + case None => ir.DeleteFromTable(target, where = where) + } + } + + override def visitUpdateStatement(ctx: UpdateStatementContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.tableRef().accept(vc.relationBuilder) + val set = vc.expressionBuilder.visitMany(ctx.setColumnValue()) + val sources = + Option(ctx.tableSources()).map(t => vc.relationBuilder.visitMany(t.tableSource()).foldLeft(target)(crossJoin)) + val where = Option(ctx.searchCondition()).map(_.accept(vc.expressionBuilder)) + ir.UpdateTable(target, sources, set, where, None, None) + } + + override def visitMergeStatement(ctx: MergeStatementContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.tableRef().accept(vc.relationBuilder) + val relation = ctx.tableSource().accept(vc.relationBuilder) + val predicate = ctx.searchCondition().accept(vc.expressionBuilder) + val matchedActions = ctx + .mergeCond() + .mergeCondMatch() + .asScala + .map(buildMatchAction) + + val notMatchedActions = ctx + .mergeCond() + .mergeCondNotMatch() + .asScala + .map(buildNotMatchAction) + + ir.MergeIntoTable( + target, + relation, + predicate, + matchedActions = matchedActions, + notMatchedActions = notMatchedActions) + } + + private def buildMatchAction(ctx: MergeCondMatchContext): ir.MergeAction = { + val condition = ctx match { + case c if c.searchCondition() != null => Some(c.searchCondition().accept(vc.expressionBuilder)) + case _ => None + } + + ctx match { + case d if d.mergeUpdateDelete().DELETE() != null => + ir.DeleteAction(condition) + case u if u.mergeUpdateDelete().UPDATE() != null => + val assign = u + .mergeUpdateDelete() + .setColumnValue() + .asScala + .map(vc.expressionBuilder.visitSetColumnValue) + .map { case a: ir.Assign => + a + } + ir.UpdateAction(condition, assign) + } + + } + + private def buildNotMatchAction(ctx: MergeCondNotMatchContext): ir.MergeAction = { + val condition = ctx match { + case c if c.searchCondition() != null => Some(c.searchCondition().accept(vc.expressionBuilder)) + case _ => None + } + ctx match { + case c if c.mergeInsert().columnList() != null => + val assignment = c + .mergeInsert() + .columnList() + .columnName() + .asScala + .map(_.accept(vc.expressionBuilder)) + .zip( + c + .mergeInsert() + .exprList() + .expr() + .asScala + .map(_.accept(vc.expressionBuilder))) + .map { case (col, value) => + ir.Assign(col, value) + } + + ir.InsertAction(condition, assignment) + + case _ => ir.InsertAction(condition, Seq.empty[ir.Assign]) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeErrorStrategy.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeErrorStrategy.scala new file mode 100644 index 0000000000..9cc1c223f3 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeErrorStrategy.scala @@ -0,0 +1,319 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.SqlErrorStrategy +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._ +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.IntervalSet + +import scala.collection.JavaConverters._ + +/** + * Custom error strategy for SQL parsing

While we do not do anything super special here, we wish to override a + * couple of the message generating methods and the token insert and delete messages, which do not create an exception + * and don't allow us to create an error message in context. Additionally, we can now implement i18n, should that ever + * become necessary.

+ * + *

At the moment, we require valid SQL as child to the conversion process, but if we ever change that strategy, then + * we can implement custom recovery steps here based upon context, though there is no improvement on the sync() + * method.

+ */ +class SnowflakeErrorStrategy extends SqlErrorStrategy { + + /** + * Generate a message for the error. + * + * The exception contains a stack trace, from which we can construct a more informative error message than just + * mismatched child and a huge list of things we were looking for. + * + * @param e + * the RecognitionException + * @return + * the error message + */ + override protected def generateMessage(recognizer: Parser, e: RecognitionException): String = { + // We build the messages by looking at the stack trace of the exception, but if the + // rule translation is not found, or it is the same as the previous message, we skip it, + // to avoid repeating the same message multiple times. This is because a recognition error + // could be found in a parent rule or a child rule but there is no extra information + // provided by being more specific about the rule location. ALos, in some productions + // we may be embedded very deeply in the stack trace, so we want to avoid too many contexts + // in a message. + val stack = e.getStackTrace + val messages = stack.foldLeft(Seq.empty[String]) { case (messageChunks, traceElement) => + val methodName = traceElement.getMethodName + val translatedMessageOpt = SnowflakeErrorStrategy.ruleTranslation.get(methodName) + translatedMessageOpt.fold(messageChunks) { translatedMessage => + if (messageChunks.isEmpty || messageChunks.last != translatedMessage) { + messageChunks :+ translatedMessage + } else { + messageChunks + } + } + } + + if (messages.isEmpty) { + "" + } else { + messages.mkString("while parsing a ", " in a ", "") + } + } + + private[this] val MaxExpectedTokensInErrorMessage = 12 + + /** + * When building the list of expected tokens, we do some custom manipulation so that we do not produce a list of 750 + * possible tokens because there are so many keywords that can be used as id/column names. If ID is a valid expected + * token, then we remove all the keywords that are there because they can be an ID. + * @param expected + * the set of valid tokens at this point in the parse, where the error was found + * @return + * the expected string with tokens renamed in more human friendly form + */ + override protected def buildExpectedMessage(recognizer: Parser, expected: IntervalSet): String = { + val expect = if (expected.contains(ID)) { + removeIdKeywords(expected) + } else { + expected + } + + val uniqueExpectedTokens = expect.toList.asScala.map { tokenId => + // Check if the token ID has a custom translation + SnowflakeErrorStrategy.tokenTranslation.get(tokenId) match { + case Some(translatedName) => translatedName + case None => recognizer.getVocabulary.getDisplayName(tokenId) + } + }.toSet + + val overflowMark = if (uniqueExpectedTokens.size > MaxExpectedTokensInErrorMessage) { + "..." + } else { + "" + } + uniqueExpectedTokens.toSeq + .sorted(capitalizedSort) + .take(MaxExpectedTokensInErrorMessage) + .mkString("", ", ", overflowMark) + } + + /** + * Runs through the given interval and removes all the keywords that are in the set. + * @param set + * The interval from whence to remove keywords that can be Identifiers + */ + private def removeIdKeywords(set: IntervalSet): IntervalSet = { + set.subtract(SnowflakeErrorStrategy.keywordIDs) + } +} + +object SnowflakeErrorStrategy { + + // A map that will override the default display name for tokens that represent text with + // pattern matches like IDENTIFIER, STRING, etc. + private[SnowflakeErrorStrategy] val tokenTranslation: Map[Int, String] = Map( + DOUBLE_QUOTE_ID -> "Identifier", + FLOAT -> "Float", + INT -> "Integer", + ID -> "Identifier", + LOCAL_ID -> "$Identifier", + REAL -> "Real", + STRING_START -> "'String'", + VAR_SIMPLE -> "&Variable reference", + VAR_COMPLEX -> "&{Variable} reference", + STRING_CONTENT -> "'String'", + STRING_END -> "'String'", + -1 -> "End of batch", + JINJA_REF -> "Jinja Template Element", + + // When the next thing we expect can be every statement, we just say "statement" + ALTER -> "Statement", + BEGIN -> "Statement", + COMMIT -> "Statement", + CONTINUE -> "Statement", + COPY -> "Statement", + CREATE -> "Statement", + DELETE -> "Statement", + DESCRIBE -> "Statement", + DROP -> "Statement", + END -> "Statement", + EXECUTE -> "Statement", + EXPLAIN -> "Statement", + FETCH -> "Statement", + GRANT -> "Statement", + IF -> "Statement", + INSERT -> "Statement", + LIST -> "Statement", + MERGE -> "Statement", + PUT -> "Statement", + REMOVE -> "Statement", + REVOKE -> "Statement", + ROLLBACK -> "Statement", + SELECT -> "Select Statement", + SET -> "Statement", + SHOW -> "Statement", + TRUNCATE -> "Statement", + UNDROP -> "Statement", + UNSET -> "Statement", + UPDATE -> "Statement", + USE -> "Statement", + WITHIN -> "Statement", + + // No need to distinguish between operators + + PIPE_PIPE -> "Operator", + EQ -> "Operator", + GT -> "Operator", + GE -> "Operator", + LT -> "Operator", + LTGT -> "Operator", + LE -> "Operator", + STAR -> "Operator", + DIVIDE -> "Operator", + TILDA -> "Operator", + NE -> "Operator", + MINUS -> "Operator", + PLUS -> "Operator") + + private[SnowflakeErrorStrategy] val ruleTranslation: Map[String, String] = Map( + "alterCommand" -> "ALTER command", + "batch" -> "Snowflake batch", + "beginTxn" -> "BEGIN WORK | TRANSACTION statement", + "copyIntoTable" -> "COPY statement", + "ddlObject" -> "TABLE object", + "executeImmediate" -> "EXECUTE command", + "explain" -> "EXPLAIN command", + "groupByClause" -> "GROUP BY clause", + "havingClause" -> "HAVING clause", + "insertMultiTableStatement" -> "INSERT statement", + "insertStatement" -> "INSERT statement", + "joinClause" -> "JOIN clause", + "limitClause" -> "LIMIT clause", + "mergeStatement" -> "MERGE statement", + "objectRef" -> "Object reference", + "offsetClause" -> "OFFSET clause", + "orderByClause" -> "ORDER BY clause", + "otherCommand" -> "SQL command", + "outputClause" -> "OUTPUT clause", + "selectList" -> "SELECT list", + "selectStatement" -> "SELECT statement", + "snowflakeFile" -> "Snowflake batch", + "topClause" -> "TOP clause", + "update" -> "UPDATE statement", + "updateElem" -> "UPDATE element specification", + "updateStatement" -> "UPDATE statement", + "updateStatement" -> "UPDATE statement", + "updateWhereClause" -> "WHERE clause", + "whereClause" -> "WHERE clause", + "withTableHints" -> "WITH table hints", + // Etc + + "tableSource" -> "table source", + "tableSourceItem" -> "table source") + + private[SnowflakeErrorStrategy] val keywordIDs: IntervalSet = new IntervalSet( + ACCOUNTADMIN, + ACTION, + ACTION, + AES, + ALERT, + ARRAY, + ARRAY_AGG, + AT_KEYWORD, + CHECKSUM, + CLUSTER, + COLLATE, + COLLECTION, + COMMENT, + CONDITION, + CONFIGURATION, + COPY_OPTIONS_, + DATA, + DATE_FORMAT, + DEFINITION, + DELTA, + DENSE_RANK, + DIRECTION, + DOWNSTREAM, + DUMMY, + DYNAMIC, + EDITION, + END, + EMAIL, + EVENT, + EXCHANGE, + EXPIRY_DATE, + FIRST, + FIRST_NAME, + FLATTEN, + FLOOR, + FUNCTION, + GET, + GLOBAL, + IDENTIFIER, + IDENTITY, + IF, + INDEX, + INPUT, + INTERVAL, + KEY, + KEYS, + LANGUAGE, + LAST_NAME, + LAST_QUERY_ID, + LEAD, + LENGTH, + LOCAL, + MAX_CONCURRENCY_LEVEL, + MODE, + NAME, + NETWORK, + NOORDER, + OFFSET, + OPTION, + ORDER, + ORGADMIN, + OUTBOUND, + OUTER, + PARTITION, + PATH, + PATTERN, + PORT, + PROCEDURE_NAME, + PROPERTY, + PROVIDER, + PUBLIC, + RANK, + RECURSIVE, + REGION, + REPLACE, + RESOURCE, + RESOURCES, + RESPECT, + RESTRICT, + RESULT, + RLIKE, + ROLE, + SECURITYADMIN, + SHARES, + SOURCE, + STAGE, + START, + STATE, + STATS, + SYSADMIN, + TABLE, + TAG, + TAGS, + TARGET_LAG, + TEMP, + TIMESTAMP, + TIMEZONE, + TYPE, + URL, + USER, + USERADMIN, + VALUE, + VALUES, + VERSION, + WAREHOUSE, + WAREHOUSE_TYPE) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilder.scala new file mode 100644 index 0000000000..6b7fd1efdf --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilder.scala @@ -0,0 +1,847 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => _, _} +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.Token + +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +class SnowflakeExpressionBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.Expression] + with ParserCommon[ir.Expression] + with ir.IRHelpers { + + private[this] val functionBuilder = new SnowflakeFunctionBuilder + private[this] val typeBuilder = new SnowflakeTypeBuilder + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Expression = + ir.UnresolvedExpression(ruleText = ruleText, message = message) + + // Concrete visitors.. + + override def visitFunctionCall(ctx: FunctionCallContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case b if b.builtinFunction() != null => b.builtinFunction().accept(this) + case s if s.standardFunction() != null => s.standardFunction().accept(this) + case a if a.aggregateFunction() != null => a.aggregateFunction().accept(this) + case r if r.rankingWindowedFunction() != null => r.rankingWindowedFunction().accept(this) + } + } + + override def visitValuesTable(ctx: ValuesTableContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.valuesTableBody().accept(this) + } + + override def visitGroupByElem(ctx: GroupByElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.columnElem() != null => c.columnElem().accept(this) + case n if n.INT() != null => ir.NumericLiteral(n.INT().getText) + case e if e.expressionElem() != null => e.expressionElem().accept(this) + } + } + + override def visitId(ctx: IdContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildId(ctx) + } + + private[snowflake] def buildId(ctx: IdContext): ir.Id = ctx match { + case c if c.DOUBLE_QUOTE_ID() != null => + val idValue = c.getText.trim.stripPrefix("\"").stripSuffix("\"").replaceAll("\"\"", "\"") + ir.Id(idValue, caseSensitive = true) + case v if v.AMP() != null => + // Note that there is nothing special about &id other than they become $id in Databricks + // Many places in the builder concatenate the output of visitId with other strings and so we + // lose the ir.Dot(ir.Variable, ir.Id) that we could pick up and therefore propagate ir.Variable if + // we wanted to leave the translation to generate phase. I think we probably do want to do that, but + // a lot of code has bypassed accept() and called visitId directly, and expects ir.Id, + // then uses fields from it. + // + // To rework that is quite a big job. So, for now, we translate &id to $id here. + // It is not wrong for the id rule to hold the AMP ID alt, but ideally it would produce + // an ir.Variable and we would process that at generation time instead of concatenating into strings :( + ir.Id(s"$$${v.ID().getText}") + case d if d.LOCAL_ID() != null => + ir.Id(s"$$${d.LOCAL_ID().getText.drop(1)}") + case id => ir.Id(id.getText) + } + + override def visitSelectListElem(ctx: SelectListElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val rawExpression = ctx match { + case c if c.columnElem() != null => c.columnElem().accept(this) + case c if c.expressionElem() != null => c.expressionElem().accept(this) + case c if c.columnElemStar() != null => c.columnElemStar().accept(this) + } + buildAlias(ctx.asAlias(), rawExpression) + } + + override def visitExpressionElem(ctx: ExpressionElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case e if e.expr() != null => e.expr().accept(this) + case p if p.searchCondition() != null => p.searchCondition().accept(this) + } + } + + override def visitColumnElem(ctx: ColumnElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val objectNameIds: Seq[ir.NameOrPosition] = Option(ctx.dotIdentifier()).toSeq.flatMap(_.id().asScala.map(buildId)) + val columnIds: Seq[ir.NameOrPosition] = ctx match { + case c if c.columnName() != null => c.columnName().id().asScala.map(buildId) + case c if c.columnPosition() != null => Seq(visitColumnPosition(c.columnPosition())) + } + val fqn = objectNameIds ++ columnIds + val objectRefIds = fqn.take(fqn.size - 1) + val objectRef = if (objectRefIds.isEmpty) { + None + } else { + Some(ir.ObjectReference(objectRefIds.head, objectRefIds.tail: _*)) + } + ir.Column(objectRef, fqn.last) + } + + override def visitDotIdentifier(ctx: DotIdentifierContext): ir.ObjectReference = { + val ids = ctx.id().asScala.map(buildId) + ir.ObjectReference(ids.head, ids.tail: _*) + } + + override def visitColumnPosition(ctx: ColumnPositionContext): ir.Position = { + ir.Position(ctx.INT().getText.toInt) + } + + override def visitColumnElemStar(ctx: ColumnElemStarContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Star(Option(ctx.dotIdentifier()).map { on => + val objectNameIds = on.id().asScala.map(buildId) + ir.ObjectReference(objectNameIds.head, objectNameIds.tail: _*) + }) + } + + private def buildAlias(ctx: AsAliasContext, input: ir.Expression): ir.Expression = + Option(ctx).fold(input) { c => + val alias = buildId(c.alias().id()) + ir.Alias(input, alias) + } + + override def visitColumnName(ctx: ColumnNameContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.id().asScala match { + case Seq(columnName) => ir.Column(None, buildId(columnName)) + case Seq(tableNameOrAlias, columnName) => + ir.Column(Some(ir.ObjectReference(buildId(tableNameOrAlias))), buildId(columnName)) + } + } + + override def visitOrderItem(ctx: OrderItemContext): ir.SortOrder = { + val direction = if (ctx.DESC() != null) ir.Descending else ir.Ascending + val nullOrdering = if (direction == ir.Descending) { + if (ctx.LAST() != null) { + ir.NullsLast + } else { + ir.NullsFirst + } + } else { + if (ctx.FIRST() != null) { + ir.NullsFirst + } else { + ir.NullsLast + } + } + ir.SortOrder(ctx.expr().accept(this), direction, nullOrdering) + } + + override def visitLiteral(ctx: LiteralContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val sign = Option(ctx.sign()).map(_ => "-").getOrElse("") + ctx match { + case c if Option(c.id()).exists(_.getText.toLowerCase(Locale.ROOT) == "date") => + val dateStr = c.string().getText.stripPrefix("'").stripSuffix("'") + Try(java.time.LocalDate.parse(dateStr)) + .map(ir.Literal(_)) + .getOrElse(ir.Literal.Null) + case c if c.TIMESTAMP() != null => + val timestampStr = c.string.getText.stripPrefix("'").stripSuffix("'") + val format = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") + Try(LocalDateTime.parse(timestampStr, format)) + .map(ir.Literal(_)) + .getOrElse(ir.Literal.Null) + case c if c.string() != null => c.string.accept(this) + case c if c.INT() != null => ir.NumericLiteral(sign + c.INT().getText) + case c if c.FLOAT() != null => ir.NumericLiteral(sign + c.FLOAT().getText) + case c if c.REAL() != null => ir.NumericLiteral(sign + c.REAL().getText) + case c if c.NULL() != null => ir.Literal.Null + case c if c.trueFalse() != null => visitTrueFalse(c.trueFalse()) + case c if c.jsonLiteral() != null => visitJsonLiteral(c.jsonLiteral()) + case c if c.arrayLiteral() != null => visitArrayLiteral(c.arrayLiteral()) + case _ => ir.Literal.Null + } + } + + /** + * Reconstruct a string literal from its composite parts, translating variable + * references on the fly. + *

+ * A string literal is a sequence of tokens identifying either a variable reference + * or a piece of normal text. At this point in time, we basically re-assemble the pieces + * here into an ir.StringLiteral. The variable references are translated here into the Databricks + * SQL equivalent, which is $id. + *

+ *

+ * Note however that we really should be generating something like ir.CompositeString(Seq[something]) + * and then anywhere our ir currently uses a String ir ir.StringLiteral, we should be using ir.CompositeString, + * which will then be correctly translated at generation time. We wil get there in increments however - for + * now, this hack will correctly translate variable references in string literals. + *

+ * + * @param ctx the parse tree + */ + override def visitString(ctx: SnowflakeParser.StringContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + + // $$string$$ means interpret the string as raw string with no variable substitution, escape sequences, etc. + // TODO: Do we need a raw flag in the ir.StringLiteral so that we generate r'sdfsdfsdsfds' for Databricks SQL? + // or is r'string' a separate Ir in Spark? + case _ if ctx.DOLLAR_STRING() != null => + val str = ctx.DOLLAR_STRING().getText.stripPrefix("$$").stripSuffix("$$") + ir.StringLiteral(str) + + // Else we must have composite string literal + case _ => + val str = if (ctx.stringPart() == null) { + "" + } else { + ctx + .stringPart() + .asScala + .map { + case p if p.VAR_SIMPLE() != null => s"$${${p.VAR_SIMPLE().getText.drop(1)}}" // &var => ${var} (soon) + case p if p.VAR_COMPLEX() != null => s"$$${p.VAR_COMPLEX().getText.drop(1)}" // &{var} => ${var} + case p if p.STRING_AMPAMP() != null => "&" // && => & + case p if p.STRING_CONTENT() != null => p.STRING_CONTENT().getText + case p if p.STRING_ESCAPE() != null => p.STRING_ESCAPE().getText + case p if p.STRING_SQUOTE() != null => "''" // Escaped single quote + case p if p.STRING_UNICODE() != null => p.STRING_UNICODE().getText + case _ => removeQuotes(ctx.getText) + } + .mkString + } + ir.StringLiteral(str) + } + } + + private def removeQuotes(str: String): String = { + str.stripPrefix("'").stripSuffix("'") + } + + override def visitTrueFalse(ctx: TrueFalseContext): ir.Literal = + ctx.TRUE() match { + case null => ir.Literal.False + case _ => ir.Literal.True + } + + override def visitExprNot(ctx: ExprNotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.NOT().asScala.foldLeft(ctx.expr().accept(this)) { case (e, _) => ir.Not(e) } + } + + override def visitExprAnd(ctx: ExprAndContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expr(0).accept(this) + val right = ctx.expr(1).accept(this) + ir.And(left, right) + } + + override def visitExprOr(ctx: ExprOrContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expr(0).accept(this) + val right = ctx.expr(1).accept(this) + ir.Or(left, right) + } + + override def visitNonLogicalExpression(ctx: NonLogicalExpressionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.expression().accept(this) + } + + override def visitExprPrecedence(ctx: ExprPrecedenceContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.expression().accept(this) + } + + override def visitExprNextval(ctx: ExprNextvalContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + NextValue(ctx.dotIdentifier().getText) + } + + override def visitExprDot(ctx: ExprDotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val lhs = ctx.expression(0).accept(this) + val rhs = ctx.expression(1).accept(this) + ir.Dot(lhs, rhs) + } + + override def visitExprColon(ctx: ExprColonContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val lhs = ctx.expression(0).accept(this) + val rhs = ctx.expression(1).accept(this) + ir.JsonAccess(lhs, rhs) + } + + override def visitExprCollate(ctx: ExprCollateContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Collate(ctx.expression().accept(this), removeQuotes(ctx.string().getText)) + } + + override def visitExprCase(ctx: ExprCaseContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.caseExpression().accept(this) + } + + override def visitExprIff(ctx: ExprIffContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.iffExpr().accept(this) + } + + override def visitExprComparison(ctx: ExprComparisonContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + buildComparisonExpression(ctx.comparisonOperator(), left, right) + } + + override def visitExprDistinct(ctx: ExprDistinctContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Distinct(ctx.expression().accept(this)) + } + + override def visitExprWithinGroup(ctx: ExprWithinGroupContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expr = ctx.expression().accept(this) + val sortOrders = buildSortOrder(ctx.withinGroup().orderByClause()) + ir.WithinGroup(expr, sortOrders) + } + + override def visitExprOver(ctx: ExprOverContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildWindow(ctx.overClause(), ctx.expression().accept(this)) + } + + override def visitExprCast(ctx: ExprCastContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.castExpr().accept(this) + } + + override def visitExprAscribe(ctx: ExprAscribeContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Cast(ctx.expression().accept(this), typeBuilder.buildDataType(ctx.dataType())) + } + + override def visitExprSign(ctx: ExprSignContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.sign() match { + case c if c.PLUS() != null => ir.UPlus(ctx.expression().accept(this)) + case c if c.MINUS() != null => ir.UMinus(ctx.expression().accept(this)) + } + } + + override def visitExprPrecedence0(ctx: ExprPrecedence0Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryOperation(ctx.op, ctx.expression(0).accept(this), ctx.expression(1).accept(this)) + } + + override def visitExprPrecedence1(ctx: ExprPrecedence1Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryOperation(ctx.op, ctx.expression(0).accept(this), ctx.expression(1).accept(this)) + } + + override def visitExprPrimitive(ctx: ExprPrimitiveContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.primitiveExpression().accept(this) + } + + override def visitExprFuncCall(ctx: ExprFuncCallContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.functionCall().accept(this) + } + + override def visitJsonLiteral(ctx: JsonLiteralContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val fields = ctx.kvPair().asScala.map { kv => + val fieldName = removeQuotes(kv.key.getText) + val fieldValue = visitLiteral(kv.literal()) + ir.Alias(fieldValue, ir.Id(fieldName)) + } + ir.StructExpr(fields) + } + + override def visitArrayLiteral(ctx: ArrayLiteralContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val elements = ctx.expr().asScala.map(_.accept(this)).toList + // TODO: The current type determination may be too naive + // but this does not affect code generation as the generator does not use it. + // Here we determine the type of the array by inspecting the first expression in the array literal, + // but when an array literal contains a double or a cast and the first value appears to be an integer, + // then the array literal type should probably be typed as DoubleType and not IntegerType, which means + // we need a function that types all the expressions and types it as the most general type. + val dataType = elements.headOption.map(_.dataType).getOrElse(ir.UnresolvedType) + ir.ArrayExpr(elements, dataType) + } + + override def visitPrimArrayAccess(ctx: PrimArrayAccessContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.ArrayAccess(ctx.id().accept(this), ir.NumericLiteral(ctx.INT().getText)) + } + + override def visitPrimExprColumn(ctx: PrimExprColumnContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.id().accept(this) + } + + override def visitPrimObjectAccess(ctx: PrimObjectAccessContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.JsonAccess(ctx.id().accept(this), ir.Id(removeQuotes(ctx.string().getText))) + } + + override def visitPrimExprLiteral(ctx: PrimExprLiteralContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.literal().accept(this) + } + + private def buildBinaryOperation(operator: Token, left: ir.Expression, right: ir.Expression): ir.Expression = + operator.getType match { + case STAR => ir.Multiply(left, right) + case DIVIDE => ir.Divide(left, right) + case PLUS => ir.Add(left, right) + case MINUS => ir.Subtract(left, right) + case MODULE => ir.Mod(left, right) + case PIPE_PIPE => ir.Concat(Seq(left, right)) + } + + private[snowflake] def buildComparisonExpression( + op: ComparisonOperatorContext, + left: ir.Expression, + right: ir.Expression): ir.Expression = { + if (op.EQ() != null) { + ir.Equals(left, right) + } else if (op.NE() != null || op.LTGT() != null) { + ir.NotEquals(left, right) + } else if (op.GT() != null) { + ir.GreaterThan(left, right) + } else if (op.LT() != null) { + ir.LessThan(left, right) + } else if (op.GE() != null) { + ir.GreaterThanOrEqual(left, right) + } else if (op.LE() != null) { + ir.LessThanOrEqual(left, right) + } else { + ir.UnresolvedExpression( + ruleText = contextText(op), + message = + s"Unknown comparison operator ${contextText(op)} in SnowflakeExpressionBuilder.buildComparisonExpression", + ruleName = vc.ruleName(op), + tokenName = Some(tokenName(op.getStart))) + } + } + + override def visitIffExpr(ctx: IffExprContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val condition = ctx.searchCondition().accept(this) + val thenBranch = ctx.expr(0).accept(this) + val elseBranch = ctx.expr(1).accept(this) + ir.If(condition, thenBranch, elseBranch) + } + + override def visitCastExpr(ctx: CastExprContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.castOp != null => + val expression = c.expr().accept(this) + val dataType = typeBuilder.buildDataType(c.dataType()) + ctx.castOp.getType match { + case CAST => ir.Cast(expression, dataType) + case TRY_CAST => ir.TryCast(expression, dataType) + } + + case c if c.INTERVAL() != null => + ir.Cast(c.expr().accept(this), ir.IntervalType) + } + } + + override def visitRankingWindowedFunction(ctx: RankingWindowedFunctionContext): ir.Expression = errorCheck( + ctx) match { + case Some(errorResult) => errorResult + case None => + val ignore_nulls = if (ctx.ignoreOrRepectNulls() != null) { + ctx.ignoreOrRepectNulls().getText.equalsIgnoreCase("IGNORENULLS") + } else false + buildWindow(ctx.overClause(), ctx.standardFunction().accept(this), ignore_nulls) + } + + private def buildWindow( + ctx: OverClauseContext, + windowFunction: ir.Expression, + ignore_nulls: Boolean = false): ir.Expression = { + val partitionSpec = visitMany(ctx.expr()) + val sortOrder = + Option(ctx.windowOrderingAndFrame()).map(c => buildSortOrder(c.orderByClause())).getOrElse(Seq()) + + val frameSpec = + Option(ctx.windowOrderingAndFrame()) + .flatMap(c => Option(c.rowOrRangeClause())) + .map(buildWindowFrame) + .orElse(snowflakeDefaultFrameSpec(windowFunction)) + + ir.Window( + window_function = windowFunction, + partition_spec = partitionSpec, + sort_order = sortOrder, + frame_spec = frameSpec, + ignore_nulls = ignore_nulls) + } + + // see: https://docs.snowflake.com/en/sql-reference/functions-analytic#list-of-window-functions + // default frameSpec(UNBOUNDED FOLLOWING) is not supported for: + // "LAG", "DENSE_RANK","LEAD", "PERCENT_RANK","RANK","ROW_NUMBER" + private[this] val rankRelatedWindowFunctions = Set("CUME_DIST", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", "NTILE") + + /** + * For rank-related window functions, snowflake's default frame deviate from ANSI standard. So in such case, we must + * make the frame specification explicit. see: + * https://docs.snowflake.com/en/sql-reference/functions-analytic#usage-notes-for-window-frames + */ + private def snowflakeDefaultFrameSpec(windowFunction: ir.Expression): Option[ir.WindowFrame] = { + val rankRelatedDefaultFrameSpec = ir.WindowFrame(ir.RowsFrame, ir.UnboundedPreceding, ir.UnboundedFollowing) + windowFunction match { + case fn: ir.Fn if rankRelatedWindowFunctions.contains(fn.prettyName) => Some(rankRelatedDefaultFrameSpec) + case _ => None + } + } + + private[snowflake] def buildSortOrder(ctx: OrderByClauseContext): Seq[ir.SortOrder] = { + ctx.orderItem().asScala.map(visitOrderItem) + } + + private def buildWindowFrame(ctx: RowOrRangeClauseContext): ir.WindowFrame = { + val frameType = if (ctx.ROWS() != null) ir.RowsFrame else ir.RangeFrame + val lower = buildFrameBound(ctx.windowFrameExtent().windowFrameBound(0)) + val upper = buildFrameBound(ctx.windowFrameExtent().windowFrameBound(1)) + ir.WindowFrame(frameType, lower, upper) + } + + private def buildFrameBound(ctx: WindowFrameBoundContext): ir.FrameBoundary = ctx match { + case c if c.UNBOUNDED() != null && c.PRECEDING != null => ir.UnboundedPreceding + case c if c.UNBOUNDED() != null && c.FOLLOWING() != null => ir.UnboundedFollowing + case c if c.INT() != null && c.PRECEDING() != null => ir.PrecedingN(ir.NumericLiteral(c.INT.getText)) + case c if c.INT() != null && c.FOLLOWING() != null => ir.FollowingN(ir.NumericLiteral(c.INT.getText)) + case c if c.CURRENT() != null => ir.CurrentRow + } + + override def visitStandardFunction(ctx: StandardFunctionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val functionName = fetchFunctionName(ctx) + val arguments = ctx match { + case c if c.exprList() != null => visitMany(c.exprList().expr()) + case c if c.paramAssocList() != null => c.paramAssocList().paramAssoc().asScala.map(_.accept(this)) + case _ => Seq.empty + } + functionBuilder.buildFunction(functionName, arguments) + } + + private def fetchFunctionName(ctx: StandardFunctionContext): String = { + if (ctx.functionName() != null) { + ctx.functionName() match { + case c if c.id() != null => buildId(c.id()).id + case c if c.nonReservedFunctionName() != null => c.nonReservedFunctionName().getText + } + } else { + ctx.functionOptionalBrackets().getText + } + } + + // aggregateFunction + + override def visitAggFuncExprList(ctx: AggFuncExprListContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val param = visitMany(ctx.exprList().expr()) + functionBuilder.buildFunction(buildId(ctx.id()), param) + } + + override def visitAggFuncStar(ctx: AggFuncStarContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + functionBuilder.buildFunction(buildId(ctx.id()), Seq(ir.Star(None))) + } + + override def visitAggFuncList(ctx: AggFuncListContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val param = ctx.expr().accept(this) + val separator = Option(ctx.string()).map(s => ir.Literal(removeQuotes(s.getText))) + ctx.op.getType match { + case LISTAGG => functionBuilder.buildFunction("LISTAGG", param +: separator.toSeq) + case ARRAY_AGG => functionBuilder.buildFunction("ARRAYAGG", Seq(param)) + } + } + + override def visitBuiltinExtract(ctx: BuiltinExtractContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val part = if (ctx.ID() != null) { ir.Id(removeQuotes(ctx.ID().getText)) } + else { + buildIdFromString(ctx.string()) + } + val date = ctx.expr().accept(this) + functionBuilder.buildFunction(ctx.EXTRACT().getText, Seq(part, date)) + } + + private def buildIdFromString(ctx: SnowflakeParser.StringContext): ir.Id = ctx.accept(this) match { + case ir.StringLiteral(s) => ir.Id(s) + case _ => throw new IllegalArgumentException("Expected a string literal") + } + + override def visitCaseExpression(ctx: CaseExpressionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val exprs = ctx.expr().asScala + val otherwise = Option(ctx.ELSE()).flatMap(els => exprs.find(occursBefore(els, _)).map(_.accept(this))) + ctx match { + case c if c.switchSection().size() > 0 => + val expression = exprs.find(occursBefore(_, ctx.switchSection(0))).map(_.accept(this)) + val branches = c.switchSection().asScala.map { branch => + ir.WhenBranch(branch.expr(0).accept(this), branch.expr(1).accept(this)) + } + ir.Case(expression, branches, otherwise) + case c if c.switchSearchConditionSection().size() > 0 => + val branches = c.switchSearchConditionSection().asScala.map { branch => + ir.WhenBranch(branch.searchCondition().accept(this), branch.expr().accept(this)) + } + ir.Case(None, branches, otherwise) + } + } + + // Search conditions and predicates + + override def visitScNot(ctx: ScNotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Not(ctx.searchCondition().accept(this)) + } + + override def visitScAnd(ctx: ScAndContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.And(ctx.searchCondition(0).accept(this), ctx.searchCondition(1).accept(this)) + } + + override def visitScOr(ctx: ScOrContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Or(ctx.searchCondition(0).accept(this), ctx.searchCondition(1).accept(this)) + } + + override def visitScPred(ctx: ScPredContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.predicate().accept(this) + } + + override def visitScPrec(ctx: ScPrecContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.searchCondition.accept(this) + } + + override def visitPredExists(ctx: PredExistsContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Exists(ctx.subquery().accept(vc.relationBuilder)) + } + + override def visitPredBinop(ctx: PredBinopContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + ctx.comparisonOperator match { + case op if op.LE != null => ir.LessThanOrEqual(left, right) + case op if op.GE != null => ir.GreaterThanOrEqual(left, right) + case op if op.LTGT != null => ir.NotEquals(left, right) + case op if op.NE != null => ir.NotEquals(left, right) + case op if op.EQ != null => ir.Equals(left, right) + case op if op.GT != null => ir.GreaterThan(left, right) + case op if op.LT != null => ir.LessThan(left, right) + } + } + + override def visitPredASA(ctx: PredASAContext): ir.Expression = + // TODO: build ASA + ir.UnresolvedExpression( + ruleText = contextText(ctx), + message = "ALL | SOME | ANY is not yet supported", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + + override def visitPredBetween(ctx: PredBetweenContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val lowerBound = ctx.expression(1).accept(this) + val upperBound = ctx.expression(2).accept(this) + val expression = ctx.expression(0).accept(this) + val between = ir.Between(expression, lowerBound, upperBound) + Option(ctx.NOT()).fold[ir.Expression](between)(_ => ir.Not(between)) + } + + override def visitPredIn(ctx: PredInContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val in = if (ctx.subquery() != null) { + // In the result of a sub query + ir.In(ctx.expression().accept(this), Seq(ir.ScalarSubquery(ctx.subquery().accept(vc.relationBuilder)))) + } else { + // In a list of expressions + ir.In(ctx.expression().accept(this), ctx.exprList().expr().asScala.map(_.accept(this))) + } + Option(ctx.NOT()).fold[ir.Expression](in)(_ => ir.Not(in)) + } + + override def visitPredLikeSinglePattern(ctx: PredLikeSinglePatternContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + // NB: The escape character is a complete expression that evaluates to a single char at runtime + // and not a single char at parse time. + val escape = Option(ctx.expression(2)) + .map(_.accept(this)) + val like = ctx.op.getType match { + case LIKE => ir.Like(left, right, escape) + case ILIKE => ir.ILike(left, right, escape) + } + Option(ctx.NOT()).fold[ir.Expression](like)(_ => ir.Not(like)) + } + + override def visitPredLikeMultiplePatterns(ctx: PredLikeMultiplePatternsContext): ir.Expression = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val patterns = visitMany(ctx.exprListInParentheses().exprList().expr()) + val normalizedPatterns = normalizePatterns(patterns, ctx.expression(1)) + val like = ctx.op.getType match { + case LIKE if ctx.ALL() != null => ir.LikeAll(left, normalizedPatterns) + case LIKE => ir.LikeAny(left, normalizedPatterns) + case ILIKE if ctx.ALL() != null => ir.ILikeAll(left, normalizedPatterns) + case ILIKE => ir.ILikeAny(left, normalizedPatterns) + } + Option(ctx.NOT()).fold[ir.Expression](like)(_ => ir.Not(like)) + } + + override def visitPredRLike(ctx: PredRLikeContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + val rLike = ir.RLike(left, right) + Option(ctx.NOT()).fold[ir.Expression](rLike)(_ => ir.Not(rLike)) + } + + override def visitPredIsNull(ctx: PredIsNullContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expression = ctx.expression().accept(this) + if (ctx.NOT() != null) ir.IsNotNull(expression) else ir.IsNull(expression) + } + + override def visitPredExpr(ctx: PredExprContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.expression().accept(this) + } + + private def normalizePatterns(patterns: Seq[ir.Expression], escape: ExpressionContext): Seq[ir.Expression] = { + Option(escape) + .map(_.accept(this)) + .collect { case ir.StringLiteral(esc) => + patterns.map { + case ir.StringLiteral(pat) => + val escapedPattern = pat.replace(esc, s"\\$esc") + ir.StringLiteral(escapedPattern) + case e => ir.StringReplace(e, ir.Literal(esc), ir.Literal("\\")) + } + } + .getOrElse(patterns) + } + + override def visitParamAssoc(ctx: ParamAssocContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + NamedArgumentExpression(ctx.assocId().getText.toUpperCase(), ctx.expr().accept(this)) + } + + override def visitSetColumnValue(ctx: SetColumnValueContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Assign(ctx.columnName().accept(this), ctx.expr().accept(this)) + } + + override def visitExprSubquery(ctx: ExprSubqueryContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.ScalarSubquery(ctx.subquery().accept(vc.relationBuilder)) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeFunctionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeFunctionBuilder.scala new file mode 100644 index 0000000000..cc4b5212b1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeFunctionBuilder.scala @@ -0,0 +1,843 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeFunctionConverters.SynonymOf +import com.databricks.labs.remorph.parsers.{ConversionStrategy, FunctionBuilder, FunctionDefinition} +import com.databricks.labs.remorph.{intermediate => ir} + +class SnowflakeFunctionBuilder extends FunctionBuilder { + + private[this] val SnowflakeFunctionDefinitionPf: PartialFunction[String, FunctionDefinition] = { + + case "ABS" => FunctionDefinition.standard(1) + case "ACOS" => FunctionDefinition.standard(1) + case "ACOSH" => FunctionDefinition.standard(1) + case "ADD_MONTHS" => FunctionDefinition.standard(2) + case "ALERT_HISTORY" => FunctionDefinition.standard(4) + case "ALL_USER_NAMES" => FunctionDefinition.standard(0) + case "ANY_VALUE" => FunctionDefinition.standard(1) + case "APPLICATION_JSON" => FunctionDefinition.standard(1) + case "APPROXIMATE_JACCARD_INDEX" => FunctionDefinition.standard(1) + case "APPROXIMATE_SIMILARITY" => FunctionDefinition.standard(1) + case "APPROX_COUNT_DISTINCT" => FunctionDefinition.standard(1, Int.MaxValue) + case "APPROX_PERCENTILE" => FunctionDefinition.standard(2) + case "APPROX_PERCENTILE_ACCUMULATE" => FunctionDefinition.standard(1) + case "APPROX_PERCENTILE_COMBINE" => FunctionDefinition.standard(1) + case "APPROX_PERCENTILE_ESTIMATE" => FunctionDefinition.standard(2) + case "APPROX_TOP_K" => FunctionDefinition.standard(1, 3) + case "APPROX_TOP_K_ACCUMULATE" => FunctionDefinition.standard(2) + case "APPROX_TOP_K_COMBINE" => FunctionDefinition.standard(1, 2) + case "APPROX_TOP_K_ESTIMATE" => FunctionDefinition.standard(2) + case "ARRAYS_OVERLAP" => FunctionDefinition.standard(2) + case "ARRAYS_TO_OBJECT" => FunctionDefinition.standard(2) + case "ARRAYAGG" => FunctionDefinition.standard(1).withConversionStrategy(SynonymOf("ARRAY_AGG")) + case "ARRAY_AGG" => FunctionDefinition.standard(1) + case "ARRAY_APPEND" => FunctionDefinition.standard(2) + case "ARRAY_CAT" => FunctionDefinition.standard(2) + case "ARRAY_COMPACT" => FunctionDefinition.standard(1) + case "ARRAY_CONSTRUCT" => FunctionDefinition.standard(0, Int.MaxValue) + case "ARRAY_CONSTRUCT_COMPACT" => FunctionDefinition.standard(0, Int.MaxValue) + case "ARRAY_CONTAINS" => FunctionDefinition.standard(2) + case "ARRAY_DISTINCT" => FunctionDefinition.standard(1) + case "ARRAY_EXCEPT" => FunctionDefinition.standard(2) + case "ARRAY_FLATTEN" => FunctionDefinition.standard(1) + case "ARRAY_GENERATE_RANGE" => FunctionDefinition.standard(2, 3) + case "ARRAY_INSERT" => FunctionDefinition.standard(3) + case "ARRAY_INTERSECTION" => FunctionDefinition.standard(2) + case "ARRAY_MAX" => FunctionDefinition.standard(1) + case "ARRAY_MIN" => FunctionDefinition.standard(1) + case "ARRAY_POSITION" => FunctionDefinition.standard(2) + case "ARRAY_PREPEND" => FunctionDefinition.standard(2) + case "ARRAY_REMOVE" => FunctionDefinition.standard(2) + case "ARRAY_REMOVE_AT" => FunctionDefinition.standard(2) + case "ARRAY_SIZE" => FunctionDefinition.standard(1) + case "ARRAY_SLICE" => FunctionDefinition.standard(3) + case "ARRAY_SORT" => FunctionDefinition.standard(1, 3) + case "ARRAY_TO_STRING" => FunctionDefinition.standard(2) + case "ARRAY_UNION_AGG" => FunctionDefinition.standard(1) + case "ARRAY_UNIQUE_AGG" => FunctionDefinition.standard(1) + case "ASCII" => FunctionDefinition.standard(1) + case "ASIN" => FunctionDefinition.standard(1) + case "ASINH" => FunctionDefinition.standard(1) + case "AS_ARRAY" => FunctionDefinition.standard(1) + case "AS_BINARY" => FunctionDefinition.standard(1) + case "AS_BOOLEAN" => FunctionDefinition.standard(1) + case "AS_CHAR" => FunctionDefinition.standard(1) + case "AS_DATE" => FunctionDefinition.standard(1) + case "AS_DECIMAL" => FunctionDefinition.standard(1, 3) + case "AS_DOUBLE" => FunctionDefinition.standard(1) + case "AS_INTEGER" => FunctionDefinition.standard(1) + case "AS_NUMBER" => FunctionDefinition.standard(1, 3) + case "AS_OBJECT" => FunctionDefinition.standard(1) + case "AS_REAL" => FunctionDefinition.standard(1) + case "AS_TIME" => FunctionDefinition.standard(1) + case "AS_TIMESTAMP_LTZ" => FunctionDefinition.standard(1) + case "AS_TIMESTAMP_NTZ" => FunctionDefinition.standard(1) + case "AS_TIMESTAMP_TZ" => FunctionDefinition.standard(1) + case "AS_VARCHAR" => FunctionDefinition.standard(1) + case "ATAN" => FunctionDefinition.standard(1) + case "ATAN2" => FunctionDefinition.standard(2) + case "ATANH" => FunctionDefinition.standard(1) + case "AUTOMATIC_CLUSTERING_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "TABLE_NAME")) + case "AUTO_REFRESH_REGISTRATION_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "OBJECT_TYPE", "OBJECT_NAME")) + case "AVG" => FunctionDefinition.standard(1) + case "BASE64_DECODE_BINARY" => FunctionDefinition.standard(1, 2) + case "BASE64_DECODE_STRING" => FunctionDefinition.standard(1, 2) + case "BASE64_ENCODE" => FunctionDefinition.standard(1, 3) + case "BITAND" => FunctionDefinition.standard(2, 3) + case "BITAND_AGG" => FunctionDefinition.standard(1) + case "BITMAP_BIT_POSITION" => FunctionDefinition.standard(1) + case "BITMAP_BUCKET_NUMBER" => FunctionDefinition.standard(1) + case "BITMAP_CONSTRUCT_AGG" => FunctionDefinition.standard(1) + case "BITMAP_COUNT" => FunctionDefinition.standard(1) + case "BITMAP_OR_AGG" => FunctionDefinition.standard(1) + case "BITNOT" => FunctionDefinition.standard(1) + case "BITOR" => FunctionDefinition.standard(2, 3) + case "BITOR_AGG" => FunctionDefinition.standard(1) + case "BITSHIFTLEFT" => FunctionDefinition.standard(2) + case "BITSHIFTRIGHT" => FunctionDefinition.standard(2) + case "BITXOR" => FunctionDefinition.standard(2, 3) + case "BITXOR_AGG" => FunctionDefinition.standard(1) + case "BIT_LENGTH" => FunctionDefinition.standard(1) + case "BLANK_COUNT (system data metric function)" => FunctionDefinition.standard(1) + case "BLANK_PERCENT (system data metric function)" => FunctionDefinition.standard(1) + case "BOOLAND" => FunctionDefinition.standard(2) + case "BOOLAND_AGG" => FunctionDefinition.standard(1) + case "BOOLNOT" => FunctionDefinition.standard(1) + case "BOOLOR" => FunctionDefinition.standard(2) + case "BOOLOR_AGG" => FunctionDefinition.standard(1) + case "BOOLXOR" => FunctionDefinition.standard(2) + case "BOOLXOR_AGG" => FunctionDefinition.standard(1) + case "BUILD_SCOPED_FILE_URL" => FunctionDefinition.standard(0) + case "BUILD_STAGE_FILE_URL" => FunctionDefinition.standard(0) + case "CAST , " => FunctionDefinition.standard(2) + case "CBRT" => FunctionDefinition.symbolic(Set("STR"), Set("DISABLE_AUTO_CONVERT")) + case "CEIL" => FunctionDefinition.standard(1, 2) + case "CHAR" => FunctionDefinition.standard(1) + case "CHARINDEX" => FunctionDefinition.standard(2, 3) + case "CHECK_JSON" => FunctionDefinition.standard(1) + case "CHECK_XML" => FunctionDefinition.symbolic(Set("STR"), Set("DISABLE_AUTO_CONVERT")) + case "CHR" => FunctionDefinition.standard(1) + case "COALESCE" => FunctionDefinition.standard(1, Int.MaxValue) + case "COLLATE" => FunctionDefinition.standard(2) + case "COLLATION" => FunctionDefinition.standard(1) + case "COMPLETE (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(2, 3) + case "COMPLETE_TASK_GRAPHS" => + FunctionDefinition.symbolic(Set.empty, Set("RESULT_LIMIT", "ROOT_TASK_NAME", "ERROR_ONLY")) + case "COMPRESS" => FunctionDefinition.standard(2) + case "CONCAT , " => FunctionDefinition.standard(1) + case "CONCAT_WS" => FunctionDefinition.standard(1, Int.MaxValue) + case "CONDITIONAL_CHANGE_EVENT" => FunctionDefinition.standard(1) + case "CONDITIONAL_TRUE_EVENT" => FunctionDefinition.standard(1) + case "CONTAINS" => FunctionDefinition.standard(2) + case "CONVERT_TIMEZONE" => FunctionDefinition.standard(2, 3) + case "COPY_HISTORY" => FunctionDefinition.standard(2, 3) + case "CORR" => FunctionDefinition.standard(2) + case "COS" => FunctionDefinition.standard(1) + case "COSH" => FunctionDefinition.standard(1) + case "COT" => FunctionDefinition.standard(1) + case "COUNT" => FunctionDefinition.standard(1, Int.MaxValue) + case "COUNT_IF" => FunctionDefinition.standard(1) + case "COUNT_TOKENS (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(3) + case "COVAR_POP" => FunctionDefinition.standard(2) + case "COVAR_SAMP" => FunctionDefinition.standard(2) + case "CUME_DIST" => FunctionDefinition.standard(0) + case "CURRENT_ACCOUNT" => FunctionDefinition.standard(0) + case "CURRENT_ACCOUNT_NAME" => FunctionDefinition.standard(0) + case "CURRENT_AVAILABLE_ROLES" => FunctionDefinition.standard(0) + case "CURRENT_CLIENT" => FunctionDefinition.standard(0) + case "CURRENT_DATABASE" => FunctionDefinition.standard(0) + case "CURRENT_DATE" => FunctionDefinition.standard(0) + case "CURRENT_IP_ADDRESS" => FunctionDefinition.standard(0) + case "CURRENT_ORGANIZATION_NAME" => FunctionDefinition.standard(0) + case "CURRENT_REGION" => FunctionDefinition.standard(0) + case "CURRENT_ROLE" => FunctionDefinition.standard(0) + case "CURRENT_ROLE_TYPE" => FunctionDefinition.standard(0) + case "CURRENT_SCHEMA" => FunctionDefinition.standard(0) + case "CURRENT_SCHEMAS" => FunctionDefinition.standard(0) + case "CURRENT_SECONDARY_ROLES" => FunctionDefinition.standard(0) + case "CURRENT_SESSION" => FunctionDefinition.standard(0) + case "CURRENT_STATEMENT" => FunctionDefinition.standard(0) + case "CURRENT_TASK_GRAPHS" => FunctionDefinition.symbolic(Set.empty, Set("RESULT_LIMIT", "ROOT_TASK_NAME")) + case "CURRENT_TIME" => FunctionDefinition.standard(0, 1) + case "CURRENT_TIMESTAMP" => FunctionDefinition.standard(0, 1) + case "CURRENT_TRANSACTION" => FunctionDefinition.standard(0) + case "CURRENT_USER" => FunctionDefinition.standard(0) + case "CURRENT_VERSION" => FunctionDefinition.standard(0) + case "CURRENT_WAREHOUSE" => FunctionDefinition.standard(0) + case "DATABASE_REFRESH_HISTORY" => FunctionDefinition.standard(0, 1) + case "DATABASE_REFRESH_PROGRESS" => FunctionDefinition.standard(0, 1) + case "DATABASE_REFRESH_PROGRESS_BY_JOB" => FunctionDefinition.standard(0, 1) + case "DATABASE_REPLICATION_USAGE_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "DATABASE_NAME")) + case "DATABASE_STORAGE_USAGE_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "DATABASE_NAME")) + case "DATA_METRIC_FUNCTION_REFERENCES" => FunctionDefinition.standard(3) + case "DATA_METRIC_SCHEDULED_TIME (system data metric function)" => FunctionDefinition.standard(0) + case "DATA_TRANSFER_HISTORY" => FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END")) + case "DATE" => FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_DATE")) + case "DATEADD" => FunctionDefinition.standard(3) + case "DATEDIFF" => FunctionDefinition.standard(3) + case "DATEFROMPARTS" => FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATE_FROM_PARTS")) + case "DATE_FROM_PARTS" => FunctionDefinition.standard(3) + case "DATE_FORMAT" => FunctionDefinition.standard(2) + case "DATE_PART" => FunctionDefinition.standard(2) + case "DATE_TRUNC" => FunctionDefinition.standard(2) + case "DAY" => FunctionDefinition.standard(1) + case "DAYNAME" => FunctionDefinition.standard(1) + case "DAYOFMONTH" => FunctionDefinition.standard(1) + case "DAYOFWEEK" => FunctionDefinition.standard(1) + case "DAYOFWEEKISO" => FunctionDefinition.standard(1) + case "DAYOFYEAR" => FunctionDefinition.standard(1) + case "DECODE" => FunctionDefinition.standard(3, Int.MaxValue) + case "DECOMPRESS_BINARY" => FunctionDefinition.standard(2) + case "DECOMPRESS_STRING" => FunctionDefinition.standard(2) + case "DECRYPT" => FunctionDefinition.standard(2, 4) + case "DECRYPT_RAW" => FunctionDefinition.standard(3, 6) + case "DEGREES" => FunctionDefinition.standard(1) + case "DENSE_RANK" => FunctionDefinition.standard(0) + case "DIV0" => FunctionDefinition.standard(2) + case "DIV0NULL" => FunctionDefinition.standard(2) + case "DUPLICATE_COUNT (system data metric function)" => FunctionDefinition.standard(1) + case "DYNAMIC_TABLES" => FunctionDefinition.standard(3) + case "DYNAMIC_TABLE_GRAPH_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("AS_OF", "HISTORY_START", "HISTORY_END")) + case "DYNAMIC_TABLE_REFRESH_HISTORY" => + FunctionDefinition.symbolic( + Set.empty, + Set("DATA_TIMESTAMP_START", "DATA_TIMESTAMP_END", "RESULT_LIMIT", "NAME", "NAME_PREFIX", "ERROR_ONLY")) + case "EDITDISTANCE" => FunctionDefinition.standard(2, 3) + case "EMAIL_INTEGRATION_CONFIG" => FunctionDefinition.standard(5) + case "EMBED_TEXT_1024 (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(2) + case "EMBED_TEXT_768 (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(2) + case "ENCRYPT" => FunctionDefinition.standard(2, 4) + case "ENCRYPT_RAW" => FunctionDefinition.standard(3, 5) + case "ENDSWITH" => FunctionDefinition.standard(2) + case "EQUAL_NULL" => FunctionDefinition.standard(2) + case "EXP" => FunctionDefinition.standard(1) + case "EXPLAIN_JSON" => FunctionDefinition.standard(1) + case "EXTERNAL_FUNCTIONS_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "FUNCTION_SIGNATURE")) + case "EXTERNAL_TABLE_FILES" => FunctionDefinition.standard(1) + case "EXTERNAL_TABLE_FILE_REGISTRATION_HISTORY" => FunctionDefinition.standard(1, 2) + case "EXTRACT" => FunctionDefinition.standard(2) + case "EXTRACT_ANSWER (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(2) + case "EXTRACT_SEMANTIC_CATEGORIES" => FunctionDefinition.standard(1, 2) + case "FACTORIAL" => FunctionDefinition.standard(1) + case "FILTER" => FunctionDefinition.standard(2) + case "FINETUNE ('CANCEL') (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(0) + case "FINETUNE ('CREATE') (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(0) + case "FINETUNE ('DESCRIBE') (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(0) + case "FINETUNE ('SHOW') (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(0) + case "FINETUNE (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(0) + case "FIRST_VALUE" => FunctionDefinition.standard(1) + case "FLATTEN" => FunctionDefinition.symbolic(Set("INPUT"), Set("PATH", "OUTER", "RECURSIVE", "MODE")) + case "FLOOR" => FunctionDefinition.standard(1, 2) + case "FRESHNESS (system data metric function)" => FunctionDefinition.standard(1) + case "GENERATE_COLUMN_DESCRIPTION" => FunctionDefinition.standard(2) + case "GENERATOR" => FunctionDefinition.symbolic(Set.empty, Set("ROWCOUNT", "TIMELIMIT")) + case "GET" => FunctionDefinition.standard(2) + case "GETBIT" => FunctionDefinition.standard(2) + case "GETDATE" => FunctionDefinition.standard(0) + case "GETVARIABLE" => FunctionDefinition.standard(1) + case "GET_ABSOLUTE_PATH" => FunctionDefinition.standard(0) + case "GET_ANACONDA_PACKAGES_REPODATA" => FunctionDefinition.standard(1) + case "GET_CONDITION_QUERY_UUID" => FunctionDefinition.standard(0) + case "GET_DDL" => FunctionDefinition.standard(2, 3) + case "GET_IGNORE_CASE" => FunctionDefinition.standard(2) + case "GET_OBJECT_REFERENCES" => FunctionDefinition.standard(3) + case "GET_PATH" => FunctionDefinition.standard(2) + case "GET_PRESIGNED_URL" => FunctionDefinition.standard(0) + case "GET_QUERY_OPERATOR_STATS" => FunctionDefinition.standard(1) + case "GET_RELATIVE_PATH" => FunctionDefinition.standard(0) + case "GET_STAGE_LOCATION" => FunctionDefinition.standard(0) + case "GREATEST" => FunctionDefinition.standard(1, Int.MaxValue) + case "GREATEST_IGNORE_NULLS" => FunctionDefinition.standard(1) + case "GROUPING" => FunctionDefinition.standard(1, Int.MaxValue) + case "GROUPING_ID" => FunctionDefinition.standard(1, Int.MaxValue) + case "H3_CELL_TO_BOUNDARY" => FunctionDefinition.standard(1) + case "H3_CELL_TO_CHILDREN" => FunctionDefinition.standard(2) + case "H3_CELL_TO_CHILDREN_STRING" => FunctionDefinition.standard(2) + case "H3_CELL_TO_PARENT" => FunctionDefinition.standard(2) + case "H3_CELL_TO_POINT" => FunctionDefinition.standard(1) + case "H3_COMPACT_CELLS" => FunctionDefinition.standard(1) + case "H3_COMPACT_CELLS_STRINGS" => FunctionDefinition.standard(1) + case "H3_COVERAGE" => FunctionDefinition.standard(2) + case "H3_COVERAGE_STRINGS" => FunctionDefinition.standard(2) + case "H3_GET_RESOLUTION" => FunctionDefinition.standard(1) + case "H3_GRID_DISK" => FunctionDefinition.standard(2) + case "H3_GRID_DISTANCE" => FunctionDefinition.standard(2) + case "H3_GRID_PATH" => FunctionDefinition.standard(2) + case "H3_INT_TO_STRING" => FunctionDefinition.standard(1) + case "H3_IS_PENTAGON" => FunctionDefinition.standard(1) + case "H3_IS_VALID_CELL" => FunctionDefinition.standard(1) + case "H3_LATLNG_TO_CELL" => FunctionDefinition.standard(3) + case "H3_LATLNG_TO_CELL_STRING" => FunctionDefinition.standard(3) + case "H3_POINT_TO_CELL" => FunctionDefinition.standard(2) + case "H3_POINT_TO_CELL_STRING" => FunctionDefinition.standard(2) + case "H3_POLYGON_TO_CELLS" => FunctionDefinition.standard(2) + case "H3_POLYGON_TO_CELLS_STRINGS" => FunctionDefinition.standard(2) + case "H3_STRING_TO_INT" => FunctionDefinition.standard(1) + case "H3_TRY_COVERAGE" => FunctionDefinition.standard(2) + case "H3_TRY_COVERAGE_STRINGS" => FunctionDefinition.standard(2) + case "H3_TRY_GRID_DISTANCE" => FunctionDefinition.standard(2) + case "H3_TRY_GRID_PATH" => FunctionDefinition.standard(2) + case "H3_TRY_POLYGON_TO_CELLS" => FunctionDefinition.standard(2) + case "H3_TRY_POLYGON_TO_CELLS_STRINGS" => FunctionDefinition.standard(2) + case "H3_UNCOMPACT_CELLS" => FunctionDefinition.standard(2) + case "H3_UNCOMPACT_CELLS_STRINGS" => FunctionDefinition.standard(2) + case "HASH" => FunctionDefinition.standard(1, Int.MaxValue) + case "HASH_AGG" => FunctionDefinition.standard(1, Int.MaxValue) + case "HAVERSINE" => FunctionDefinition.standard(4) + case "HEX_DECODE_BINARY" => FunctionDefinition.standard(1) + case "HEX_DECODE_STRING" => FunctionDefinition.standard(1) + case "HEX_ENCODE" => FunctionDefinition.standard(1, 2) + case "HLL" => FunctionDefinition.standard(1, Int.MaxValue) + case "HLL_ACCUMULATE" => FunctionDefinition.standard(1) + case "HLL_COMBINE" => FunctionDefinition.standard(1) + case "HLL_ESTIMATE" => FunctionDefinition.standard(1) + case "HLL_EXPORT" => FunctionDefinition.standard(1) + case "HLL_IMPORT" => FunctionDefinition.standard(1) + case "HOUR / MINUTE / SECOND" => FunctionDefinition.standard(0) + case "HOUR" => FunctionDefinition.standard(1) + case "IFF" => FunctionDefinition.standard(3) + case "IFNULL" => FunctionDefinition.standard(1, 2) + case "ILIKE ANY" => FunctionDefinition.standard(2, 3) + case "INFER_SCHEMA" => FunctionDefinition.standard(0) + case "INITCAP" => FunctionDefinition.standard(1, 2) + case "INSERT" => FunctionDefinition.standard(4) + case "INTEGRATION" => FunctionDefinition.standard(1) + case "INVOKER_ROLE" => FunctionDefinition.standard(0) + case "INVOKER_SHARE" => FunctionDefinition.standard(0) + case "IS [ NOT ] DISTINCT FROM" => FunctionDefinition.standard(0) + case "IS [ NOT ] NULL" => FunctionDefinition.standard(0) + case "ISNULL" => FunctionDefinition.standard(1) + case "IS_ARRAY" => FunctionDefinition.standard(1) + case "IS_BINARY" => FunctionDefinition.standard(1) + case "IS_BOOLEAN" => FunctionDefinition.standard(1) + case "IS_CHAR" => FunctionDefinition.standard(1) + case "IS_DATABASE_ROLE_IN_SESSION" => FunctionDefinition.standard(1) + case "IS_DATE" => FunctionDefinition.standard(1) + case "IS_DATE_VALUE" => FunctionDefinition.standard(1) + case "IS_DECIMAL" => FunctionDefinition.standard(1) + case "IS_DOUBLE" => FunctionDefinition.standard(1) + case "IS_GRANTED_TO_INVOKER_ROLE" => FunctionDefinition.standard(1) + case "IS_INSTANCE_ROLE_IN_SESSION" => FunctionDefinition.standard(1) + case "IS_INTEGER" => FunctionDefinition.standard(1) + case "IS_NULL_VALUE" => FunctionDefinition.standard(1) + case "IS_OBJECT" => FunctionDefinition.standard(1) + case "IS_REAL" => FunctionDefinition.standard(1) + case "IS_ROLE_IN_SESSION" => FunctionDefinition.standard(1) + case "IS_TIME" => FunctionDefinition.standard(1) + case "IS_TIMESTAMP_LTZ" => FunctionDefinition.standard(1) + case "IS_TIMESTAMP_NTZ" => FunctionDefinition.standard(1) + case "IS_TIMESTAMP_TZ" => FunctionDefinition.standard(1) + case "IS_VARCHAR" => FunctionDefinition.standard(1) + case "JAROWINKLER_SIMILARITY" => FunctionDefinition.standard(2) + case "JSON_EXTRACT_PATH_TEXT" => FunctionDefinition.standard(2) + case "KURTOSIS" => FunctionDefinition.standard(1) + case "LAG" => FunctionDefinition.standard(1, 3) + case "LAST_DAY" => FunctionDefinition.standard(1, 2) + case "LAST_QUERY_ID" => FunctionDefinition.standard(0, 1) + case "LAST_SUCCESSFUL_SCHEDULED_TIME" => FunctionDefinition.standard(0) + case "LAST_TRANSACTION" => FunctionDefinition.standard(0) + case "LAST_VALUE" => FunctionDefinition.standard(1) + case "LEAD" => FunctionDefinition.standard(1, 3) + case "LEAST" => FunctionDefinition.standard(1, Int.MaxValue) + case "LEAST_IGNORE_NULLS" => FunctionDefinition.standard(1) + case "LEFT" => FunctionDefinition.standard(2) + case "LEN" => FunctionDefinition.standard(1) + case "LENGTH" => FunctionDefinition.standard(1) + case "LIKE ALL" => FunctionDefinition.standard(2, 3) + case "LIKE ANY" => FunctionDefinition.standard(2, 3) + case "LISTAGG" => FunctionDefinition.standard(1, 2) + case "LN" => FunctionDefinition.standard(1) + case "LOCALTIME" => FunctionDefinition.standard(0) + case "LOCALTIMESTAMP" => FunctionDefinition.standard(0, 1) + case "LOG" => FunctionDefinition.standard(2) + case "LOGIN_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("TIME_RANGE_START", "TIME_RANGE_END", "RESULT_LIMIT")) + case "LOGIN_HISTORY_BY_USER" => + FunctionDefinition.symbolic(Set.empty, Set("USER_NAME", "TIME_RANGE_START", "TIME_RANGE_END", "RESULT_LIMIT")) + case "LOWER" => FunctionDefinition.standard(1) + case "LPAD" => FunctionDefinition.standard(2, 3) + case "LTRIM" => FunctionDefinition.standard(1, 2) + case "MAP_CAT" => FunctionDefinition.standard(2) + case "MAP_CONTAINS_KEY" => FunctionDefinition.standard(2) + case "MAP_DELETE" => FunctionDefinition.standard(2) + case "MAP_INSERT" => FunctionDefinition.standard(3, 4) + case "MAP_KEYS" => FunctionDefinition.standard(1) + case "MAP_PICK" => FunctionDefinition.standard(4) + case "MAP_SIZE" => FunctionDefinition.standard(1) + case "MATERIALIZED_VIEW_REFRESH_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "MATERIALIZED_VIEW_NAME")) + case "MAX" => FunctionDefinition.standard(1) + case "MAX_BY" => FunctionDefinition.standard(2, 3) + case "MD5" => FunctionDefinition.standard(1) + case "MD5_BINARY" => FunctionDefinition.standard(1) + case "MD5_HEX" => FunctionDefinition.standard(1) + case "MD5_NUMBER — " => FunctionDefinition.standard(1) + case "MD5_NUMBER_LOWER64" => FunctionDefinition.standard(1) + case "MD5_NUMBER_UPPER64" => FunctionDefinition.standard(1) + case "MEDIAN" => FunctionDefinition.standard(1) + case "MIN (system data metric function)" => FunctionDefinition.standard(1) + case "MIN" => FunctionDefinition.standard(1) + case "MINHASH" => FunctionDefinition.standard(2, Int.MaxValue) + case "MINHASH_COMBINE" => FunctionDefinition.standard(0) + case "MINUTE" => FunctionDefinition.standard(1) + case "MIN_BY" => FunctionDefinition.standard(2, 3) + case "MOD" => FunctionDefinition.standard(2) + case "MODE" => FunctionDefinition.standard(1) + case "MONTH" => FunctionDefinition.standard(1) + case "MONTHNAME" => FunctionDefinition.standard(1) + case "MONTHS_BETWEEN" => FunctionDefinition.standard(2) + case "NETWORK_RULE_REFERENCES" => FunctionDefinition.standard(2) + case "NEXT_DAY" => FunctionDefinition.standard(2) + case "NORMAL" => FunctionDefinition.standard(3) + case "NOTIFICATION_HISTORY" => FunctionDefinition.standard(4) + case "NTH_VALUE" => FunctionDefinition.standard(2) + case "NTILE" => FunctionDefinition.standard(1) + case "NULLIF" => FunctionDefinition.standard(2) + case "NULLIFZERO" => FunctionDefinition.standard(1) + case "NULL_COUNT (system data metric function)" => FunctionDefinition.standard(1) + case "NULL_PERCENT (system data metric function)" => FunctionDefinition.standard(1) + case "NVL" => FunctionDefinition.standard(2).withConversionStrategy(SynonymOf("IFNULL")) + case "NVL2" => FunctionDefinition.standard(3) + case "OBJECT_AGG" => FunctionDefinition.standard(2) + case "OBJECT_CONSTRUCT" => FunctionDefinition.standard(0, Int.MaxValue) + case "OBJECT_CONSTRUCT_KEEP_NULL" => FunctionDefinition.standard(1, Int.MaxValue) + case "OBJECT_DELETE" => FunctionDefinition.standard(3) + case "OBJECT_INSERT" => FunctionDefinition.standard(3, 4) + case "OBJECT_KEYS" => FunctionDefinition.standard(1) + case "OBJECT_PICK" => FunctionDefinition.standard(2, Int.MaxValue) + case "OCTET_LENGTH" => FunctionDefinition.standard(1) + case "PARSE_IP" => FunctionDefinition.standard(2, 3) + case "PARSE_JSON" => FunctionDefinition.standard(1) + case "PARSE_URL" => FunctionDefinition.standard(1, 2) + case "PARSE_XML" => FunctionDefinition.symbolic(Set("STR"), Set("DISABLE_AUTO_CONVERT")) + case "PERCENTILE_CONT" => FunctionDefinition.standard(1) + case "PERCENTILE_DISC" => FunctionDefinition.standard(1) + case "PERCENT_RANK" => FunctionDefinition.standard(0) + case "PI" => FunctionDefinition.standard(0) + case "PIPE_USAGE_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "PIPE_NAME")) + case "POLICY_CONTEXT" => FunctionDefinition.standard(7) + case "POLICY_REFERENCES" => + FunctionDefinition.symbolic(Set.empty, Set("POLICY_NAME", "POLICY_KIND", "REF_ENTITY_NAME", "REF_ENTITY_DOMAIN")) + case "POSITION" => FunctionDefinition.standard(2, 3) + case "POW" => FunctionDefinition.standard(2) + case "POWER" => FunctionDefinition.standard(2) + case "PREVIOUS_DAY" => FunctionDefinition.standard(2) + case "QUARTER" => FunctionDefinition.standard(1) + case "QUERY_ACCELERATION_HISTORY" => FunctionDefinition.standard(0) + case "QUERY_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("END_TIME_RANGE_START", "END_TIME_RANGE_END", "RESULT_LIMIT")) + case "QUERY_HISTORY_BY_SESSION" => + FunctionDefinition.symbolic( + Set.empty, + Set("SESSION_ID", "END_TIME_RANGE_START", "END_TIME_RANGE_END", "RESULT_LIMIT")) + case "QUERY_HISTORY_BY_USER" => + FunctionDefinition.symbolic( + Set.empty, + Set("USER_NAME", "END_TIME_RANGE_START", "END_TIME_RANGE_END", "RESULT_LIMIT")) + case "QUERY_HISTORY_BY_WAREHOUSE" => + FunctionDefinition.symbolic( + Set.empty, + Set("WAREHOUSE_NAME", "END_TIME_RANGE_START", "END_TIME_RANGE_END", "RESULT_LIMIT")) + case "RADIANS" => FunctionDefinition.standard(1) + case "RANDOM" => FunctionDefinition.standard(0, 1) + case "RANDSTR" => FunctionDefinition.standard(2) + case "RANK" => FunctionDefinition.standard(0) + case "RATIO_TO_REPORT" => FunctionDefinition.standard(1) + case "REGEXP_COUNT" => FunctionDefinition.standard(2, 4) + case "REGEXP_INSTR" => FunctionDefinition.standard(2, 7) + case "REGEXP_LIKE" => FunctionDefinition.standard(2, 3) + case "REGEXP_REPLACE" => FunctionDefinition.standard(2, 6) + case "REGEXP_SUBSTR" => FunctionDefinition.standard(2, 6) + case "REGEXP_SUBSTR_ALL" => FunctionDefinition.standard(2, 6) + case "REGR_AVGX" => FunctionDefinition.standard(2) + case "REGR_AVGY" => FunctionDefinition.standard(2) + case "REGR_COUNT" => FunctionDefinition.standard(2) + case "REGR_INTERCEPT" => FunctionDefinition.standard(2) + case "REGR_R2" => FunctionDefinition.standard(2) + case "REGR_SLOPE" => FunctionDefinition.standard(2) + case "REGR_SXX" => FunctionDefinition.standard(2) + case "REGR_SXY" => FunctionDefinition.standard(2) + case "REGR_SYY" => FunctionDefinition.standard(2) + case "REGR_VALX" => FunctionDefinition.standard(2) + case "REGR_VALY" => FunctionDefinition.standard(2) + case "REPEAT" => FunctionDefinition.standard(2) + case "REPLACE" => FunctionDefinition.standard(2, 3) + case "REPLICATION_GROUP_REFRESH_HISTORY" => FunctionDefinition.standard(1) + case "REPLICATION_GROUP_REFRESH_PROGRESS" => FunctionDefinition.standard(1) + case "REPLICATION_GROUP_REFRESH_PROGRESS_BY_JOB" => FunctionDefinition.standard(1) + case "REPLICATION_GROUP_USAGE_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "REPLICATION_GROUP_NAME")) + case "REPLICATION_USAGE_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "DATABASE_NAME")) + case "REST_EVENT_HISTORY" => FunctionDefinition.standard(1, 4) + case "RESULT_SCAN" => FunctionDefinition.standard(1) + case "REVERSE" => FunctionDefinition.standard(1) + case "RIGHT" => FunctionDefinition.standard(2) + case "RLIKE" => FunctionDefinition.standard(2, 3) + case "ROUND" => FunctionDefinition.standard(1, 3) + case "ROW_COUNT (system data metric function)" => FunctionDefinition.standard(0) + case "ROW_NUMBER" => FunctionDefinition.standard(0) + case "RPAD" => FunctionDefinition.standard(2, 3) + case "RTRIM" => FunctionDefinition.standard(1, 2) + case "RTRIMMED_LENGTH" => FunctionDefinition.standard(1) + case "SCHEDULED_TIME" => FunctionDefinition.standard(0) + case "SEARCH_OPTIMIZATION_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "TABLE_NAME")) + case "SECOND" => FunctionDefinition.standard(1) + case "SENTIMENT (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(1) + case "SEQ1" => FunctionDefinition.standard(0, 1) + case "SEQ2" => FunctionDefinition.standard(0, 1) + case "SEQ4" => FunctionDefinition.standard(0, 1) + case "SEQ8" => FunctionDefinition.standard(0, 1) + case "SERVERLESS_TASK_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "TASK_NAME")) + case "SHA1" => FunctionDefinition.standard(1) + case "SHA1_BINARY" => FunctionDefinition.standard(1) + case "SHA1_HEX" => FunctionDefinition.standard(1) + case "SHA2" => FunctionDefinition.standard(1, 2) + case "SHA2_BINARY" => FunctionDefinition.standard(1, 2) + case "SHA2_HEX" => FunctionDefinition.standard(1, 2) + case "SHOW_PYTHON_PACKAGES_DEPENDENCIES" => FunctionDefinition.standard(2) + case "SIGN" => FunctionDefinition.standard(1) + case "SIN" => FunctionDefinition.standard(1) + case "SINH" => FunctionDefinition.standard(1) + case "SKEW" => FunctionDefinition.standard(1) + case "SOUNDEX" => FunctionDefinition.standard(1) + case "SOUNDEX_P123" => FunctionDefinition.standard(1) + case "SPACE" => FunctionDefinition.standard(1) + case "SPLIT" => FunctionDefinition.standard(2) + case "SPLIT_PART" => FunctionDefinition.standard(3) + case "SPLIT_TO_TABLE" => FunctionDefinition.standard(2) + case "SQRT" => FunctionDefinition.standard(1) + case "SQUARE" => FunctionDefinition.standard(1) + case "STAGE_DIRECTORY_FILE_REGISTRATION_HISTORY" => FunctionDefinition.standard(1, 2) + case "STAGE_STORAGE_USAGE_HISTORY" => FunctionDefinition.standard(2) + case "STARTSWITH" => FunctionDefinition.standard(2) + case "STDDEV (system data metric function)" => FunctionDefinition.standard(1) + case "STDDEV" => FunctionDefinition.standard(1) + case "STDDEV, STDDEV_SAMP" => FunctionDefinition.standard(3) + case "STDDEV_POP" => FunctionDefinition.standard(1) + case "STDDEV_SAMP" => FunctionDefinition.standard(1) + case "STRIP_NULL_VALUE" => FunctionDefinition.standard(1) + case "STRTOK" => FunctionDefinition.standard(1, 3) + case "STRTOK_SPLIT_TO_TABLE" => FunctionDefinition.standard(1, 2) + case "STRTOK_TO_ARRAY" => FunctionDefinition.standard(1, 2) + case "ST_AREA" => FunctionDefinition.standard(1) + case "ST_ASBINARY" => FunctionDefinition.standard(1) + case "ST_ASEWKB" => FunctionDefinition.standard(1) + case "ST_ASEWKT" => FunctionDefinition.standard(1) + case "ST_ASGEOJSON" => FunctionDefinition.standard(1) + case "ST_ASTEXT" => FunctionDefinition.standard(1) + case "ST_ASWKB" => FunctionDefinition.standard(1) + case "ST_ASWKT" => FunctionDefinition.standard(1) + case "ST_AZIMUTH" => FunctionDefinition.standard(2) + case "ST_BUFFER" => FunctionDefinition.standard(2) + case "ST_CENTROID" => FunctionDefinition.standard(1) + case "ST_COLLECT" => FunctionDefinition.standard(1, 2) + case "ST_CONTAINS" => FunctionDefinition.standard(2) + case "ST_COVEREDBY" => FunctionDefinition.standard(2) + case "ST_COVERS" => FunctionDefinition.standard(2) + case "ST_DIFFERENCE" => FunctionDefinition.standard(2) + case "ST_DIMENSION" => FunctionDefinition.standard(1) + case "ST_DISJOINT" => FunctionDefinition.standard(2) + case "ST_DISTANCE" => FunctionDefinition.standard(2) + case "ST_DWITHIN" => FunctionDefinition.standard(3) + case "ST_ENDPOINT" => FunctionDefinition.standard(1) + case "ST_ENVELOPE" => FunctionDefinition.standard(1) + case "ST_GEOGFROMGEOHASH" => FunctionDefinition.standard(1, 2) + case "ST_GEOGPOINTFROMGEOHASH" => FunctionDefinition.standard(1) + case "ST_GEOGRAPHYFROMWKB" => FunctionDefinition.standard(1, 2) + case "ST_GEOGRAPHYFROMWKT" => FunctionDefinition.standard(1, 2) + case "ST_GEOHASH" => FunctionDefinition.standard(1, 2) + case "ST_GEOMETRYFROMWKB" => FunctionDefinition.standard(1, 3) + case "ST_GEOMETRYFROMWKT" => FunctionDefinition.standard(1, 3) + case "ST_GEOMFROMGEOHASH" => FunctionDefinition.standard(1, 2) + case "ST_GEOMPOINTFROMGEOHASH" => FunctionDefinition.standard(1) + case "ST_GEOM_POINT" => FunctionDefinition.standard(2) + case "ST_HAUSDORFFDISTANCE" => FunctionDefinition.standard(2) + case "ST_INTERSECTION" => FunctionDefinition.standard(2) + case "ST_INTERSECTION_AGG" => FunctionDefinition.standard(1) + case "ST_INTERSECTS" => FunctionDefinition.standard(2) + case "ST_ISVALID" => FunctionDefinition.standard(1) + case "ST_LENGTH" => FunctionDefinition.standard(1) + case "ST_MAKEGEOMPOINT" => FunctionDefinition.standard(2) + case "ST_MAKELINE" => FunctionDefinition.standard(2) + case "ST_MAKEPOINT" => FunctionDefinition.standard(2) + case "ST_MAKEPOLYGON" => FunctionDefinition.standard(1) + case "ST_MAKEPOLYGONORIENTED" => FunctionDefinition.standard(1) + case "ST_NPOINTS" => FunctionDefinition.standard(1) + case "ST_NUMPOINTS" => FunctionDefinition.standard(1) + case "ST_PERIMETER" => FunctionDefinition.standard(1) + case "ST_POINT" => FunctionDefinition.standard(2) + case "ST_POINTN" => FunctionDefinition.standard(2) + case "ST_POLYGON" => FunctionDefinition.standard(1) + case "ST_SETSRID" => FunctionDefinition.standard(2) + case "ST_SIMPLIFY" => FunctionDefinition.standard(2, 3) + case "ST_SRID" => FunctionDefinition.standard(1) + case "ST_STARTPOINT" => FunctionDefinition.standard(1) + case "ST_SYMDIFFERENCE" => FunctionDefinition.standard(2) + case "ST_TRANSFORM" => FunctionDefinition.standard(2, 3) + case "ST_UNION" => FunctionDefinition.standard(2) + case "ST_UNION_AGG" => FunctionDefinition.standard(1) + case "ST_WITHIN" => FunctionDefinition.standard(2) + case "ST_X" => FunctionDefinition.standard(1) + case "ST_XMAX" => FunctionDefinition.standard(1) + case "ST_XMIN" => FunctionDefinition.standard(1) + case "ST_Y" => FunctionDefinition.standard(1) + case "ST_YMAX" => FunctionDefinition.standard(1) + case "ST_YMIN" => FunctionDefinition.standard(1) + case "SUM" => FunctionDefinition.standard(1) + case "SUMMARIZE (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(1) + case "SYSDATE" => FunctionDefinition.standard(0) + case "SYSTEM$ABORT_SESSION" => FunctionDefinition.standard(1) + case "SYSTEM$ABORT_TRANSACTION" => FunctionDefinition.standard(1) + case "SYSTEM$ADD_EVENT (for Snowflake Scripting)" => FunctionDefinition.standard(2) + case "SYSTEM$ALLOWLIST" => FunctionDefinition.standard(0) + case "SYSTEM$ALLOWLIST_PRIVATELINK" => FunctionDefinition.standard(0) + case "SYSTEM$AUTHORIZE_PRIVATELINK" => FunctionDefinition.standard(3) + case "SYSTEM$AUTHORIZE_STAGE_PRIVATELINK_ACCESS" => FunctionDefinition.standard(1) + case "SYSTEM$BEHAVIOR_CHANGE_BUNDLE_STATUS" => FunctionDefinition.standard(1) + case "SYSTEM$BLOCK_INTERNAL_STAGES_PUBLIC_ACCESS" => FunctionDefinition.standard(0) + case "SYSTEM$CANCEL_ALL_QUERIES" => FunctionDefinition.standard(1) + case "SYSTEM$CANCEL_QUERY" => FunctionDefinition.standard(1) + case "SYSTEM$CLEANUP_DATABASE_ROLE_GRANTS" => FunctionDefinition.standard(2) + case "SYSTEM$CLIENT_VERSION_INFO" => FunctionDefinition.standard(0) + case "SYSTEM$CLUSTERING_DEPTH" => FunctionDefinition.standard(4) + case "SYSTEM$CLUSTERING_INFORMATION" => FunctionDefinition.standard(4) + case "SYSTEM$CLUSTERING_RATIO — " => FunctionDefinition.standard(4) + case "SYSTEM$CONVERT_PIPES_SQS_TO_SNS" => FunctionDefinition.standard(2) + case "SYSTEM$CREATE_BILLING_EVENT" => FunctionDefinition.standard(3, 7) + case "SYSTEM$CURRENT_USER_TASK_NAME" => FunctionDefinition.standard(0) + case "SYSTEM$DATABASE_REFRESH_HISTORY — " => FunctionDefinition.standard(1) + case "SYSTEM$DATABASE_REFRESH_PROGRESS , SYSTEM$DATABASE_REFRESH_PROGRESS_BY_JOB — " => + FunctionDefinition.standard(2) + case "SYSTEM$DISABLE_BEHAVIOR_CHANGE_BUNDLE" => FunctionDefinition.standard(1) + case "SYSTEM$DISABLE_DATABASE_REPLICATION" => FunctionDefinition.standard(1) + case "SYSTEM$ENABLE_BEHAVIOR_CHANGE_BUNDLE" => FunctionDefinition.standard(1) + case "SYSTEM$ESTIMATE_AUTOMATIC_CLUSTERING_COSTS" => FunctionDefinition.standard(3) + case "SYSTEM$ESTIMATE_QUERY_ACCELERATION" => FunctionDefinition.standard(0) + case "SYSTEM$ESTIMATE_SEARCH_OPTIMIZATION_COSTS" => FunctionDefinition.standard(1, 2) + case "SYSTEM$EXPLAIN_JSON_TO_TEXT" => FunctionDefinition.standard(1) + case "SYSTEM$EXPLAIN_PLAN_JSON" => FunctionDefinition.standard(2) + case "SYSTEM$EXTERNAL_TABLE_PIPE_STATUS" => FunctionDefinition.standard(1) + case "SYSTEM$FINISH_OAUTH_FLOW" => FunctionDefinition.standard(1) + case "SYSTEM$GENERATE_SAML_CSR" => FunctionDefinition.standard(2) + case "SYSTEM$GENERATE_SCIM_ACCESS_TOKEN" => FunctionDefinition.standard(0) + case "SYSTEM$GET_AWS_SNS_IAM_POLICY" => FunctionDefinition.standard(1) + case "SYSTEM$GET_CLASSIFICATION_RESULT" => FunctionDefinition.standard(1) + case "SYSTEM$GET_CMK_AKV_CONSENT_URL" => FunctionDefinition.standard(2) + case "SYSTEM$GET_CMK_CONFIG" => FunctionDefinition.standard(1) + case "SYSTEM$GET_CMK_INFO" => FunctionDefinition.standard(0) + case "SYSTEM$GET_CMK_KMS_KEY_POLICY" => FunctionDefinition.standard(0) + case "SYSTEM$GET_COMPUTE_POOL_STATUS" => FunctionDefinition.standard(1) + case "SYSTEM$GET_DIRECTORY_TABLE_STATUS" => FunctionDefinition.standard(0, 1) + case "SYSTEM$GET_GCP_KMS_CMK_GRANT_ACCESS_CMD" => FunctionDefinition.standard(0) + case "SYSTEM$GET_ICEBERG_TABLE_INFORMATION" => FunctionDefinition.standard(1) + case "SYSTEM$GET_LOGIN_FAILURE_DETAILS" => FunctionDefinition.standard(1) + case "SYSTEM$GET_PREDECESSOR_RETURN_VALUE" => FunctionDefinition.standard(1) + case "SYSTEM$GET_PRIVATELINK" => FunctionDefinition.standard(3) + case "SYSTEM$GET_PRIVATELINK_AUTHORIZED_ENDPOINTS" => FunctionDefinition.standard(0) + case "SYSTEM$GET_PRIVATELINK_CONFIG" => FunctionDefinition.standard(0) + case "SYSTEM$GET_SERVICE_LOGS" => FunctionDefinition.standard(3, 4) + case "SYSTEM$GET_SERVICE_STATUS" => FunctionDefinition.standard(1, 2) + case "SYSTEM$GET_SNOWFLAKE_PLATFORM_INFO" => FunctionDefinition.standard(0) + case "SYSTEM$GET_TAG" => FunctionDefinition.standard(3) + case "SYSTEM$GET_TAG_ALLOWED_VALUES" => FunctionDefinition.standard(1) + case "SYSTEM$GET_TAG_ON_CURRENT_COLUMN" => FunctionDefinition.standard(1) + case "SYSTEM$GET_TAG_ON_CURRENT_TABLE" => FunctionDefinition.standard(1) + case "SYSTEM$GET_TASK_GRAPH_CONFIG" => FunctionDefinition.standard(1) + case "SYSTEM$GLOBAL_ACCOUNT_SET_PARAMETER" => FunctionDefinition.standard(0) + case "SYSTEM$INTERNAL_STAGES_PUBLIC_ACCESS_STATUS" => FunctionDefinition.standard(0) + case "SYSTEM$IS_APPLICATION_INSTALLED_FROM_SAME_ACCOUNT" => FunctionDefinition.standard(0) + case "SYSTEM$IS_APPLICATION_SHARING_EVENTS_WITH_PROVIDER" => FunctionDefinition.standard(0) + case "SYSTEM$LAST_CHANGE_COMMIT_TIME" => FunctionDefinition.standard(1) + case "SYSTEM$LINK_ACCOUNT_OBJECTS_BY_NAME" => FunctionDefinition.standard(1) + case "SYSTEM$LIST_APPLICATION_RESTRICTED_FEATURES" => FunctionDefinition.standard(0) + case "SYSTEM$LOG, SYSTEM$LOG_ (for Snowflake Scripting)" => FunctionDefinition.standard(2) + case "SYSTEM$MIGRATE_SAML_IDP_REGISTRATION" => FunctionDefinition.standard(2) + case "SYSTEM$PIPE_FORCE_RESUME" => FunctionDefinition.standard(1) + case "SYSTEM$PIPE_REBINDING_WITH_NOTIFICATION_CHANNEL" => FunctionDefinition.standard(1) + case "SYSTEM$PIPE_STATUS" => FunctionDefinition.standard(1) + case "SYSTEM$QUERY_REFERENCE" => FunctionDefinition.standard(1, 2) + case "SYSTEM$REFERENCE" => FunctionDefinition.standard(2, 4) + case "SYSTEM$REGISTER_CMK_INFO" => FunctionDefinition.standard(7) + case "SYSTEM$REGISTRY_LIST_IMAGES — " => FunctionDefinition.standard(3) + case "SYSTEM$REVOKE_PRIVATELINK" => FunctionDefinition.standard(3) + case "SYSTEM$REVOKE_STAGE_PRIVATELINK_ACCESS" => FunctionDefinition.standard(1) + case "SYSTEM$SET_APPLICATION_RESTRICTED_FEATURE_ACCESS" => FunctionDefinition.standard(0) + case "SYSTEM$SET_EVENT_SHARING_ACCOUNT_FOR_REGION" => FunctionDefinition.standard(3) + case "SYSTEM$SET_RETURN_VALUE" => FunctionDefinition.standard(1) + case "SYSTEM$SET_SPAN_ATTRIBUTES (for Snowflake Scripting)" => FunctionDefinition.standard(1) + case "SYSTEM$SHOW_ACTIVE_BEHAVIOR_CHANGE_BUNDLES" => FunctionDefinition.standard(0) + case "SYSTEM$SHOW_BUDGETS_IN_ACCOUNT" => FunctionDefinition.standard(0) + case "SYSTEM$SHOW_EVENT_SHARING_ACCOUNTS" => FunctionDefinition.standard(0) + case "SYSTEM$SHOW_OAUTH_CLIENT_SECRETS" => FunctionDefinition.standard(1) + case "SYSTEM$SNOWPIPE_STREAMING_UPDATE_CHANNEL_OFFSET_TOKEN" => FunctionDefinition.standard(5) + case "SYSTEM$START_OAUTH_FLOW" => FunctionDefinition.standard(1) + case "SYSTEM$STREAM_BACKLOG" => FunctionDefinition.standard(1) + case "SYSTEM$STREAM_GET_TABLE_TIMESTAMP" => FunctionDefinition.standard(1) + case "SYSTEM$STREAM_HAS_DATA" => FunctionDefinition.standard(1) + case "SYSTEM$TASK_DEPENDENTS_ENABLE" => FunctionDefinition.standard(1) + case "SYSTEM$TASK_RUNTIME_INFO" => FunctionDefinition.standard(1) + case "SYSTEM$TYPEOF" => FunctionDefinition.standard(1) + case "SYSTEM$UNBLOCK_INTERNAL_STAGES_PUBLIC_ACCESS" => FunctionDefinition.standard(0) + case "SYSTEM$UNSET_EVENT_SHARING_ACCOUNT_FOR_REGION" => FunctionDefinition.standard(3) + case "SYSTEM$USER_TASK_CANCEL_ONGOING_EXECUTIONS" => FunctionDefinition.standard(1) + case "SYSTEM$VALIDATE_STORAGE_INTEGRATION" => FunctionDefinition.standard(4) + case "SYSTEM$VERIFY_CMK_INFO" => FunctionDefinition.standard(0) + case "SYSTEM$VERIFY_EXTERNAL_OAUTH_TOKEN" => FunctionDefinition.standard(1) + case "SYSTEM$WAIT" => FunctionDefinition.standard(1, 2) + case "SYSTEM$WHITELIST — " => FunctionDefinition.standard(0) + case "SYSTEM$WHITELIST_PRIVATELINK — " => FunctionDefinition.standard(0) + case "SYSTIMESTAMP" => FunctionDefinition.standard(0) + case "TAG_REFERENCES" => FunctionDefinition.standard(2) + case "TAG_REFERENCES_ALL_COLUMNS" => FunctionDefinition.standard(2) + case "TAG_REFERENCES_WITH_LINEAGE" => FunctionDefinition.standard(1) + case "TAN" => FunctionDefinition.standard(1) + case "TANH" => FunctionDefinition.standard(1) + case "TASK_DEPENDENTS" => FunctionDefinition.standard(2) + case "TASK_HISTORY" => + FunctionDefinition.symbolic( + Set.empty, + Set( + "SCHEDULED_TIME_RANGE_START", + "SCHEDULED_TIME_RANGE_END", + "RESULT_LIMIT", + "TASK_NAME", + "ERROR_ONLY", + "ROOT_TASK_ID")) + case "TEXT_HTML" => FunctionDefinition.standard(1) + case "TEXT_PLAIN" => FunctionDefinition.standard(1) + case "TIME" => FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_TIME")) + case "TIMEADD" => FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATEADD")) + case "TIMEDIFF" => FunctionDefinition.standard(3) + case "TIMESTAMPADD" => FunctionDefinition.standard(3) + case "TIMESTAMPDIFF" => FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATEDIFF")) + case "TIMESTAMP_FROM_PARTS" => FunctionDefinition.standard(2, 8) + case "TIMESTAMP_LTZ_FROM_PARTS" => FunctionDefinition.standard(6, 7) + case "TIMESTAMP_NTZ_FROM_PARTS" => FunctionDefinition.standard(2, 7) + case "TIMESTAMP_TZ_FROM_PARTS" => FunctionDefinition.standard(6, 7) + case "TIME_FROM_PARTS" => FunctionDefinition.standard(3, 4) + case "TIME_SLICE" => FunctionDefinition.standard(3, 4) + case "TOP_INSIGHTS (SNOWFLAKE.ML)" => FunctionDefinition.standard(4) + case "TO_ARRAY" => FunctionDefinition.standard(1, 2) + case "TO_BINARY" => FunctionDefinition.standard(1, 2) + case "TO_BOOLEAN" => FunctionDefinition.standard(1) + case "TO_CHAR , TO_VARCHAR" => FunctionDefinition.standard(4, 5) + case "TO_CHAR" => FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_VARCHAR")) + case "TO_DATE , DATE" => FunctionDefinition.standard(4, 5) + case "TO_DATE" => FunctionDefinition.standard(1, 2) + case "TO_DECIMAL" => FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TO_NUMBER")) + case "TO_DOUBLE" => FunctionDefinition.standard(1, 2) + case "TO_GEOGRAPHY" => FunctionDefinition.standard(1, 2) + case "TO_GEOMETRY" => FunctionDefinition.standard(1, 3) + case "TO_JSON" => FunctionDefinition.standard(1) + case "TO_NUMBER" => FunctionDefinition.standard(1, 4) + case "TO_NUMERIC" => FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TO_NUMBER")) + case "TO_OBJECT" => FunctionDefinition.standard(1) + case "TO_QUERY" => FunctionDefinition.standard(1, 3) + case "TO_TIME , TIME" => FunctionDefinition.standard(4, 5) + case "TO_TIME" => FunctionDefinition.standard(1, 2) + case "TO_TIMESTAMP" => FunctionDefinition.standard(1, 2) + case "TO_TIMESTAMP_LTZ" => FunctionDefinition.standard(1, 2) + case "TO_TIMESTAMP_NTZ" => FunctionDefinition.standard(1, 2) + case "TO_TIMESTAMP_TZ" => FunctionDefinition.standard(1, 2) + case "TO_VARCHAR" => FunctionDefinition.standard(1, 2) + case "TO_VARIANT" => FunctionDefinition.standard(1) + case "TO_XML" => FunctionDefinition.standard(1) + case "TRANSFORM" => FunctionDefinition.standard(2) + case "TRANSLATE" => FunctionDefinition.standard(3) + case "TRIM" => FunctionDefinition.standard(1, 2) + case "TRUNC" => FunctionDefinition.standard(2) + case "TRUNCATE" => FunctionDefinition.standard(2) + case "TRY_BASE64_DECODE_BINARY" => FunctionDefinition.standard(1, 2) + case "TRY_BASE64_DECODE_STRING" => FunctionDefinition.standard(1, 2) + case "TRY_CAST" => FunctionDefinition.standard(0) + case "TRY_COMPLETE (SNOWFLAKE.CORTEX)" => FunctionDefinition.standard(2, 3) + case "TRY_DECRYPT" => FunctionDefinition.standard(2, 4) + case "TRY_DECRYPT_RAW" => FunctionDefinition.standard(3, 6) + case "TRY_HEX_DECODE_BINARY" => FunctionDefinition.standard(1) + case "TRY_HEX_DECODE_STRING" => FunctionDefinition.standard(1) + case "TRY_PARSE_JSON" => FunctionDefinition.standard(1) + case "TRY_TO_BINARY" => FunctionDefinition.standard(1, 2) + case "TRY_TO_BOOLEAN" => FunctionDefinition.standard(1) + case "TRY_TO_DATE" => FunctionDefinition.standard(1, 2) + case "TRY_TO_DECIMAL" => FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TRY_TO_NUMBER")) + case "TRY_TO_DOUBLE" => FunctionDefinition.standard(1, 2) + case "TRY_TO_GEOGRAPHY" => FunctionDefinition.standard(1, 2) + case "TRY_TO_GEOMETRY" => FunctionDefinition.standard(1, 3) + case "TRY_TO_NUMBER" => FunctionDefinition.standard(1, 4) + case "TRY_TO_NUMERIC" => FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TRY_TO_NUMBER")) + case "TRY_TO_TIME" => FunctionDefinition.standard(1, 2) + case "TRY_TO_TIMESTAMP / TRY_TO_TIMESTAMP_*" => FunctionDefinition.standard(2, 3) + case "TRY_TO_TIMESTAMP" => FunctionDefinition.standard(1, 2) + case "TRY_TO_TIMESTAMP_LTZ" => FunctionDefinition.standard(1, 2) + case "TRY_TO_TIMESTAMP_NTZ" => FunctionDefinition.standard(1, 2) + case "TRY_TO_TIMESTAMP_TZ" => FunctionDefinition.standard(1, 2) + case "TYPEOF" => FunctionDefinition.standard(1) + case "UNICODE" => FunctionDefinition.standard(1) + case "UNIFORM" => FunctionDefinition.standard(3) + case "UNIQUE_COUNT (system data metric function)" => FunctionDefinition.standard(1) + case "UPPER" => FunctionDefinition.standard(1) + case "UUID_STRING" => FunctionDefinition.standard(0, 2) + case "VALIDATE" => FunctionDefinition.standard(3) + case "VALIDATE_PIPE_LOAD" => FunctionDefinition.standard(2, 3) + case "VARIANCE" => FunctionDefinition.standard(1) + case "VARIANCE_POP" => FunctionDefinition.standard(1) + case "VARIANCE_SAMP" => FunctionDefinition.standard(1) + case "VAR_POP" => FunctionDefinition.standard(1) + case "VAR_SAMP" => FunctionDefinition.standard(1) + case "VECTOR_COSINE_SIMILARITY" => FunctionDefinition.standard(2) + case "VECTOR_INNER_PRODUCT" => FunctionDefinition.standard(2) + case "VECTOR_L2_DISTANCE" => FunctionDefinition.standard(2) + case "WAREHOUSE_LOAD_HISTORY" => + FunctionDefinition.symbolic(Set.empty, Set("DATE_RANGE_START", "DATE_RANGE_END", "WAREHOUSE_NAME")) + case "WAREHOUSE_METERING_HISTORY" => FunctionDefinition.standard(1, 3) + case "WEEK" => FunctionDefinition.standard(1) + case "WEEKISO" => FunctionDefinition.standard(1) + case "WEEKOFYEAR" => FunctionDefinition.standard(1) + case "WIDTH_BUCKET" => FunctionDefinition.standard(4) + case "XMLGET" => FunctionDefinition.standard(2, 3) + case "YEAR" => FunctionDefinition.standard(1) + case "YEAROFWEEK" => FunctionDefinition.standard(1) + case "YEAROFWEEKISO" => FunctionDefinition.standard(1) + case "ZEROIFNULL" => FunctionDefinition.standard(1) + case "ZIPF" => FunctionDefinition.standard(3) + case "[ NOT ] BETWEEN" => FunctionDefinition.standard(0) + case "[ NOT ] ILIKE" => FunctionDefinition.standard(2, 3) + case "[ NOT ] IN" => FunctionDefinition.standard(0) + case "[ NOT ] LIKE" => FunctionDefinition.standard(2, 3) + case "[ NOT ] REGEXP" => FunctionDefinition.standard(2) + case "[ NOT ] RLIKE" => FunctionDefinition.standard(2, 3) + + } + + override def functionDefinition(name: String): Option[FunctionDefinition] = + // If not found, check common functions + SnowflakeFunctionDefinitionPf.orElse(commonFunctionsPf).lift(name.toUpperCase()) + + def applyConversionStrategy( + functionArity: FunctionDefinition, + args: Seq[ir.Expression], + irName: String): ir.Expression = { + functionArity.conversionStrategy match { + case Some(strategy) => strategy.convert(irName, args) + case _ => ir.CallFunction(irName, args) + } + } +} + +object SnowflakeFunctionConverters { + + case class SynonymOf(canonicalName: String) extends ConversionStrategy { + override def convert(irName: String, args: Seq[ir.Expression]): ir.Expression = ir.CallFunction(canonicalName, args) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakePlanParser.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakePlanParser.scala new file mode 100644 index 0000000000..82c4d50161 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakePlanParser.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.parsers.snowflake.rules._ +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.{CharStream, Lexer, ParserRuleContext, TokenStream} + +class SnowflakePlanParser extends PlanParser[SnowflakeParser] { + + private[this] val vc = new SnowflakeVisitorCoordinator(SnowflakeParser.VOCABULARY, SnowflakeParser.ruleNames) + + override protected def createLexer(input: CharStream): Lexer = new SnowflakeLexer(input) + override protected def createParser(stream: TokenStream): SnowflakeParser = new SnowflakeParser(stream) + override protected def createTree(parser: SnowflakeParser): ParserRuleContext = parser.snowflakeFile() + override protected def createPlan(tree: ParserRuleContext): ir.LogicalPlan = vc.astBuilder.visit(tree) + override protected def addErrorStrategy(parser: SnowflakeParser): Unit = + parser.setErrorHandler(new SnowflakeErrorStrategy) + def dialect: String = "snowflake" + + // TODO: Note that this is not the correct place for the optimizer, but it is here for now + override protected def createOptimizer: ir.Rules[ir.LogicalPlan] = { + ir.Rules( + new DealiasLCAs, + new ConvertFractionalSecond, + new FlattenLateralViewToExplode(), + new SnowflakeCallMapper, + ir.AlwaysUpperNameForCallFunction, + new UpdateToMerge, + new CastParseJsonToFromJson, + new TranslateWithinGroup, + new FlattenNestedConcat, + new CompactJsonAccess, + new DealiasInlineColumnExpressions) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala new file mode 100644 index 0000000000..983af966a1 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala @@ -0,0 +1,374 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._ +import com.databricks.labs.remorph.parsers.snowflake.rules.InlineColumnExpression +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.ParserRuleContext + +import scala.collection.JavaConverters._ + +class SnowflakeRelationBuilder(override val vc: SnowflakeVisitorCoordinator) + extends SnowflakeParserBaseVisitor[ir.LogicalPlan] + with ParserCommon[ir.LogicalPlan] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.LogicalPlan = + ir.UnresolvedRelation(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitSelectStatement(ctx: SelectStatementContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // Note that the optional select clauses may be null on very simple statements + // such as SELECT 1; + val select = Option(ctx.selectOptionalClauses()).map(_.accept(this)).getOrElse(ir.NoTable) + val relation = buildLimitOffset(ctx.limitClause(), select) + val (top, allOrDistinct, selectListElements) = ctx match { + case c if ctx.selectClause() != null => + ( + None, + c.selectClause().selectListNoTop().allDistinct(), + c.selectClause().selectListNoTop().selectList().selectListElem().asScala) + case c if ctx.selectTopClause() != null => + ( + Option(c.selectTopClause().selectListTop().topClause()), + c.selectTopClause().selectListTop().allDistinct(), + c.selectTopClause().selectListTop().selectList().selectListElem().asScala) + + // NOte that we must cater for error recovery where neither clause is present + case _ => (None, null, Seq.empty) + } + val expressions = selectListElements.map(_.accept(vc.expressionBuilder)) + + if (Option(allOrDistinct).exists(_.DISTINCT() != null)) { + buildTop(top, buildDistinct(relation, expressions)) + } else { + ir.Project(buildTop(top, relation), expressions) + } + } + + private def buildLimitOffset(ctx: LimitClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx).fold(input) { c => + if (c.LIMIT() != null) { + val limit = ir.Limit(input, ctx.expr(0).accept(vc.expressionBuilder)) + if (c.OFFSET() != null) { + ir.Offset(limit, ctx.expr(1).accept(vc.expressionBuilder)) + } else { + limit + } + } else { + ir.Offset(input, ctx.expr(0).accept(vc.expressionBuilder)) + } + } + } + + private def buildDistinct(input: ir.LogicalPlan, projectExpressions: Seq[ir.Expression]): ir.LogicalPlan = { + val columnNames = projectExpressions.collect { + case ir.Id(i, _) => i + case ir.Column(_, c) => c + case ir.Alias(_, a) => a + } + ir.Deduplicate(input, projectExpressions, columnNames.isEmpty, within_watermark = false) + } + + private def buildTop(ctxOpt: Option[TopClauseContext], input: ir.LogicalPlan): ir.LogicalPlan = + ctxOpt.fold(input) { top => + ir.Limit(input, top.expr().accept(vc.expressionBuilder)) + } + + override def visitSelectOptionalClauses(ctx: SelectOptionalClausesContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val from = Option(ctx.fromClause()).map(_.accept(this)).getOrElse(ir.NoTable) + buildOrderBy( + ctx.orderByClause(), + buildQualify( + ctx.qualifyClause(), + buildHaving(ctx.havingClause(), buildGroupBy(ctx.groupByClause(), buildWhere(ctx.whereClause(), from))))) + } + + override def visitFromClause(ctx: FromClauseContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableSources = visitMany(ctx.tableSources().tableSource()) + // The tableSources seq cannot be empty (as empty FROM clauses are not allowed + tableSources match { + case Seq(tableSource) => tableSource + case sources => + sources.reduce( + ir.Join(_, _, None, ir.InnerJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + } + } + + private def buildFilter[A](ctx: A, conditionRule: A => ParserRuleContext, input: ir.LogicalPlan): ir.LogicalPlan = + Option(ctx).fold(input) { c => + ir.Filter(input, conditionRule(c).accept(vc.expressionBuilder)) + } + private def buildHaving(ctx: HavingClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = + buildFilter[HavingClauseContext](ctx, _.searchCondition(), input) + + private def buildQualify(ctx: QualifyClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = + buildFilter[QualifyClauseContext](ctx, _.expr(), input) + private def buildWhere(ctx: WhereClauseContext, from: ir.LogicalPlan): ir.LogicalPlan = + buildFilter[WhereClauseContext](ctx, _.searchCondition(), from) + + private def buildGroupBy(ctx: GroupByClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx).fold(input) { c => + val groupingExpressions = + Option(c.groupByList()).toSeq + .flatMap(_.groupByElem().asScala) + .map(_.accept(vc.expressionBuilder)) + val groupType = if (c.ALL() != null) ir.GroupByAll else ir.GroupBy + val aggregate = + ir.Aggregate(child = input, group_type = groupType, grouping_expressions = groupingExpressions, pivot = None) + buildHaving(c.havingClause(), aggregate) + } + } + + private def buildOrderBy(ctx: OrderByClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx).fold(input) { c => + val sortOrders = c.orderItem().asScala.map(vc.expressionBuilder.visitOrderItem) + ir.Sort(input, sortOrders, is_global = false) + } + } + + override def visitObjRefTableFunc(ctx: ObjRefTableFuncContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableFunc = ir.TableFunction(ctx.functionCall().accept(vc.expressionBuilder)) + buildSubqueryAlias(ctx.tableAlias(), buildPivotOrUnpivot(ctx.pivotUnpivot(), tableFunc)) + } + + // @see https://docs.snowflake.com/en/sql-reference/functions/flatten + // @see https://docs.snowflake.com/en/sql-reference/functions-table + override def visitObjRefSubquery(ctx: ObjRefSubqueryContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val relation = ctx match { + case c if c.subquery() != null => c.subquery().accept(this) + case c if c.functionCall() != null => ir.TableFunction(c.functionCall().accept(vc.expressionBuilder)) + } + val maybeLateral = if (ctx.LATERAL() != null) { + ir.Lateral(relation) + } else { + relation + } + buildSubqueryAlias(ctx.tableAlias(), buildPivotOrUnpivot(ctx.pivotUnpivot(), maybeLateral)) + } + + private def buildSubqueryAlias(ctx: TableAliasContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx) + .map(a => + ir.SubqueryAlias( + input, + vc.expressionBuilder.buildId(a.alias().id()), + a.id().asScala.map(vc.expressionBuilder.buildId))) + .getOrElse(input) + } + + override def visitValuesTable(ctx: ValuesTableContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.valuesTableBody().accept(this) + } + + override def visitValuesTableBody(ctx: ValuesTableBodyContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expressions = + ctx + .exprListInParentheses() + .asScala + .map(l => vc.expressionBuilder.visitMany(l.exprList().expr())) + ir.Values(expressions) + } + + override def visitObjRefDefault(ctx: ObjRefDefaultContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildTableAlias(ctx.tableAlias(), buildPivotOrUnpivot(ctx.pivotUnpivot(), ctx.dotIdentifier().accept(this))) + } + + override def visitTableRef(ctx: TableRefContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val table = ctx.dotIdentifier().accept(this) + Option(ctx.asAlias()) + .map { a => + ir.TableAlias(table, a.alias().getText, Seq()) + } + .getOrElse(table) + } + + override def visitDotIdentifier(ctx: DotIdentifierContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = ctx.id().asScala.map(vc.expressionBuilder.buildId).map(_.id).mkString(".") + ir.NamedTable(tableName, Map.empty, is_streaming = false) + } + + override def visitObjRefValues(ctx: ObjRefValuesContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val values = ctx.valuesTable().accept(this) + buildTableAlias(ctx.tableAlias(), values) + } + + private def buildTableAlias(ctx: TableAliasContext, relation: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx) + .map { c => + val alias = c.alias().getText + val columns = Option(c.id()).map(_.asScala.map(vc.expressionBuilder.buildId)).getOrElse(Seq.empty) + ir.TableAlias(relation, alias, columns) + } + .getOrElse(relation) + } + + private def buildPivotOrUnpivot(ctx: PivotUnpivotContext, relation: ir.LogicalPlan): ir.LogicalPlan = { + if (ctx == null) { + relation + } else if (ctx.PIVOT() != null) { + buildPivot(ctx, relation) + } else { + buildUnpivot(ctx, relation) + } + } + + private def buildPivot(ctx: PivotUnpivotContext, relation: ir.LogicalPlan): ir.LogicalPlan = { + val pivotValues: Seq[ir.Literal] = + vc.expressionBuilder.visitMany(ctx.values).collect { case lit: ir.Literal => lit } + val argument = ir.Column(None, vc.expressionBuilder.buildId(ctx.pivotColumn)) + val column = ir.Column(None, vc.expressionBuilder.buildId(ctx.valueColumn)) + val aggFunc = vc.expressionBuilder.buildId(ctx.aggregateFunc) + val aggregateFunction = vc.functionBuilder.buildFunction(aggFunc, Seq(argument)) + ir.Aggregate( + child = relation, + group_type = ir.Pivot, + grouping_expressions = Seq(aggregateFunction), + pivot = Some(ir.Pivot(column, pivotValues))) + } + + private def buildUnpivot(ctx: PivotUnpivotContext, relation: ir.LogicalPlan): ir.LogicalPlan = { + val unpivotColumns = ctx + .columnList() + .columnName() + .asScala + .map(_.accept(vc.expressionBuilder)) + val variableColumnName = vc.expressionBuilder.buildId(ctx.valueColumn) + val valueColumnName = vc.expressionBuilder.buildId(ctx.nameColumn) + ir.Unpivot( + child = relation, + ids = unpivotColumns, + values = None, + variable_column_name = variableColumnName, + value_column_name = valueColumnName) + } + + override def visitTableSource(ctx: TableSourceContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableSource = ctx match { + case c if c.tableSourceItemJoined() != null => c.tableSourceItemJoined().accept(this) + case c if c.tableSource() != null => c.tableSource().accept(this) + } + buildSample(ctx.sample(), tableSource) + } + + override def visitTableSourceItemJoined(ctx: TableSourceItemJoinedContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.objectRef().accept(this) + ctx.joinClause().asScala.foldLeft(left)(buildJoin) + } + + private def buildJoin(left: ir.LogicalPlan, right: JoinClauseContext): ir.Join = { + val usingColumns = Option(right.columnList()).map(_.columnName().asScala.map(_.getText)).getOrElse(Seq()) + val joinType = if (right.NATURAL() != null) { + ir.NaturalJoin(translateOuterJoinType(right.outerJoin())) + } else if (right.CROSS() != null) { + ir.CrossJoin + } else { + translateJoinType(right.joinType()) + } + ir.Join( + left, + right.objectRef().accept(this), + Option(right.searchCondition()).map(_.accept(vc.expressionBuilder)), + joinType, + usingColumns, + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) + + } + + private[snowflake] def translateJoinType(joinType: JoinTypeContext): ir.JoinType = { + Option(joinType) + .map { jt => + if (jt.INNER() != null) { + ir.InnerJoin + } else { + translateOuterJoinType(jt.outerJoin()) + } + } + .getOrElse(ir.UnspecifiedJoin) + } + + private def translateOuterJoinType(ctx: OuterJoinContext): ir.JoinType = { + Option(ctx) + .collect { + case c if c.LEFT() != null => ir.LeftOuterJoin + case c if c.RIGHT() != null => ir.RightOuterJoin + case c if c.FULL() != null => ir.FullOuterJoin + } + .getOrElse(ir.UnspecifiedJoin) + } + + override def visitCTETable(ctx: CTETableContext): ir.LogicalPlan = + errorCheck(ctx).getOrElse { + val tableName = vc.expressionBuilder.buildId(ctx.tableName) + val columns = ctx.columnList() match { + case null => Seq.empty[ir.Id] + case c => c.columnName().asScala.flatMap(_.id.asScala.map(vc.expressionBuilder.buildId)) + } + val queryExpression = ctx.queryExpression().accept(vc.astBuilder) + ir.SubqueryAlias(queryExpression, tableName, columns) + } + + override def visitCTEColumn(ctx: CTEColumnContext): ir.LogicalPlan = { + InlineColumnExpression(vc.expressionBuilder.buildId(ctx.id()), ctx.expr().accept(vc.expressionBuilder)) + } + + private def buildSampleMethod(ctx: SampleMethodContext): ir.SamplingMethod = ctx match { + case c: SampleMethodRowFixedContext => ir.RowSamplingFixedAmount(BigDecimal(c.INT().getText)) + case c: SampleMethodRowProbaContext => ir.RowSamplingProbabilistic(BigDecimal(c.INT().getText)) + case c: SampleMethodBlockContext => ir.BlockSampling(BigDecimal(c.INT().getText)) + } + + private def buildSample(ctx: SampleContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx) + .map { sampleCtx => + val seed = Option(sampleCtx.sampleSeed()).map(s => BigDecimal(s.INT().getText)) + val sampleMethod = buildSampleMethod(sampleCtx.sampleMethod()) + ir.TableSample(input, sampleMethod, seed) + } + .getOrElse(input) + } + + override def visitTableOrQuery(ctx: TableOrQueryContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.tableRef() != null => c.tableRef().accept(this) + case c if c.subquery() != null => + val subquery = c.subquery().accept(this) + Option(c.asAlias()) + .map { a => + ir.SubqueryAlias(subquery, vc.expressionBuilder.buildId(a.alias().id()), Seq()) + } + .getOrElse(subquery) + + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilder.scala new file mode 100644 index 0000000000..b67e696b79 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilder.scala @@ -0,0 +1,69 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.DataTypeContext +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters._ + +/** + * @see + * https://spark.apache.org/docs/latest/sql-ref-datatypes.html + * @see + * https://docs.snowflake.com/en/sql-reference-data-types + */ +class SnowflakeTypeBuilder { + // see https://docs.snowflake.com/en/sql-reference/data-types-numeric#int-integer-bigint-smallint-tinyint-byteint + private def defaultNumber = ir.DecimalType(Some(38), Some(0)) + + def buildDataType(ctx: DataTypeContext): ir.DataType = ctx match { + case null => ir.UnresolvedType + case _ if ctx.ARRAY() != null => ir.ArrayType(buildDataType(ctx.dataType())) + case _ if ctx.OBJECT() != null => ir.UnparsedType("OBJECT") // TODO: get more examples + case _ => + val typeDef = ctx.id.getText.toUpperCase() + typeDef match { + // precision types - loosing precision except for decimal + case "CHAR" | "NCHAR" | "CHARACTER" => ir.StringType // character types + case "CHAR_VARYING" | "NCHAR_VARYING" | "NVARCHAR2" | "NVARCHAR" | "STRING" | "TEXT" | "VARCHAR" => + ir.StringType // VARCHAR is string type in Databricks + case "NUMBER" | "NUMERIC" | "DECIMAL" => decimal(ctx) + case "TIMESTAMP" | "TIMESTAMP_LTZ" | "TIMESTAMP_TZ" | "TIMESTAMPTZ" => ir.TimestampType + case "TIMESTAMP_NTZ" => ir.TimestampNTZType + + // non-precision types + case "BIGINT" => defaultNumber + case "BINARY" => ir.BinaryType + case "BOOLEAN" => ir.BooleanType + case "BYTEINT" => defaultNumber + case "DATE" => ir.DateType + case "DOUBLE" => ir.DoubleType + case "DOUBLE PRECISION" => ir.DoubleType + case "FLOAT" => ir.DoubleType + case "FLOAT4" => ir.DoubleType + case "FLOAT8" => ir.DoubleType + case "INT" => defaultNumber + case "INTEGER" => defaultNumber + case "REAL" => ir.DoubleType + case "SMALLINT" => defaultNumber + case "TIME" => ir.TimestampType + case "TINYINT" => ir.TinyintType + case "VARBINARY" => ir.BinaryType + case "VARIANT" => ir.VariantType + + // TODO: GEOGRAPHY is not yet catered for in Snowflake type builder + // TODO: GEOMETRY is not yet catered for in Snowflake type builder + + // and everything else must be an input error or a type we don't know about yet + case _ => ir.UnparsedType(typeDef) + } + } + + private def decimal(c: DataTypeContext) = { + val nums = c.INT().asScala + // https://docs.snowflake.com/en/sql-reference/data-types-numeric#number + // Per Docs defaulting the precision to 38 and scale to 0 + val precision = nums.headOption.map(_.getText.toInt).getOrElse(38) + val scale = nums.drop(1).headOption.map(_.getText.toInt).getOrElse(0) + ir.DecimalType(precision, scale) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeVisitorCoordinator.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeVisitorCoordinator.scala new file mode 100644 index 0000000000..0a1f874b3d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeVisitorCoordinator.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.VisitorCoordinator +import org.antlr.v4.runtime.Vocabulary + +class SnowflakeVisitorCoordinator(parserVocab: Vocabulary, ruleNames: Array[String]) + extends VisitorCoordinator(parserVocab, ruleNames) { + + val astBuilder = new SnowflakeAstBuilder(this) + val relationBuilder = new SnowflakeRelationBuilder(this) + val expressionBuilder = new SnowflakeExpressionBuilder(this) + val dmlBuilder = new SnowflakeDMLBuilder(this) + val ddlBuilder = new SnowflakeDDLBuilder(this) + val functionBuilder = new SnowflakeFunctionBuilder + + // Snowflake extensions + val commandBuilder = new SnowflakeCommandBuilder(this) + val typeBuilder = new SnowflakeTypeBuilder +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/expressions.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/expressions.scala new file mode 100644 index 0000000000..556e5afd5d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/expressions.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.{intermediate => ir} + +case class NamedArgumentExpression(key: String, value: ir.Expression) extends ir.Expression { + override def children: Seq[ir.Expression] = value :: Nil + override def dataType: ir.DataType = value.dataType +} + +case class NextValue(sequenceName: String) extends ir.LeafExpression { + override def dataType: ir.DataType = ir.LongType +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CastParseJsonToFromJson.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CastParseJsonToFromJson.scala new file mode 100644 index 0000000000..b75e01db7f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CastParseJsonToFromJson.scala @@ -0,0 +1,14 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.TranspilerState +import com.databricks.labs.remorph.generators.sql.DataTypeGenerator +import com.databricks.labs.remorph.intermediate._ + +class CastParseJsonToFromJson extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { case Cast(CallFunction("PARSE_JSON", Seq(payload)), dt, _, _, _) => + val dataType = DataTypeGenerator.generateDataType(dt).runAndDiscardState(TranspilerState()) + JsonToStructs(payload, Literal(dataType), None) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CompactJsonAccess.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CompactJsonAccess.scala new file mode 100644 index 0000000000..d6b6d7f662 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/CompactJsonAccess.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate._ + +class CompactJsonAccess extends Rule[LogicalPlan] with IRHelpers { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case expression: Expression => + expression transform { + case JsonAccess(JsonAccess(l1, r1), JsonAccess(l2, r2)) => JsonAccess(l1, Dot(r1, Dot(l2, r2))) + case JsonAccess(JsonAccess(l1, r1), r2) => JsonAccess(l1, Dot(r1, r2)) + case JsonAccess(l1, JsonAccess(l2, r2)) => JsonAccess(l1, Dot(l2, r2)) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/ConvertFractionalSecond.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/ConvertFractionalSecond.scala new file mode 100644 index 0000000000..3d9f0a6e45 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/ConvertFractionalSecond.scala @@ -0,0 +1,63 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.{intermediate => ir} + +class ConvertFractionalSecond extends ir.Rule[ir.LogicalPlan] { + + // Please read the note here : https://docs.snowflake.com/en/sql-reference/functions/current_timestamp#arguments + // TODO Fractional seconds are only displayed if they have been explicitly + // set in the TIME_OUTPUT_FORMAT parameter for the session (e.g. 'HH24:MI:SS.FF'). + + private[this] val timeMapping: Map[Int, String] = Map( + 0 -> "HH:mm:ss", + 1 -> "HH:mm:ss", + 2 -> "HH:mm:ss", + 3 -> "HH:mm:ss", + 4 -> "HH:mm:ss", + 5 -> "HH:mm:ss", + 6 -> "HH:mm:ss", + 7 -> "HH:mm:ss", + 8 -> "HH:mm:ss", + 9 -> "HH:mm:ss") + + override def apply(plan: ir.LogicalPlan): ir.LogicalPlan = { + plan transformAllExpressions { + case ir.CallFunction("CURRENT_TIME", right) => handleSpecialTSFunctions("CURRENT_TIME", right) + case ir.CallFunction("LOCALTIME", right) => handleSpecialTSFunctions("LOCALTIME", right) + case ir.CallFunction("CURRENT_TIMESTAMP", right) => + if (right.isEmpty) { + ir.CurrentTimestamp() + } else { + handleSpecialTSFunctions("CURRENT_TIMESTAMP", right) + } + case ir.CallFunction("LOCALTIMESTAMP", right) => + if (right.isEmpty) { + ir.CurrentTimestamp() + } else { + handleSpecialTSFunctions("LOCALTIMESTAMP", right) + } + } + } + + private def getIntegerValue(literal: Option[ir.Literal]): Option[Int] = literal match { + case Some(ir.Literal(value: Int, _)) => Some(value) + case _ => None + } + + private def handleSpecialTSFunctions(functionName: String, arguments: Seq[ir.Expression]): ir.Expression = { + val timeFormat = timeMapping(getIntegerValue(arguments.headOption.flatMap { + case lit: ir.Literal => Some(lit) + case _ => None + }).getOrElse(0)) + + // https://docs.snowflake.com/en/sql-reference/functions/current_timestamp + // https://docs.snowflake.com/en/sql-reference/functions/current_time + // https://docs.snowflake.com/en/sql-reference/functions/localtimestamp + val formatString = functionName match { + case "CURRENT_TIME" | "LOCALTIME" => timeFormat + case _ => s"yyyy-MM-dd $timeFormat.SSS" + } + ir.DateFormatClass(ir.CurrentTimestamp(), ir.Literal(formatString)) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasInlineColumnExpressions.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasInlineColumnExpressions.scala new file mode 100644 index 0000000000..872481f40a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasInlineColumnExpressions.scala @@ -0,0 +1,41 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate._ + +// SNOWFLAKE - Common Table Expression may be of the form: +// WITH +// a AS (1), +// b AS (2) +// SELECT ... +// CTEs like `a AS (1)` above act as columns of an anonymous table with a single row. +// In AstBuilding phase, we translate those as InlineColumnExpression. Later, in Optimizing phase, we'll combine all +// such expressions in a single table declaration (using VALUES). +private[snowflake] case class InlineColumnExpression(columnName: Id, value: Expression) extends LogicalPlan { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty +} + +class DealiasInlineColumnExpressions extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transform { case WithCTE(ctes, query) => + bundleInlineColumns(ctes, query) + } + + private def bundleInlineColumns(plans: Seq[LogicalPlan], query: LogicalPlan): LogicalPlan = { + val (inlineColumns, tables) = plans.foldLeft((Seq.empty[InlineColumnExpression], Seq.empty[LogicalPlan])) { + case ((ics, tbls), i: InlineColumnExpression) => ((ics :+ i, tbls)) + case ((ics, tbls), t) => ((ics, tbls :+ t)) + } + val columnNamesToValues = inlineColumns.map(ic => ic.columnName -> ic.value).toMap + + val fixedUpReferences = query transformUp { case p => + p.transformExpressionsUp { + case Column(None, id: Id) if columnNamesToValues.contains(id) => columnNamesToValues(id) + case id: Id if columnNamesToValues.contains(id) => columnNamesToValues(id) + } + } + + WithCTE(tables, fixedUpReferences) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAs.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAs.scala new file mode 100644 index 0000000000..505580668a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAs.scala @@ -0,0 +1,53 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate.{Expression, _} + +class DealiasLCAs extends Rule[LogicalPlan] with IRHelpers { + + override def apply(plan: LogicalPlan): LogicalPlan = transformPlan(plan) + + private[rules] def transformPlan(plan: LogicalPlan): LogicalPlan = + plan transform { case project: Project => + dealiasProject(project) + } + + private def dealiasProject(project: Project): Project = { + // Go through the Project's select list, collecting aliases + // and dealias expressions using the aliases collected thus far + val (aliases, dealiasedExpressions) = + project.expressions.foldLeft((Map.empty[String, Expression], Seq.empty[Expression])) { + case ((aliases, exprs), a @ Alias(expr, name)) => + // LCA aren't supported in WINDOW clauses, so we must dealias them + val dw = dealiasWindow(expr, aliases) + val accumulatedExprs = exprs :+ CurrentOrigin.withOrigin(a.origin)(Alias(dw, name)) + // An aliased expression may refer to a previous LCA, so before storing the mapping, + // we must dealias the expression to ensure that mapped expressions are fully dealiased. + val newFoundAlias = dealiasExpression(dw, aliases) + val updatedAliases = aliases + (name.id -> newFoundAlias) + (updatedAliases, accumulatedExprs) + case ((aliases, exprs), e) => (aliases, exprs :+ dealiasWindow(e, aliases)) + } + + val dealiasedInput = project.input transformDown { case f @ Filter(in, cond) => + CurrentOrigin.withOrigin(f.origin)(Filter(in, dealiasExpression(cond, aliases))) + } + + CurrentOrigin.withOrigin(project.origin)(Project(dealiasedInput, dealiasedExpressions)) + } + + private def dealiasWindow(expr: Expression, aliases: Map[String, Expression]): Expression = { + expr transformDown { case w: Window => + w.mapChildren(dealiasExpression(_, aliases)) + } + } + + private def dealiasExpression(expr: Expression, aliases: Map[String, Expression]): Expression = { + expr transformUp { + case id: Id => aliases.getOrElse(id.id, id) + case n: Name => aliases.getOrElse(n.name, n) + case e: Exists => CurrentOrigin.withOrigin(e.origin)(Exists(transformPlan(e.relation))) + case s: ScalarSubquery => CurrentOrigin.withOrigin(s.origin)(ScalarSubquery(transformPlan(s.plan))) + } + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenLateralViewToExplode.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenLateralViewToExplode.scala new file mode 100644 index 0000000000..a66615b6b7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenLateralViewToExplode.scala @@ -0,0 +1,67 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.parsers.snowflake.NamedArgumentExpression + +// @see https://docs.snowflake.com/en/sql-reference/functions/flatten +class FlattenLateralViewToExplode extends Rule[LogicalPlan] with IRHelpers { + + private[this] val FLATTEN_OUTPUT_COLUMNS = Set("seq", "key", "path", "index", "value", "this") + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j: Join if isLateralFlatten(j.left) => j.copy(left = translatePosExplode(plan, j.left)) + case j: Join if isLateralFlatten(j.right) => j.copy(right = translatePosExplode(plan, j.right)) + case p if isLateralFlatten(p) => translatePosExplode(plan, p) + } + + private def isLateralFlatten(plan: LogicalPlan): Boolean = plan match { + case SubqueryAlias(Lateral(TableFunction(CallFunction("FLATTEN", _)), _, _), _, _) => true + case _ => false + } + + private def translatePosExplode(plan: LogicalPlan, lateralFlatten: LogicalPlan): LogicalPlan = { + val SubqueryAlias(Lateral(TableFunction(CallFunction(_, args)), _, _), id, colNames) = lateralFlatten + val named = args.collect { case NamedArgumentExpression(key, value) => + key.toUpperCase() -> value + }.toMap + + val exprs = plan.expressions + + // FLATTEN produces a table with several columns (that we materialize as FLATTEN_OUTPUT_COLUMNS). + // We retain only the columns that are actually referenced elsewhere in the query. + val flattenOutputReferencedColumns = FLATTEN_OUTPUT_COLUMNS.filter { col => + exprs.exists(_.find { + case Dot(x, Id(c, false)) => x == id && c.equalsIgnoreCase(col) + case Column(Some(r), Id(c, false)) => r.head == id && c.equalsIgnoreCase(col) + case _ => false + }.isDefined) + } + + val input = named("INPUT") + val outer = getFlag(named, "OUTER") + + // If the `index` column of FLATTEN's output is referenced elsewhere in the query, we need to translate that + // call to FLATTEN to POSEXPLODE (so that we get the actual index of each produced row). + if (flattenOutputReferencedColumns.contains("index")) { + // TODO: What if we need the `index` of FLATTEN-ing something that translates to a VARIANT? + // VARIANT_EXPLODE outputs a POS column, we need to add a test case for that. + SubqueryAlias( + Lateral(TableFunction(PosExplode(input)), outer = outer, isView = true), + id, + flattenOutputReferencedColumns.toSeq.map(Id(_))) + } else { + val translated = input.dataType match { + case VariantType if outer => Lateral(TableFunction(VariantExplodeOuter(input)), outer = false) + case VariantType => Lateral(TableFunction(VariantExplode(input)), outer = false) + case _ => Lateral(TableFunction(Explode(input)), outer = outer) + } + SubqueryAlias(translated, id, colNames) + } + } + + private def getFlag(named: Map[String, Expression], flagName: String): Boolean = named.get(flagName) match { + case Some(BooleanLiteral(value)) => value + case _ => false + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcat.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcat.scala new file mode 100644 index 0000000000..12f988afee --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcat.scala @@ -0,0 +1,23 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules +import com.databricks.labs.remorph.{intermediate => ir} + +/** + * Flattens nested invocations of CONCAT into a single one. For example, `CONCAT(CONCAT(a, b), CONCAT(c, d))` becomes + * `CONCAT(a, b, c, d)`. + */ +class FlattenNestedConcat extends ir.Rule[ir.LogicalPlan] { + + override def apply(plan: ir.LogicalPlan): ir.LogicalPlan = { + plan transformAllExpressions flattenConcat + } + + // Make the implementation accessible for testing without having to build a full LogicalPlan + private[rules] def flattenConcat: PartialFunction[ir.Expression, ir.Expression] = { case expression => + expression transformUp { case ir.Concat(items) => + ir.Concat(items.flatMap { + case ir.Concat(sub) => sub + case x => Seq(x) + }) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapper.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapper.scala new file mode 100644 index 0000000000..6706a17ebe --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapper.scala @@ -0,0 +1,651 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate.UnresolvedNamedLambdaVariable +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.transpilers.TranspileException + +import java.time.format.DateTimeFormatter +import scala.util.Try + +class SnowflakeCallMapper extends ir.CallMapper with ir.IRHelpers { + private[this] val zeroLiteral: ir.Literal = ir.IntLiteral(0) + private[this] val oneLiteral: ir.Literal = ir.IntLiteral(1) + + override def convert(call: ir.Fn): ir.Expression = { + withNormalizedName(call) match { + // keep all the names in alphabetical order + case ir.CallFunction("ARRAY_CAT", args) => ir.Concat(args) + case ir.CallFunction("ARRAY_CONSTRUCT", args) => ir.CreateArray(args) + case ir.CallFunction("ARRAY_CONSTRUCT_COMPACT", args) => + ir.ArrayExcept(ir.CreateArray(args), ir.CreateArray(Seq(ir.Literal.Null))) + case ir.CallFunction("ARRAY_CONTAINS", args) => ir.ArrayContains(args(1), args.head) + case ir.CallFunction("ARRAY_INTERSECTION", args) => ir.ArrayIntersect(args.head, args(1)) + case ir.CallFunction("ARRAY_SIZE", args) => ir.Size(args.head) + case ir.CallFunction("ARRAY_SLICE", args) => + // @see https://docs.snowflake.com/en/sql-reference/functions/array_slice + // @see https://docs.databricks.com/en/sql/language-manual/functions/slice.html + // TODO: optimize constants: ir.Add(ir.Literal(2), ir.Literal(2)) => ir.Literal(4) + ir.Slice(args.head, zeroIndexedToOneIndexed(args(1)), args.lift(2).getOrElse(oneLiteral)) + case ir.CallFunction("ARRAY_SORT", args) => arraySort(args) + case ir.CallFunction("ARRAY_TO_STRING", args) => ir.ArrayJoin(args.head, args(1), None) + case ir.CallFunction("BASE64_DECODE_STRING", args) => ir.UnBase64(args.head) + case ir.CallFunction("BASE64_DECODE_BINARY", args) => ir.UnBase64(args.head) + case ir.CallFunction("BASE64_ENCODE", args) => ir.Base64(args.head) + case ir.CallFunction("BITOR_AGG", args) => ir.BitOrAgg(args.head) + case ir.CallFunction("BOOLAND_AGG", args) => ir.BoolAnd(args.head) + case ir.CallFunction("DATEADD", args) => dateAdd(args) + case ir.CallFunction("DATEDIFF", args) => dateDiff(args) + case ir.CallFunction("DATE_FROM_PARTS", args) => ir.MakeDate(args.head, args(1), args(2)) + case ir.CallFunction("DATE_PART", args) => datePart(args) + case ir.CallFunction("DATE_TRUNC", args) => dateTrunc(args) + case ir.CallFunction("DAYNAME", args) => dayname(args) + case ir.CallFunction("DECODE", args) => decode(args) + case ir.CallFunction("DIV0", args) => div0(args) + case ir.CallFunction("DIV0NULL", args) => div0null(args) + case ir.CallFunction("EDITDISTANCE", args) => ir.Levenshtein(args.head, args(1), args.lift(2)) + case ir.CallFunction("FIRST_VALUE", args) => ir.First(args.head, args.lift(1)) + case ir.CallFunction("FLATTEN", args) => + // @see https://docs.snowflake.com/en/sql-reference/functions/flatten + ir.Explode(args.head) + case ir.CallFunction("IFNULL", args) => ir.Coalesce(args) + case ir.CallFunction("IS_INTEGER", args) => isInteger(args) + case ir.CallFunction("JSON_EXTRACT_PATH_TEXT", args) => getJsonObject(args) + case ir.CallFunction("LAST_VALUE", args) => ir.Last(args.head, args.lift(1)) + case ir.CallFunction("LEN", args) => ir.Length(args.head) + case ir.CallFunction("LISTAGG", args) => + ir.ArrayJoin(ir.CollectList(args.head, None), args.lift(1).getOrElse(ir.Literal("")), None) + case ir.CallFunction("MONTHNAME", args) => ir.DateFormatClass(args.head, ir.Literal("MMM")) + case ir.CallFunction("MONTHS_BETWEEN", args) => ir.MonthsBetween(args.head, args(1), ir.Literal.True) + case ir.CallFunction("NULLIFZERO", args) => nullIfZero(args.head) + case ir.CallFunction("OBJECT_KEYS", args) => ir.JsonObjectKeys(args.head) + case ir.CallFunction("OBJECT_CONSTRUCT", args) => objectConstruct(args) + case ir.CallFunction("PARSE_JSON", args) => ir.ParseJson(args.head) + case ir.CallFunction("POSITION", args) => ir.CallFunction("LOCATE", args) + case ir.CallFunction("REGEXP_LIKE", args) => ir.RLike(args.head, args(1)) + case ir.CallFunction("REGEXP_SUBSTR", args) => regexpExtract(args) + case ir.CallFunction("SHA2", args) => ir.Sha2(args.head, args.lift(1).getOrElse(ir.Literal(256))) + case ir.CallFunction("SPLIT_PART", args) => splitPart(args) + case ir.CallFunction("SQUARE", args) => ir.Pow(args.head, ir.Literal(2)) + case ir.CallFunction("STRTOK", args) => strtok(args) + case ir.CallFunction("STRTOK_TO_ARRAY", args) => split(args) + case ir.CallFunction("SYSDATE", _) => ir.CurrentTimestamp() + case ir.CallFunction("TIMESTAMPADD", args) => timestampAdd(args) + case ir.CallFunction("TIMESTAMP_FROM_PARTS", args) => makeTimestamp(args) + case ir.CallFunction("TO_ARRAY", args) => toArray(args) + case ir.CallFunction("TO_BOOLEAN", args) => toBoolean(args) + case ir.CallFunction("TO_DATE", args) => toDate(args) + case ir.CallFunction("TO_DOUBLE", args) => ir.CallFunction("DOUBLE", args) + case ir.CallFunction("TO_NUMBER", args) => toNumber(args) + case ir.CallFunction("TO_OBJECT", args) => ir.StructsToJson(args.head, args.lift(1)) + case ir.CallFunction("TO_VARCHAR", args) => ir.CallFunction("TO_CHAR", args) + case ir.CallFunction("TO_VARIANT", args) => ir.StructsToJson(args.head, None) + case ir.CallFunction("TO_TIME", args) => toTime(args) + case ir.CallFunction("TO_TIMESTAMP", args) => toTimestamp(args) + case ir.CallFunction("TRY_BASE64_DECODE_STRING", args) => ir.UnBase64(args.head) + case ir.CallFunction("TRY_BASE64_DECODE_BINARY", args) => ir.UnBase64(args.head) + case ir.CallFunction("TRY_PARSE_JSON", args) => ir.ParseJson(args.head) + case ir.CallFunction("TRY_TO_BOOLEAN", args) => tryToBoolean(args) + case ir.CallFunction("TRY_TO_DATE", args) => tryToDate(args) + case ir.CallFunction("TRY_TO_NUMBER", args) => tryToNumber(args) + case ir.CallFunction("UUID_STRING", _) => ir.Uuid() + case ir.CallFunction("ZEROIFNULL", args) => ir.If(ir.IsNull(args.head), ir.Literal(0), args.head) + case x => super.convert(x) + } + } + + private def objectConstruct(args: Seq[ir.Expression]): ir.Expression = args match { + case Seq(s @ ir.Star(_)) => ir.StructExpr(Seq(s)) + case pairs: Seq[ir.Expression] => + ir.StructExpr( + pairs + .sliding(2, 2) + .collect { + case Seq(ir.StringLiteral(key), v) => ir.Alias(v, ir.Id(key)) + case args => throw TranspileException(ir.UnsupportedArguments("OBJECT_CONSTRUCT", args)) + } + .toList) + } + + private def nullIfZero(expr: ir.Expression): ir.Expression = + ir.If(ir.Equals(expr, zeroLiteral), ir.Literal.Null, expr) + + private def div0null(args: Seq[ir.Expression]): ir.Expression = args match { + case Seq(left, right) => + ir.If(ir.Or(ir.Equals(right, zeroLiteral), ir.IsNull(right)), zeroLiteral, ir.Divide(left, right)) + } + + private def div0(args: Seq[ir.Expression]): ir.Expression = args match { + case Seq(left, right) => + ir.If(ir.Equals(right, zeroLiteral), zeroLiteral, ir.Divide(left, right)) + } + + private def zeroIndexedToOneIndexed(expr: ir.Expression): ir.Expression = expr match { + case ir.IntLiteral(num) => ir.IntLiteral(num + 1) + case neg: ir.UMinus => neg + case x => ir.If(ir.GreaterThanOrEqual(x, zeroLiteral), ir.Add(x, oneLiteral), x) + } + + private def getJsonObject(args: Seq[ir.Expression]): ir.Expression = { + val translatedFmt = args match { + case Seq(_, ir.StringLiteral(path)) => ir.Literal("$." + path) + case Seq(_, id: ir.Id) => ir.Concat(Seq(ir.Literal("$."), id)) + + // As well as CallFunctions, we can receive concrete functions, which are already resolved, + // and don't need to be converted + case x: ir.Fn => x + + case a => throw TranspileException(ir.UnsupportedArguments("GET_JSON_OBJECT", a)) + } + ir.GetJsonObject(args.head, translatedFmt) + } + + private def split(args: Seq[ir.Expression]): ir.Expression = { + val delim = args.lift(1) match { + case None => ir.StringLiteral("[ ]") + case Some(ir.StringLiteral(d)) => ir.StringLiteral(s"[$d]") + case Some(e) => ir.Concat(Seq(ir.StringLiteral("["), e, ir.StringLiteral("]"))) + } + ir.StringSplit(args.head, delim, None) + } + + private def toNumber(args: Seq[ir.Expression]): ir.Expression = { + val getArg: Int => Option[ir.Expression] = args.lift + if (args.size < 2) { + ir.Cast(args.head, ir.DecimalType(38, 0)) + } else if (args.size == 2) { + ir.ToNumber(args.head, args(1)) + } else { + val fmt = getArg(1).collect { case f @ ir.StringLiteral(_) => + f + } + val precPos = fmt.fold(1)(_ => 2) + val prec = getArg(precPos).collect { case ir.IntLiteral(p) => + p + } + val scale = getArg(precPos + 1).collect { case ir.IntLiteral(s) => + s + } + val castedExpr = fmt.fold(args.head)(_ => ir.ToNumber(args.head, args(1))) + ir.Cast(castedExpr, ir.DecimalType(prec, scale)) + } + } + + private def tryToNumber(args: Seq[ir.Expression]): ir.Expression = { + val getArg: Int => Option[ir.Expression] = args.lift + if (args.size == 1) { + ir.Cast(args.head, ir.DecimalType(Some(38), Some(0))) + } else { + val fmt = getArg(1).collect { case f @ ir.StringLiteral(_) => + f + } + val precPos = fmt.fold(1)(_ => 2) + val prec = getArg(precPos) + .collect { case ir.IntLiteral(p) => + p + } + .orElse(Some(38)) + val scale = getArg(precPos + 1) + .collect { case ir.IntLiteral(s) => + s + } + .orElse(Some(0)) + val castedExpr = fmt.fold(args.head)(f => ir.TryToNumber(args.head, f)) + ir.Cast(castedExpr, ir.DecimalType(prec, scale)) + } + } + + private def strtok(args: Seq[ir.Expression]): ir.Expression = { + if (args.size == 1) { + splitPart(Seq(args.head, ir.Literal(" "), oneLiteral)) + } else if (args.size == 2) { + splitPart(Seq(args.head, args(1), oneLiteral)) + } else splitPart(args) + } + + /** + * Snowflake and DB SQL differ in the `partNumber` argument: in Snowflake, a value of 0 is interpreted as "get the + * first part" while it raises an error in DB SQL. + */ + private def splitPart(args: Seq[ir.Expression]): ir.Expression = args match { + case Seq(str, delim, ir.IntLiteral(0)) => ir.StringSplitPart(str, delim, oneLiteral) + case Seq(str, delim, ir.IntLiteral(p)) => ir.StringSplitPart(str, delim, ir.Literal(p)) + case Seq(str, delim, expr) => + ir.StringSplitPart(str, delim, ir.If(ir.Equals(expr, zeroLiteral), oneLiteral, expr)) + case other => + throw TranspileException(ir.WrongNumberOfArguments("SPLIT_PART", other.size, "3")) + } + + // REGEXP_SUBSTR( , [ , [ , [ , [ , ]]]]) + private def regexpExtract(args: Seq[ir.Expression]): ir.Expression = { + val subject = if (args.size >= 3) { + ir.Substring(args.head, args(2)) + } else args.head + if (args.size <= 3) { + ir.RegExpExtract(subject, args(1), Some(zeroLiteral)) + } else { + val occurrence = args(3) match { + case ir.IntLiteral(o) => ir.Literal(o - 1) + case o => ir.Subtract(o, oneLiteral) + } + val pattern = args.lift(4) match { + case None => args(1) + case Some(ir.StringLiteral(regexParams)) => translateLiteralRegexParameters(regexParams, args(1)) + case Some(regexParams) => translateRegexParameters(regexParams, args(1)) + } + val groupNumber = args.lift(5).orElse(Some(zeroLiteral)) + ir.ArrayAccess(ir.RegExpExtractAll(subject, pattern, groupNumber), occurrence) + } + } + + private def translateLiteralRegexParameters(regexParams: String, pattern: ir.Expression): ir.Expression = { + val filtered = regexParams.foldLeft("") { case (agg, item) => + if (item == 'c') agg.filter(_ != 'i') + else if ("ism".contains(item)) agg + item + else agg + } + pattern match { + case ir.StringLiteral(pat) => ir.Literal(s"(?$filtered)$pat") + case e => ir.Concat(Seq(ir.Literal(s"(?$filtered)"), e)) + } + } + + /** + * regex_params may be any expression (a literal, but also a column, etc), this changes it to + * + * aggregate( + * split(regex_params, ''), + * cast(array() as array), + * (agg, item) -> + * case + * when item = 'c' then filter(agg, c -> c != 'i') + * when item in ('i', 's', 'm') then array_append(agg, item) + * else agg + * end, + * filtered -> '(?' || array_join(array_distinct(filtered), '') || ')' + * ) + */ + private def translateRegexParameters(regexParameters: ir.Expression, pattern: ir.Expression): ir.Expression = { + ir.ArrayAggregate( + ir.StringSplit(regexParameters, ir.Literal(""), None), + ir.Cast(ir.CreateArray(Seq()), ir.ArrayType(ir.StringType)), + ir.LambdaFunction( + ir.Case( + expression = None, + branches = Seq( + ir.WhenBranch( + ir.Equals(ir.Id("item"), ir.Literal("c")), + ir.ArrayFilter( + ir.Id("agg"), + ir.LambdaFunction( + ir.NotEquals(ir.Id("item"), ir.Literal("i")), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("item")))))), + ir.WhenBranch( + ir.In(ir.Id("item"), Seq(ir.Literal("i"), ir.Literal("s"), ir.Literal("m"))), + ir.ArrayAppend(ir.Id("agg"), ir.Id("item")))), + otherwise = Some(ir.Id("agg"))), + Seq(UnresolvedNamedLambdaVariable(Seq("agg")), UnresolvedNamedLambdaVariable(Seq("item")))), + ir.LambdaFunction( + ir.Concat(Seq(ir.Literal("(?"), ir.ArrayJoin(ir.Id("filtered"), ir.Literal("")), ir.Literal(")"), pattern)), + Seq(UnresolvedNamedLambdaVariable(Seq("filtered"))))) + } + + private def dateDiff(args: Seq[ir.Expression]): ir.Expression = { + val datePart = SnowflakeTimeUnits.translateDateOrTimePart(args.head) + ir.TimestampDiff(datePart, args(1), args(2)) + } + + private def tryToDate(args: Seq[ir.Expression]): ir.Expression = { + ir.CallFunction("DATE", Seq(ir.TryToTimestamp(args.head, args.lift(1)))) + } + + private def dateAdd(args: Seq[ir.Expression]): ir.Expression = { + if (args.size == 2) { + ir.DateAdd(args.head, args(1)) + } else if (args.size == 3) { + timestampAdd(args) + } else { + throw TranspileException(ir.WrongNumberOfArguments("DATEADD", args.size, "2 or 3")) + + } + } + + private def timestampAdd(args: Seq[ir.Expression]): ir.Expression = { + val dateOrTimePart = SnowflakeTimeUnits.translateDateOrTimePart(args.head) + ir.TimestampAdd(dateOrTimePart, args(1), args(2)) + } + + private def datePart(args: Seq[ir.Expression]): ir.Expression = { + val part = SnowflakeTimeUnits.translateDateOrTimePart(args.head) + ir.Extract(ir.Id(part), args(1)) + } + + private def dateTrunc(args: Seq[ir.Expression]): ir.Expression = { + val part = SnowflakeTimeUnits.translateDateOrTimePart(args.head) + ir.TruncTimestamp(ir.Literal(part.toUpperCase()), args(1)) + } + + private def makeTimestamp(args: Seq[ir.Expression]): ir.Expression = { + if (args.size == 2) { + // Snowflake's TIMESTAMP_FROM_PARTS can be invoked with only two arguments + // that, in this case, represent a date and a time. In such case, we need to + // extract the components of both date and time and feed them to MAKE_TIMESTAMP + // accordingly + val year = ir.DatePart(ir.Id("year"), args.head) + val month = ir.DatePart(ir.Id("month"), args.head) + val day = ir.DatePart(ir.Id("day"), args.head) + val hour = ir.Hour(args(1)) + val minute = ir.Minute(args(1)) + val second = ir.Second(args(1)) + ir.MakeTimestamp(year, month, day, hour, minute, second, None) + } else if (args.size == 6) { + ir.MakeTimestamp(args.head, args(1), args(2), args(3), args(4), args(5), None) + } else if (args.size == 7) { + // When call with individual parts (as opposed to the two-arguments scenario above) + // Snowflake allows for two additional optional parameters: an amount of nanoseconds + // and/or a timezone. So when we get 7 arguments, we need to inspect the last one to + // determine whether it's an amount of nanoseconds (ie. a number) or a timezone reference + // (ie. a string) + args(6) match { + case ir.IntLiteral(_) => + // We ignore that last parameter as DB SQL doesn't handle nanoseconds + // TODO warn the user about this + ir.MakeTimestamp(args.head, args(1), args(2), args(3), args(4), args(5), None) + case timezone @ ir.StringLiteral(_) => + ir.MakeTimestamp(args.head, args(1), args(2), args(3), args(4), args(5), Some(timezone)) + case _ => throw TranspileException(ir.UnsupportedArguments("TIMESTAMP_FROM_PART", Seq(args(6)))) + } + } else if (args.size == 8) { + // Here the situation is simpler, we just ignore the 7th argument (nanoseconds) + ir.MakeTimestamp(args.head, args(1), args(2), args(3), args(4), args(5), Some(args(7))) + } else { + throw TranspileException(ir.WrongNumberOfArguments("TIMESTAMP_FROM_PART", args.size, "either 2, 6, 7 or 8")) + } + } + + private def toTime(args: Seq[ir.Expression]): ir.Expression = { + val timeFormat = ir.Literal("HH:mm:ss") + args match { + case Seq(a) => + ir.DateFormatClass( + inferTemporalFormat(a, unsupportedAutoTimestampFormats ++ unsupportedAutoTimeFormats), + timeFormat) + case Seq(a, b) => ir.DateFormatClass(ir.ParseToTimestamp(a, Some(b)), timeFormat) + case _ => throw TranspileException(ir.WrongNumberOfArguments("TO_TIMESTAMP", args.size, "1 or 2")) + } + } + + private def toTimestamp(args: Seq[ir.Expression]): ir.Expression = args match { + case Seq(a) => inferTemporalFormat(a, unsupportedAutoTimestampFormats) + case Seq(a, lit: ir.Literal) => toTimestampWithLiteralFormat(a, lit) + case Seq(a, b) => toTimestampWithVariableFormat(a, b) + case _ => throw TranspileException(ir.WrongNumberOfArguments("TO_TIMESTAMP", args.size, "1 or 2")) + } + + private def toTimestampWithLiteralFormat(expression: ir.Expression, fmt: ir.Literal): ir.Expression = fmt match { + case num @ ir.IntLiteral(_) => + ir.ParseToTimestamp(expression, Some(ir.Pow(ir.Literal(10), num))) + case ir.StringLiteral(str) => + ir.ParseToTimestamp( + expression, + Some(ir.StringLiteral(temporalFormatMapping.foldLeft(str) { case (s, (sf, dbx)) => s.replace(sf, dbx) }))) + } + + private def toTimestampWithVariableFormat(expression: ir.Expression, fmt: ir.Expression): ir.Expression = { + val translatedFmt = temporalFormatMapping.foldLeft(fmt) { case (s, (sf, dbx)) => + ir.StringReplace(s, ir.Literal(sf), ir.Literal(dbx)) + } + ir.If( + ir.StartsWith(fmt, ir.Literal("DY")), + ir.ParseToTimestamp(ir.Substring(expression, ir.Literal(4)), Some(ir.Substring(translatedFmt, ir.Literal(4)))), + ir.ParseToTimestamp(expression, Some(translatedFmt))) + } + + // Timestamp formats that can be automatically inferred by Snowflake but not by Databricks + private[this] val unsupportedAutoTimestampFormats = Seq( + "yyyy-MM-dd'T'HH:mmXXX", + "yyyy-MM-dd HH:mmXXX", + "EEE, dd MMM yyyy HH:mm:ss ZZZ", + "EEE, dd MMM yyyy HH:mm:ss.SSSSSSSSS ZZZ", + "EEE, dd MMM yyyy hh:mm:ss a ZZZ", + "EEE, dd MMM yyyy hh:mm:ss.SSSSSSSSS a ZZZ", + "EEE, dd MMM yyyy HH:mm:ss", + "EEE, dd MMM yyyy HH:mm:ss.SSSSSSSSS", + "EEE, dd MMM yyyy hh:mm:ss a", + "EEE, dd MMM yyyy hh:mm:ss.SSSSSSSSS a", + "M/dd/yyyy HH:mm:ss", + "EEE MMM dd HH:mm:ss ZZZ yyyy") + + private[this] val unsupportedAutoTimeFormats = + Seq("HH:MM:ss.SSSSSSSSS", "HH:MM:ss", "HH:MM", "hh:MM:ss.SSSSSSSSS a", "hh:MM:ss a", "hh:MM a") + + // In Snowflake, when TO_TIME/TO_TIMESTAMP is called without a specific format, the system is capable of inferring the + // format from the string being parsed. Databricks has a similar behavior, but the set of formats it's capable of + // detecting automatically is narrower. + private def inferTemporalFormat(expression: ir.Expression, unsupportedAutoformats: Seq[String]): ir.Expression = + expression match { + // If the expression to be parsed is a Literal, we try the formats supported by Snowflake but not by Databricks + // and add an explicit parameter with the first that matches, or fallback to no format parameter if none has + // matched (which could indicate that either the implicit format is one Databricks can automatically infer, or the + // string to be parsed is malformed). + case ir.StringLiteral(timeStr) => + Try(timeStr.trim.toInt) + .map(int => ir.ParseToTimestamp(ir.Literal(int))) + .getOrElse( + ir.ParseToTimestamp( + expression, + unsupportedAutoformats + .find(fmt => Try(DateTimeFormatter.ofPattern(fmt).parse(timeStr)).isSuccess) + .map(ir.Literal(_)))) + // If the string to be parsed isn't a Literal, we do something similar but "at runtime". + case e => + ir.Case( + Some(ir.TypeOf(e)), + Seq( + ir.WhenBranch( + ir.Literal("string"), + ir.IfNull( + ir.Coalesce(ir.TryToTimestamp(ir.TryCast(e, ir.IntegerType)) +: unsupportedAutoformats.map( + makeAutoFormatExplicit(e, _))), + ir.ParseToTimestamp(e)))), + Some(ir.Cast(expression, ir.TimestampType))) + } + + private def makeAutoFormatExplicit(expr: ir.Expression, javaDateTimeFormatString: String): ir.Expression = + if (javaDateTimeFormatString.startsWith("EEE")) { + // Since version 3.0, Spark doesn't support day-of-week field in datetime parsing + // Considering that this is piece of information is irrelevant for parsing a timestamp + // we simply ignore it from the input string and the format. + ir.TryToTimestamp(ir.Substring(expr, ir.Literal(4)), Some(ir.Literal(javaDateTimeFormatString.substring(3)))) + } else { + ir.TryToTimestamp(expr, Some(ir.Literal(javaDateTimeFormatString))) + } + + private[this] val temporalFormatMapping = Seq( + "YYYY" -> "yyyy", + "YY" -> "yy", + "MON" -> "MMM", + "DD" -> "dd", + "DY" -> "EEE", // will be ignored down the line as it isn't supported anymore since Spark 3.0 + "HH24" -> "HH", + "HH12" -> "hh", + "AM" -> "a", + "PM" -> "a", + "MI" -> "mm", + "SS" -> "ss", + "FF9" -> "SSSSSSSSS", + "FF8" -> "SSSSSSSS", + "FF7" -> "SSSSSSS", + "FF6" -> "SSSSSS", + "FF5" -> "SSSSS", + "FF4" -> "SSSS", + "FF3" -> "SSS", + "FF2" -> "SS", + "FF1" -> "S", + "FF0" -> "", + "FF" -> "SSSSSSSSS", + "TZH:TZM" -> "ZZZ", + "TZHTZM" -> "ZZZ", + "TZH" -> "ZZZ", + "UUUU" -> "yyyy", + "\"" -> "'") + + private def dayname(args: Seq[ir.Expression]): ir.Expression = { + ir.DateFormatClass(args.head, ir.Literal("E")) + } + + private def toDate(args: Seq[ir.Expression]): ir.Expression = { + if (args.size == 1) { + ir.Cast(args.head, ir.DateType) + } else if (args.size == 2) { + ir.ParseToDate(args.head, Some(args(1))) + } else { + throw TranspileException(ir.WrongNumberOfArguments("TO_DATE", args.size, "1 or 2")) + } + } + + private def isInteger(args: Seq[ir.Expression]): ir.Expression = { + ir.Case( + None, + Seq( + ir.WhenBranch(ir.IsNull(args.head), ir.Literal.Null), + ir.WhenBranch( + ir.And(ir.RLike(args.head, ir.Literal("^-?[0-9]+$")), ir.IsNotNull(ir.TryCast(args.head, ir.IntegerType))), + ir.Literal(true))), + Some(ir.Literal(false))) + } + + private def toArray(args: Seq[ir.Expression]): ir.Expression = { + ir.If(ir.IsNull(args.head), ir.Literal.Null, ir.CreateArray(Seq(args.head))) + } + + private def toBoolean(args: Seq[ir.Expression]): ir.Expression = { + toBooleanLike(args.head, ir.RaiseError(ir.Literal("Invalid parameter type for TO_BOOLEAN"))) + } + + private def tryToBoolean(args: Seq[ir.Expression]): ir.Expression = { + toBooleanLike(args.head, ir.Literal.Null) + } + + private def toBooleanLike(arg: ir.Expression, otherwise: ir.Expression): ir.Expression = { + val castArgAsDouble = ir.Cast(arg, ir.DoubleType) + ir.Case( + None, + Seq( + ir.WhenBranch(ir.IsNull(arg), ir.Literal.Null), + ir.WhenBranch(ir.Equals(ir.TypeOf(arg), ir.Literal("boolean")), ir.CallFunction("BOOLEAN", Seq(arg))), + ir.WhenBranch( + ir.Equals(ir.TypeOf(arg), ir.Literal("string")), + ir.Case( + None, + Seq( + ir.WhenBranch( + ir.In( + ir.Lower(arg), + Seq( + ir.Literal("true"), + ir.Literal("t"), + ir.Literal("yes"), + ir.Literal("y"), + ir.Literal("on"), + ir.Literal("1"))), + ir.Literal(true)), + ir.WhenBranch( + ir.In( + ir.Lower(arg), + Seq( + ir.Literal("false"), + ir.Literal("f"), + ir.Literal("no"), + ir.Literal("n"), + ir.Literal("off"), + ir.Literal("0"))), + ir.Literal(false))), + Some(ir.RaiseError(ir.Literal(s"Boolean value of x is not recognized by TO_BOOLEAN"))))), + ir.WhenBranch( + ir.IsNotNull(ir.TryCast(arg, ir.DoubleType)), + ir.Case( + None, + Seq( + ir.WhenBranch( + ir.Or( + ir.IsNaN(castArgAsDouble), + ir.Equals(castArgAsDouble, ir.CallFunction("DOUBLE", Seq(ir.Literal("infinity"))))), + ir.RaiseError(ir.Literal("Invalid parameter type for TO_BOOLEAN")))), + Some(ir.NotEquals(castArgAsDouble, ir.DoubleLiteral(0.0d)))))), + Some(otherwise)) + } + + private def decode(args: Seq[ir.Expression]): ir.Expression = { + if (args.size >= 3) { + val expr = args.head + val groupedArgs = args.tail.sliding(2, 2).toList + ir.Case( + None, + groupedArgs.takeWhile(_.size == 2).map(l => makeWhenBranch(expr, l.head, l.last)), + groupedArgs.find(_.size == 1).map(_.head)) + } else { + throw TranspileException(ir.WrongNumberOfArguments("DECODE", args.size, "at least 3")) + } + } + + private def makeWhenBranch(expr: ir.Expression, cond: ir.Expression, out: ir.Expression): ir.WhenBranch = { + cond match { + case ir.Literal.Null => ir.WhenBranch(ir.IsNull(expr), out) + case any => ir.WhenBranch(ir.Equals(expr, any), out) + } + } + + private def arraySort(args: Seq[ir.Expression]): ir.Expression = { + makeArraySort(args.head, args.lift(1), args.lift(2)) + } + + private def makeArraySort( + arr: ir.Expression, + sortAscending: Option[ir.Expression], + nullsFirst: Option[ir.Expression]): ir.Expression = { + // Currently, only TRUE/FALSE Boolean literals are supported for Boolean parameters. + val paramSortAsc = sortAscending.getOrElse(ir.Literal.True) + val paramNullsFirst = nullsFirst.getOrElse { + paramSortAsc match { + case ir.Literal.True => ir.Literal.False + case ir.Literal.False => ir.Literal.True + case _ => throw TranspileException(ir.UnsupportedArguments("ARRAY_SORT", Seq(paramSortAsc))) + } + } + + def handleComparison(isNullOrSmallFirst: ir.Expression, nullOrSmallAtLeft: Boolean): ir.Expression = { + isNullOrSmallFirst match { + case ir.Literal.True => if (nullOrSmallAtLeft) ir.Literal(-1) else oneLiteral + case ir.Literal.False => if (nullOrSmallAtLeft) oneLiteral else ir.Literal(-1) + case _ => throw TranspileException(ir.UnsupportedArguments("ARRAY_SORT", Seq(isNullOrSmallFirst))) + } + } + + val comparator = ir.LambdaFunction( + ir.Case( + None, + Seq( + ir.WhenBranch(ir.And(ir.IsNull(ir.Id("left")), ir.IsNull(ir.Id("right"))), zeroLiteral), + ir.WhenBranch(ir.IsNull(ir.Id("left")), handleComparison(paramNullsFirst, nullOrSmallAtLeft = true)), + ir.WhenBranch(ir.IsNull(ir.Id("right")), handleComparison(paramNullsFirst, nullOrSmallAtLeft = false)), + ir.WhenBranch( + ir.LessThan(ir.Id("left"), ir.Id("right")), + handleComparison(paramSortAsc, nullOrSmallAtLeft = true)), + ir.WhenBranch( + ir.GreaterThan(ir.Id("left"), ir.Id("right")), + handleComparison(paramSortAsc, nullOrSmallAtLeft = false))), + Some(zeroLiteral)), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("left")), ir.UnresolvedNamedLambdaVariable(Seq("right")))) + + val irSortArray = (paramSortAsc, paramNullsFirst) match { + // We can make the IR much simpler for some cases + // by using DBSQL SORT_ARRAY function without needing a custom comparator + case (ir.Literal.True, ir.Literal.True) => ir.SortArray(arr, None) + case (ir.Literal.False, ir.Literal.False) => ir.SortArray(arr, Some(ir.Literal.False)) + case _ => ir.ArraySort(arr, comparator) + } + + irSortArray + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeTimeUnits.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeTimeUnits.scala new file mode 100644 index 0000000000..d615018fcd --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeTimeUnits.scala @@ -0,0 +1,44 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate.IRHelpers +import com.databricks.labs.remorph.{intermediate => ir} +import com.databricks.labs.remorph.transpilers.TranspileException + +object SnowflakeTimeUnits extends IRHelpers { + private[this] val dateOrTimeParts = Map( + Set("YEAR", "Y", "YY", "YYY", "YYYY", "YR", "YEARS", "YRS") -> "year", + Set("MONTH", "MM", "MON", "MONS", "MONTHS") -> "month", + Set("DAY", "D", "DD", "DAYS", "DAYOFMONTH") -> "day", + Set("DAYOFWEEK", "WEEKDAY", "DOW", "DW") -> "dayofweek", + Set("DAYOFWEEKISO", "WEEKDAY_ISO", "DOW_ISO", "DW_ISO") -> "dayofweekiso", + Set("DAYOFYEAR", "YEARDAY", "DOY", "DY") -> "dayofyear", + Set("WEEK", "W", "WK", "WEEKOFYEAR", "WOY", "WY") -> "week", + Set("WEEKISO", "WEEK_ISO", "WEEKOFYEARISO", "WEEKOFYEAR_ISO") -> "weekiso", + Set("QUARTER", "Q", "QTR", "QTRS", "QUARTERS") -> "quarter", + Set("YEAROFWEEK") -> "yearofweek", + Set("YEAROFWEEKISO") -> "yearofweekiso", + Set("HOUR", "H", "HH", "HR", "HOURS", "HRS") -> "hour", + Set("MINUTE", "M", "MI", "MIN", "MINUTES", "MINS") -> "minute", + Set("SECOND", "S", "SEC", "SECONDS", "SECS") -> "second", + Set("MILLISECOND", "MS", "MSEC", "MILLISECONDS") -> "millisecond", + Set("MICROSECOND", "US", "USEC", "MICROSECONDS") -> "microsecond", + Set("NANOSECOND", "NS", "NSEC", "NANOSEC", "NSECOND", "NANOSECONDS", "NANOSECS", "NSECONDS") -> "nanosecond", + Set("EPOCH_SECOND", "EPOCH", "EPOCH_SECONDS") -> "epoch_second", + Set("EPOCH_MILLISECOND", "EPOCH_MILLISECONDS") -> "epoch_millisecond", + Set("EPOCH_MICROSECOND", "EPOCH_MICROSECONDS") -> "epoch_microsecond", + Set("EPOCH_NANOSECOND", "EPOCH_NANOSECONDS") -> "epoch_nanosecond", + Set("TIMEZONE_HOUR", "TZH") -> "timezone_hour", + Set("TIMEZONE_MINUTE", "TZM") -> "timezone_minute") + + private def findDateOrTimePart(part: String): Option[String] = + dateOrTimeParts.find(_._1.contains(part.toUpperCase())).map(_._2) + + def translateDateOrTimePart(input: ir.Expression): String = input match { + case ir.Id(part, _) if SnowflakeTimeUnits.findDateOrTimePart(part).nonEmpty => + SnowflakeTimeUnits.findDateOrTimePart(part).get + case ir.StringLiteral(part) if SnowflakeTimeUnits.findDateOrTimePart(part).nonEmpty => + SnowflakeTimeUnits.findDateOrTimePart(part).get + case x => throw TranspileException(ir.UnsupportedDateTimePart(x)) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/TranslateWithinGroup.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/TranslateWithinGroup.scala new file mode 100644 index 0000000000..6764c7ec61 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/TranslateWithinGroup.scala @@ -0,0 +1,62 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate.UnresolvedNamedLambdaVariable +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.annotation.tailrec + +class TranslateWithinGroup extends ir.Rule[ir.LogicalPlan] { + + override def apply(plan: ir.LogicalPlan): ir.LogicalPlan = { + plan transformAllExpressions { + case ir.WithinGroup(ir.CallFunction("ARRAY_AGG", args), sorts) => sortArray(args.head, sorts) + case ir.WithinGroup(ir.CallFunction("LISTAGG", args), sorts) => + ir.ArrayJoin(sortArray(args.head, sorts), args(1)) + } + } + + private def sortArray(arg: ir.Expression, sort: Seq[ir.SortOrder]): ir.Expression = { + if (sort.size == 1 && sameReference(arg, sort.head.expr)) { + val sortOrder = if (sort.head.direction == ir.Descending) { Some(ir.Literal(false)) } + else { None } + ir.SortArray(ir.CollectList(arg), sortOrder) + } else { + + val namedStructFunc = ir.CreateNamedStruct(Seq(ir.Literal("value"), arg) ++ sort.zipWithIndex.flatMap { + case (s, index) => + Seq(ir.Literal(s"sort_by_$index"), s.expr) + }) + + ir.ArrayTransform( + ir.ArraySort(ir.CollectList(namedStructFunc), sortingLambda(sort)), + ir.LambdaFunction(ir.Dot(ir.Id("s"), ir.Id("value")), Seq(ir.UnresolvedNamedLambdaVariable(Seq("s"))))) + } + } + + @tailrec private def sameReference(left: ir.Expression, right: ir.Expression): Boolean = left match { + case ir.Distinct(e) => sameReference(e, right) + case l if l == right => true + case _ => false + } + + private def sortingLambda(sort: Seq[ir.SortOrder]): ir.Expression = { + ir.LambdaFunction( + ir.Case( + None, + sort.zipWithIndex.flatMap { case (s, index) => + Seq( + ir.WhenBranch( + ir.LessThan( + ir.Dot(ir.Id("left"), ir.Id(s"sort_by_$index")), + ir.Dot(ir.Id("right"), ir.Id(s"sort_by_$index"))), + if (s.direction == ir.Ascending) ir.Literal(-1) else ir.Literal(1)), + ir.WhenBranch( + ir.GreaterThan( + ir.Dot(ir.Id("left"), ir.Id(s"sort_by_$index")), + ir.Dot(ir.Id("right"), ir.Id(s"sort_by_$index"))), + if (s.direction == ir.Ascending) ir.Literal(1) else ir.Literal(-1))) + }, + Some(ir.Literal(0))), + Seq(UnresolvedNamedLambdaVariable(Seq("left")), UnresolvedNamedLambdaVariable(Seq("right")))) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/UpdateToMerge.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/UpdateToMerge.scala new file mode 100644 index 0000000000..bb21c53f14 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/rules/UpdateToMerge.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.intermediate.{Assign, Expression, Join, LogicalPlan, MergeAction, MergeIntoTable, Noop, NoopNode, Rule, UpdateAction, UpdateTable} + +class UpdateToMerge extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case update @ UpdateTable(_, None, _, _, _, _) => update + case update: UpdateTable => + MergeIntoTable(update.target, source(update), condition(update), matchedActions = matchedActions(update)) + } + + private def matchedActions(update: UpdateTable): Seq[MergeAction] = { + val set = update.set.collect { case a: Assign => a } + Seq(UpdateAction(None, set)) + } + + private def source(update: UpdateTable): LogicalPlan = update.source match { + case Some(plan) => + plan match { + case Join(_, source, _, _, _, _) => + // TODO: figure out why there's a join in the update plan + source + case _ => plan + } + case None => NoopNode + } + + private def condition(update: UpdateTable): Expression = update.where match { + case Some(condition) => condition + case None => Noop + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/DataTypeBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/DataTypeBuilder.scala new file mode 100644 index 0000000000..e84a304496 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/DataTypeBuilder.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.DataTypeContext +import com.databricks.labs.remorph.{intermediate => ir} + +class DataTypeBuilder { + + def build(ctx: DataTypeContext): ir.DataType = + ctx match { + case ident if ident.dataTypeIdentity() != null => buildIdentity(ident.dataTypeIdentity()) + case template if template.jinjaTemplate() != null => buildTemplate(template.jinjaTemplate()) + case _ => buildScalar(ctx) + + } + + private def buildScalar(ctx: DataTypeContext): ir.DataType = { + + val lenOpt = Option(ctx.INT(0)) map (_.getText.toInt) // A single length parameter + val scaleOpt = Option(ctx.INT(1)) map (_.getText.toInt) // A single scale parameter + + val typeDefinition = ctx.id().getText + typeDefinition.toLowerCase() match { + case "tinyint" => ir.ByteType(size = Some(1)) + case "smallint" => ir.ShortType + case "int" => ir.IntegerType + case "bigint" => ir.LongType + case "bit" => ir.BooleanType + case "money" => ir.DecimalType(precision = Some(19), scale = Some(4)) // Equivalent money + case "smallmoney" => ir.DecimalType(precision = Some(10), scale = Some(4)) // Equivalent smallmoney + case "float" => ir.FloatType + case "real" => ir.DoubleType + case "date" => ir.DateType + case "time" => ir.TimeType + case "datetime" => ir.TimestampType + case "datetime2" => ir.TimestampType + case "datetimeoffset" => ir.StringType // TODO: No direct equivalent + case "smalldatetime" => ir.TimestampType // Equivalent smalldatetime + case "char" => ir.CharType(size = lenOpt) + case "varchar" => ir.VarcharType(size = lenOpt) + case "nchar" => ir.CharType(size = lenOpt) + case "nvarchar" => ir.VarcharType(size = lenOpt) + case "text" => ir.VarcharType(None) + case "ntext" => ir.VarcharType(None) + case "image" => ir.BinaryType + case "decimal" | "numeric" => ir.DecimalType(precision = lenOpt, scale = scaleOpt) // Equivalent decimal + case "binary" => ir.BinaryType + case "varbinary" => ir.BinaryType + case "json" => ir.VarcharType(None) + case "uniqueidentifier" => ir.VarcharType(size = Some(16)) // Equivalent uniqueidentifier + case _ => ir.UnparsedType(typeDefinition) + } + } + + private def buildIdentity(ctx: TSqlParser.DataTypeIdentityContext): ir.DataType = + // As of right now, there is no way to implement the IDENTITY property declared as a column type in TSql + ir.UnparsedType(ctx.getText) + + private def buildTemplate(ctx: TSqlParser.JinjaTemplateContext): ir.DataType = + ir.JinjaAsDataType(ctx.getText) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/OptionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/OptionBuilder.scala new file mode 100644 index 0000000000..66ce8f8e09 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/OptionBuilder.scala @@ -0,0 +1,50 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.GenericOptionContext +import com.databricks.labs.remorph.{intermediate => ir} + +class OptionBuilder(vc: TSqlVisitorCoordinator) { + + private[tsql] def buildOptionList(opts: Seq[GenericOptionContext]): ir.OptionLists = { + val options = opts.map(this.buildOption) + val (stringOptions, boolFlags, autoFlags, exprValues) = options.foldLeft( + (Map.empty[String, String], Map.empty[String, Boolean], List.empty[String], Map.empty[String, ir.Expression])) { + case ((stringOptions, boolFlags, autoFlags, values), option) => + option match { + case ir.OptionString(key, value) => + (stringOptions + (key -> value.stripPrefix("'").stripSuffix("'")), boolFlags, autoFlags, values) + case ir.OptionOn(id) => (stringOptions, boolFlags + (id -> true), autoFlags, values) + case ir.OptionOff(id) => (stringOptions, boolFlags + (id -> false), autoFlags, values) + case ir.OptionAuto(id) => (stringOptions, boolFlags, id :: autoFlags, values) + case ir.OptionExpression(id, expr, _) => (stringOptions, boolFlags, autoFlags, values + (id -> expr)) + case _ => (stringOptions, boolFlags, autoFlags, values) + } + } + new ir.OptionLists(exprValues, stringOptions, boolFlags, autoFlags) + } + + private[tsql] def buildOption(ctx: TSqlParser.GenericOptionContext): ir.GenericOption = { + val id = ctx.id(0).getText.toUpperCase() + ctx match { + case c if c.DEFAULT() != null => ir.OptionDefault(id) + case c if c.ON() != null => ir.OptionOn(id) + case c if c.OFF() != null => ir.OptionOff(id) + case c if c.AUTO() != null => ir.OptionAuto(id) + case c if c.STRING() != null => ir.OptionString(id, c.STRING().getText) + + // FOR cannot be allowed as an id as it clashes with the FOR clause in SELECT et al. So + // we special case it here and elide the FOR. It handles just a few things such as OPTIMIZE FOR UNKNOWN, + // which becomes "OPTIMIZE", Id(UNKNOWN) + case c if c.FOR() != null => ir.OptionExpression(id, vc.expressionBuilder.buildId(c.id(1)), None) + case c if c.expression() != null => + val supplement = if (c.id(1) != null) Some(ctx.id(1).getText) else None + ir.OptionExpression(id, c.expression().accept(vc.expressionBuilder), supplement) + case _ if id == "DEFAULT" => ir.OptionDefault(id) + case _ if id == "ON" => ir.OptionOn(id) + case _ if id == "OFF" => ir.OptionOff(id) + case _ if id == "AUTO" => ir.OptionAuto(id) + // All other cases being OptionOn as it is a single keyword representing true + case _ => ir.OptionOn(id) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilder.scala new file mode 100644 index 0000000000..03f3d83955 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilder.scala @@ -0,0 +1,118 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate.LogicalPlan +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters.asScalaBufferConverter + +/** + * @see + * org.apache.spark.sql.catalyst.parser.AstBuilder + */ +class TSqlAstBuilder(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[ir.LogicalPlan] + with ParserCommon[ir.LogicalPlan] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.LogicalPlan = + ir.UnresolvedRelation(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitTSqlFile(ctx: TSqlParser.TSqlFileContext): ir.LogicalPlan = { + + // This very top level visitor does not ignore any valid statements for the batch, instead + // we prepend any errors to the batch plan, so they are generated first in the output. + val errors = errorCheck(ctx) + val batchPlan = Option(ctx.batch()).map(buildBatch).getOrElse(Seq.empty) + errors match { + case Some(errorResult) => ir.Batch(errorResult +: batchPlan) + case None => ir.Batch(batchPlan) + } + } + + private def buildBatch(ctx: TSqlParser.BatchContext): Seq[LogicalPlan] = { + + // This very top level visitor does not ignore any valid statements for the batch, instead + // we prepend any errors to the batch plan, so they are generated first in the output. + val errors = errorCheck(ctx) + val executeBodyBatchPlan = Option(ctx.executeBodyBatch()).map(_.accept(this)) + val sqlClausesPlans = ctx.sqlClauses().asScala.map(_.accept(this)).collect { case p: ir.LogicalPlan => p } + + val validStatements = executeBodyBatchPlan match { + case Some(plan) => plan +: sqlClausesPlans + case None => sqlClausesPlans + } + errors match { + case Some(errorResult) => errorResult +: validStatements + case None => validStatements + } + } + + // TODO: Stored procedure calls etc as batch start + override def visitExecuteBodyBatch(ctx: TSqlParser.ExecuteBodyBatchContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case templ if templ.jinjaTemplate() != null => templ.jinjaTemplate().accept(this) + case _ => + ir.UnresolvedRelation( + ruleText = contextText(ctx), + message = "Execute body batch is not supported yet", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + override def visitJinjaTemplate(ctx: TSqlParser.JinjaTemplateContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.JinjaAsStatement(ctx.getText) + } + + override def visitSqlClauses(ctx: TSqlParser.SqlClausesContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case dml if dml.dmlClause() != null => dml.dmlClause().accept(this) + case cfl if cfl.cflStatement() != null => cfl.cflStatement().accept(this) + case another if another.anotherStatement() != null => another.anotherStatement().accept(this) + case ddl if ddl.ddlClause() != null => ddl.ddlClause().accept(vc.ddlBuilder) + case dbcc if dbcc.dbccClause() != null => dbcc.dbccClause().accept(this) + case backup if backup.backupStatement() != null => backup.backupStatement().accept(vc.ddlBuilder) + case coaFunction if coaFunction.createOrAlterFunction() != null => + coaFunction.createOrAlterFunction().accept(this) + case coaProcedure if coaProcedure.createOrAlterProcedure() != null => + coaProcedure.createOrAlterProcedure().accept(this) + case coaTrigger if coaTrigger.createOrAlterTrigger() != null => coaTrigger.createOrAlterTrigger().accept(this) + case cv if cv.createView() != null => cv.createView().accept(this) + case go if go.goStatement() != null => go.goStatement().accept(this) + case _ => + ir.UnresolvedRelation( + ruleText = contextText(ctx), + message = s"Unknown SQL clause ${ctx.getStart.getText} in TSqlAstBuilder.visitSqlClauses", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + override def visitDmlClause(ctx: TSqlParser.DmlClauseContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val dml = ctx match { + case dml if dml.selectStatement() != null => + dml.selectStatement().accept(vc.relationBuilder) + case _ => + ctx.accept(vc.dmlBuilder) + } + + Option(ctx.withExpression()) + .map { withExpression => + val ctes = withExpression.commonTableExpression().asScala.map(_.accept(vc.relationBuilder)) + ir.WithCTE(ctes, dml) + } + .getOrElse(dml) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilder.scala new file mode 100644 index 0000000000..4edfe19f08 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilder.scala @@ -0,0 +1,646 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate.Catalog +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters.asScalaBufferConverter + +class TSqlDDLBuilder(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[ir.Catalog] + with ParserCommon[ir.Catalog] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Catalog = + ir.UnresolvedCatalog(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitCreateTable(ctx: TSqlParser.CreateTableContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case ci if ci.createInternal() != null => ci.createInternal().accept(this) + case ct if ct.createExternal() != null => ct.createExternal().accept(this) + case _ => + ir.UnresolvedCatalog( + ruleText = contextText(ctx), + message = "Unknown CREATE TABLE variant", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + override def visitCreateInternal(ctx: TSqlParser.CreateInternalContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = ctx.tableName().getText + + val (columns, virtualColumns, constraints, indices) = Option(ctx.columnDefTableConstraints()).toSeq + .flatMap(_.columnDefTableConstraint().asScala) + .foldLeft((Seq.empty[TSqlColDef], Seq.empty[TSqlColDef], Seq.empty[ir.Constraint], Seq.empty[ir.Constraint])) { + case ((cols, virtualCols, cons, inds), constraint) => + val newCols = constraint.columnDefinition() match { + case null => cols + case columnDef => + cols :+ buildColumnDeclaration(columnDef) + } + + val newVirtualCols = constraint.computedColumnDefinition() match { + case null => virtualCols + case computedCol => + virtualCols :+ buildComputedColumn(computedCol) + } + + val newCons = constraint.tableConstraint() match { + case null => cons + case tableCons => cons :+ buildTableConstraint(tableCons) + } + + val newInds = constraint.tableIndices() match { + case null => inds + case tableInds => inds :+ buildIndex(tableInds) + } + + (newCols, newVirtualCols, newCons, newInds) + } + + // At this point we have all the columns, constraints and indices, so we can build the schema + val schema = ir.StructType((columns ++ virtualColumns).map(_.structField)) + + // Now we can build the create table statement or the create table as select statement + + val createTable = ctx.createTableAs() match { + case null => ir.CreateTable(tableName, None, None, None, schema) + case ctas if ctas.selectStatementStandalone() != null => + ir.CreateTableAsSelect( + tableName, + ctas.selectStatementStandalone().accept(vc.relationBuilder), + None, + None, + None) + case _ => + ir.UnresolvedCatalog( + ruleText = contextText(ctx), + message = "Unknown variant of CREATE TABLE is not yet supported", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + + // But because TSQL is so much more complicated than Databricks SQL, we need to build the table alterations + // in a wrapper above the raw create statement. + + // First we want to iterate all the columns and build a map of all the column constraints where the key is the + // element structField.name and the value is the TSqlColDef.constraints + val columnConstraints = (columns ++ virtualColumns).map { colDef => + colDef.structField.name -> colDef.constraints + }.toMap + + // Next we create another map all options for each column where the key is the element structField.name and the + // value is the TSqlColDef.options + val columnOptions = (columns ++ virtualColumns).map { colDef => + colDef.structField.name -> colDef.options + }.toMap + + // And we want to collect any table level constraints that were generated in the TSqlColDef.tableConstraints + // by iterating columns and virtualColumns and gathering any TSqlColDef tableConstraints and + // creating a single Seq that also includes the constraints + // that were already accumulated above + val tableConstraints = constraints ++ (columns ++ virtualColumns).flatMap(_.tableConstraints) + + // We may have table level options as well as for each constraint and column + val options: Option[Seq[ir.GenericOption]] = Option(ctx.tableOptions).map { tableOptions => + tableOptions.asScala.flatMap { el => + el.tableOption().asScala.map(buildOption) + } + } + + val partitionOn = Option(ctx.onPartitionOrFilegroup()).map(_.getText) + + // Now we can build the table additions that wrap the primitive create table statement + createTable match { + case ct: ir.UnresolvedCatalog => + ct + case _ => + ir.CreateTableParams( + createTable, + columnConstraints, + columnOptions, + tableConstraints, + indices, + partitionOn, + options) + } + } + + override def visitCreateExternal(ctx: TSqlParser.CreateExternalContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.UnresolvedCatalog( + ruleText = contextText(ctx), + message = "CREATE EXTERNAL TABLE is not yet supported", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + + /** + * There seems to be no options given to TSQL: CREATE TABLE that make any sense in Databricks SQL, so we will just + * create them all as OptionUnresolved and store them as comments. If there turns out to be any compatibility, we can + * override the generation here. + * @param ctx the parse tree + * @return the option we have parsed + */ + private def buildOption(ctx: TSqlParser.TableOptionContext): ir.GenericOption = { + ir.OptionUnresolved(contextText(ctx)) + } + + private case class TSqlColDef( + structField: ir.StructField, + computedValue: Option[ir.Expression], + constraints: Seq[ir.Constraint], + tableConstraints: Seq[ir.Constraint], + options: Seq[ir.GenericOption]) + + private def buildColumnDeclaration(ctx: TSqlParser.ColumnDefinitionContext): TSqlColDef = { + + val options = Seq.newBuilder[ir.GenericOption] + val constraints = Seq.newBuilder[ir.Constraint] + val tableConstraints = Seq.newBuilder[ir.Constraint] + var nullable: Option[Boolean] = None + + if (ctx.columnDefinitionElement() != null) { + ctx.columnDefinitionElement().asScala.foreach { + case rg if rg.ROWGUIDCOL() != null => + // ROWGUID is supported in Databricks SQL + options += ir.OptionOn("ROWGUIDCOL") + + case d if d.defaultValue() != null => + // Databricks SQL does not support the naming of the DEFAULT CONSTRAINT, so we will just use the default + // expression we are given, but if there is a name, we will store it as a comment + constraints += ir.DefaultValueConstraint(d.defaultValue().expression().accept(vc.expressionBuilder)) + if (d.defaultValue().id() != null) { + options += ir.OptionUnresolved( + s"Databricks SQL cannot name the DEFAULT CONSTRAINT ${d.defaultValue().id().getText}") + } + + case c if c.columnConstraint() != null => + // For some reason TSQL supports the naming of NOT NULL constraints, but if it is named + // we can generate a check constraint that is named to enforce the NOT NULL constraint. + if (c.columnConstraint().NULL() != null) { + nullable = if (c.columnConstraint().NOT() != null) { + + if (c.columnConstraint().id() != null) { + // If the nullable constraint is named, then we will generate a table level CHECK constraint + // to enforce it. + // So we will use true here so NOT NULL it is not specified in the column definition, then the + // table level CHECK can be named and therefore altered and dropped from the table as per TSQL. + tableConstraints += ir.NamedConstraint( + c.columnConstraint().id().getText, + ir.CheckConstraint(ir.IsNotNull(ir.Column(None, ir.Id(ctx.id().getText))))) + Some(true) + } else { + Some(false) + } + } else { + Some(true) + } + } else { + val con = buildColumnConstraint(ctx.id().getText, c.columnConstraint()) + + // TSQL allows FOREIGN KEY and CHECK constraints to be declared as column constraints, + // but Databricks SQL does not so we need to gather them as table constraints + + if (c.columnConstraint().FOREIGN() != null || c.columnConstraint().checkConstraint() != null) { + tableConstraints += con + } else { + constraints += con + } + } + + case d if d.identityColumn() != null => + // IDENTITY is supported in Databricks SQL but is done via GENERATED constraint + constraints += ir + .IdentityConstraint(Some(d.identityColumn().INT(0).getText), Some(d.identityColumn().INT(1).getText)) + + // Unsupported stuff + case m if m.MASKED() != null => + // MASKED WITH FUNCTION = 'functionName' is not supported in Databricks SQL + options += ir.OptionUnresolved(s"Unsupported Option: ${contextText(m)}") + + case f if f.ENCRYPTED() != null => + // ENCRYPTED WITH ... is not supported in Databricks SQL + options += ir.OptionUnresolved(s"Unsupported Option: ${contextText(f)}") + + case o if o.genericOption() != null => + options += ir.OptionUnresolved(s"Unsupported Option: ${contextText(o)}") + } + } + val dataType = vc.dataTypeBuilder.build(ctx.dataType()) + val sf = ir.StructField(ctx.id().getText, dataType, nullable.getOrElse(true)) + + // TODO: index options + + TSqlColDef(sf, None, constraints.result(), tableConstraints.result(), options.result()) + } + + /** + * builds a table constraint such as PRIMARY KEY, UNIQUE, FOREIGN KEY + */ + private def buildTableConstraint(ctx: TSqlParser.TableConstraintContext): ir.Constraint = { + + val options = Seq.newBuilder[ir.GenericOption] + + val constraint = ctx match { + + case pu if pu.PRIMARY() != null || pu.UNIQUE() != null => + if (pu.clustered() != null) { + if (pu.clustered().CLUSTERED() != null) { + options += ir.OptionUnresolved(contextText(pu.clustered())) + } + } + val colNames = ctx.columnNameListWithOrder().columnNameWithOrder().asScala.map { cnwo => + val colName = cnwo.id().getText + if (cnwo.DESC() != null || cnwo.ASC() != null) { + options += ir.OptionUnresolved(s"Cannot specify primary key order ASC/DESC on: $colName") + } + colName + } + options ++= buildPKOptions(pu.primaryKeyOptions()) + if (pu.PRIMARY() != null) { + ir.PrimaryKey(options.result(), Some(colNames)) + } else { + ir.Unique(options.result(), Some(colNames)) + } + + case fk if fk.FOREIGN() != null => + val refObject = fk.foreignKeyOptions().tableName().getText + val tableCols = fk.columnNameList().id().asScala.map(_.getText).mkString(", ") + val refCols = Option(fk.foreignKeyOptions()) + .map(_.columnNameList().id().asScala.map(_.getText).mkString(", ")) + .getOrElse("") + if (fk.foreignKeyOptions().onDelete() != null) { + options += buildFkOnDelete(fk.foreignKeyOptions().onDelete()) + } + if (fk.foreignKeyOptions().onUpdate() != null) { + options += buildFkOnUpdate(fk.foreignKeyOptions().onUpdate()) + } + ir.ForeignKey(tableCols, refObject, refCols, options.result()) + + case cc if cc.CONNECTION() != null => + // CONNECTION is not supported in Databricks SQL + ir.UnresolvedConstraint(contextText(ctx)) + + case defVal if defVal.DEFAULT() != null => + // DEFAULT is not supported in Databricks SQL at TABLE constraint level + ir.UnresolvedConstraint(contextText(ctx)) + + case cc if cc.checkConstraint() != null => + // Check constraint construction + val expr = cc.checkConstraint().searchCondition().accept(vc.expressionBuilder) + if (cc.checkConstraint().NOT() != null) { + options += ir.OptionUnresolved("NOT FOR REPLICATION") + } + ir.CheckConstraint(expr) + + case _ => ir.UnresolvedConstraint(contextText(ctx)) + } + + // Name the constraint if it is named and not unresolved + ctx.CONSTRAINT() match { + case null => constraint + case _ => + constraint match { + case _: ir.UnresolvedConstraint => constraint + case _ => ir.NamedConstraint(ctx.cid.getText, constraint) + } + } + + } + + /** + * Builds a column constraint such as PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK + * + * Note that TSQL is way more involved than Databricks SQL. We must record all the different options so that we can + * at least generate a comment. + * + * TSQL allows FOREIGN KEY and CHECK constraints to be declared as column constraints, but Databricks SQL does not + * So the caller needs to check for those circumstances and handle them accordingly. + * + * @param ctx + * the parser context + * @return + * a constraint definition + */ + private def buildColumnConstraint(colName: String, ctx: TSqlParser.ColumnConstraintContext): ir.Constraint = { + val options = Seq.newBuilder[ir.GenericOption] + val constraint = ctx match { + case pu if pu.PRIMARY() != null || pu.UNIQUE() != null => + // Primary or unique key construction. + if (pu.clustered() != null) { + options += ir.OptionUnresolved(contextText(pu.clustered())) + } + + if (pu.primaryKeyOptions() != null) { + options ++= buildPKOptions(pu.primaryKeyOptions()) + } + + if (pu.PRIMARY() != null) { + ir.PrimaryKey(options.result()) + } else { + ir.Unique(options.result()) + } + + case fk if fk.FOREIGN() != null => + // Foreign key construction - note that this is a table level constraint in Databricks SQL + val refObject = fk.foreignKeyOptions().tableName().getText + val refCols = Option(fk.foreignKeyOptions()) + .map(_.columnNameList().id().asScala.map(_.getText).mkString(",")) + .getOrElse("") + if (fk.foreignKeyOptions().onDelete() != null) { + options += buildFkOnDelete(fk.foreignKeyOptions().onDelete()) + } + if (fk.foreignKeyOptions().onUpdate() != null) { + options += buildFkOnUpdate(fk.foreignKeyOptions().onUpdate()) + } + ir.ForeignKey(colName, refObject, refCols, options.result()) + + case cc if cc.checkConstraint() != null => + // Check constraint construction (will be gathered as a table level constraint) + val expr = cc.checkConstraint().searchCondition().accept(vc.expressionBuilder) + if (cc.checkConstraint().NOT() != null) { + options += ir.OptionUnresolved("NOT FOR REPLICATION") + } + ir.CheckConstraint(expr) + + case _ => ir.UnresolvedConstraint(contextText(ctx)) + } + + // Name the constraint if it is named and not unresolved + ctx.CONSTRAINT() match { + case null => constraint + case _ => + constraint match { + case _: ir.UnresolvedConstraint => constraint + case _ => ir.NamedConstraint(ctx.id.getText, constraint) + } + } + } + + private def buildFkOnDelete(ctx: TSqlParser.OnDeleteContext): ir.GenericOption = { + ctx match { + case c if c.CASCADE() != null => ir.OptionUnresolved("ON DELETE CASCADE") + case c if c.NULL() != null => ir.OptionUnresolved("ON DELETE SET NULL") + case c if c.DEFAULT() != null => ir.OptionUnresolved("ON DELETE SET DEFAULT") + case c if c.NO() != null => ir.OptionString("ON DELETE", "NO ACTION") + } + } + + private def buildFkOnUpdate(ctx: TSqlParser.OnUpdateContext): ir.GenericOption = { + ctx match { + case c if c.CASCADE() != null => ir.OptionUnresolved("ON UPDATE CASCADE") + case c if c.NULL() != null => ir.OptionUnresolved("ON UPDATE SET NULL") + case c if c.DEFAULT() != null => ir.OptionUnresolved("ON UPDATE SET DEFAULT") + case c if c.NO() != null => ir.OptionString("ON UPDATE", "NO ACTION") + } + } + + private def buildPKOptions(ctx: TSqlParser.PrimaryKeyOptionsContext): Seq[ir.GenericOption] = { + val options = Seq.newBuilder[ir.GenericOption] + if (ctx.FILLFACTOR() != null) { + options += ir.OptionUnresolved(s"WITH FILLFACTOR = ${ctx.getText}") + } + // TODO: index options + // TODO: partition options + options.result() + } + + private def buildComputedColumn(ctx: TSqlParser.ComputedColumnDefinitionContext): TSqlColDef = { + null + } + + /** + * Abstracted out here but Spark/Databricks SQL does not support indexes, as it is not a database and cannot reliably + * monitor data updates in external systems anyway. So it becomes an unresolved constraint here, but perhaps + * we do something moore with it later. + * + * @param ctx the parse tree + * @return An unresolved constraint representing the index syntax + */ + private def buildIndex(ctx: TSqlParser.TableIndicesContext): ir.UnresolvedConstraint = { + ir.UnresolvedConstraint(contextText(ctx)) + } + + /** + * This is not actually implemented but was a quick way to exercise the genericOption builder before we had other + * syntax implemented to test it with. + * + * @param ctx + * the parse tree + */ + override def visitBackupStatement(ctx: TSqlParser.BackupStatementContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.backupDatabase().accept(this) + } + + override def visitBackupDatabase(ctx: TSqlParser.BackupDatabaseContext): ir.Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val database = ctx.id().getText + val opts = ctx.optionList() + val options = opts.asScala.flatMap(_.genericOption().asScala).toList.map(vc.optionBuilder.buildOption) + val (disks, boolFlags, autoFlags, values) = options.foldLeft( + (List.empty[String], Map.empty[String, Boolean], List.empty[String], Map.empty[String, ir.Expression])) { + case ((disks, boolFlags, autoFlags, values), option) => + option match { + case ir.OptionString("DISK", value) => + (value.stripPrefix("'").stripSuffix("'") :: disks, boolFlags, autoFlags, values) + case ir.OptionOn(id) => (disks, boolFlags + (id -> true), autoFlags, values) + case ir.OptionOff(id) => (disks, boolFlags + (id -> false), autoFlags, values) + case ir.OptionAuto(id) => (disks, boolFlags, id :: autoFlags, values) + case ir.OptionExpression(id, expr, _) => (disks, boolFlags, autoFlags, values + (id -> expr)) + case _ => (disks, boolFlags, autoFlags, values) + } + } + // Default flags generally don't need to be specified as they are by definition, the default + BackupDatabase(database, disks, boolFlags, autoFlags, values) + } + + override def visitDdlClause(ctx: TSqlParser.DdlClauseContext): Catalog = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case c if c.alterApplicationRole() != null => c.alterApplicationRole().accept(this) + case c if c.alterAssembly() != null => c.alterAssembly().accept(this) + case c if c.alterAsymmetricKey() != null => c.alterAsymmetricKey().accept(this) + case c if c.alterAuthorization() != null => c.alterAuthorization().accept(this) + case c if c.alterAvailabilityGroup() != null => c.alterAvailabilityGroup().accept(this) + case c if c.alterCertificate() != null => c.alterCertificate().accept(this) + case c if c.alterColumnEncryptionKey() != null => c.alterColumnEncryptionKey().accept(this) + case c if c.alterCredential() != null => c.alterCredential().accept(this) + case c if c.alterCryptographicProvider() != null => c.alterCryptographicProvider().accept(this) + case c if c.alterDatabase() != null => c.alterDatabase().accept(this) + case c if c.alterDatabaseAuditSpecification() != null => c.alterDatabaseAuditSpecification().accept(this) + case c if c.alterDbRole() != null => c.alterDbRole().accept(this) + case c if c.alterEndpoint() != null => c.alterEndpoint().accept(this) + case c if c.alterExternalDataSource() != null => c.alterExternalDataSource().accept(this) + case c if c.alterExternalLibrary() != null => c.alterExternalLibrary().accept(this) + case c if c.alterExternalResourcePool() != null => c.alterExternalResourcePool().accept(this) + case c if c.alterFulltextCatalog() != null => c.alterFulltextCatalog().accept(this) + case c if c.alterFulltextStoplist() != null => c.alterFulltextStoplist().accept(this) + case c if c.alterIndex() != null => c.alterIndex().accept(this) + case c if c.alterLoginAzureSql() != null => c.alterLoginAzureSql().accept(this) + case c if c.alterLoginAzureSqlDwAndPdw() != null => c.alterLoginAzureSqlDwAndPdw().accept(this) + case c if c.alterLoginSqlServer() != null => c.alterLoginSqlServer().accept(this) + case c if c.alterMasterKeyAzureSql() != null => c.alterMasterKeyAzureSql().accept(this) + case c if c.alterMasterKeySqlServer() != null => c.alterMasterKeySqlServer().accept(this) + case c if c.alterMessageType() != null => c.alterMessageType().accept(this) + case c if c.alterPartitionFunction() != null => c.alterPartitionFunction().accept(this) + case c if c.alterPartitionScheme() != null => c.alterPartitionScheme().accept(this) + case c if c.alterRemoteServiceBinding() != null => c.alterRemoteServiceBinding().accept(this) + case c if c.alterResourceGovernor() != null => c.alterResourceGovernor().accept(this) + case c if c.alterSchemaAzureSqlDwAndPdw() != null => c.alterSchemaAzureSqlDwAndPdw().accept(this) + case c if c.alterSchemaSql() != null => c.alterSchemaSql().accept(this) + case c if c.alterSequence() != null => c.alterSequence().accept(this) + case c if c.alterServerAudit() != null => c.alterServerAudit().accept(this) + case c if c.alterServerAuditSpecification() != null => c.alterServerAuditSpecification().accept(this) + case c if c.alterServerConfiguration() != null => c.alterServerConfiguration().accept(this) + case c if c.alterServerRole() != null => c.alterServerRole().accept(this) + case c if c.alterServerRolePdw() != null => c.alterServerRolePdw().accept(this) + case c if c.alterService() != null => c.alterService().accept(this) + case c if c.alterServiceMasterKey() != null => c.alterServiceMasterKey().accept(this) + case c if c.alterSymmetricKey() != null => c.alterSymmetricKey().accept(this) + case c if c.alterTable() != null => c.alterTable().accept(this) + case c if c.alterUser() != null => c.alterUser().accept(this) + case c if c.alterUserAzureSql() != null => c.alterUserAzureSql().accept(this) + case c if c.alterWorkloadGroup() != null => c.alterWorkloadGroup().accept(this) + case c if c.alterXmlSchemaCollection() != null => c.alterXmlSchemaCollection().accept(this) + case c if c.createApplicationRole() != null => c.createApplicationRole().accept(this) + case c if c.createAssembly() != null => c.createAssembly().accept(this) + case c if c.createAsymmetricKey() != null => c.createAsymmetricKey().accept(this) + case c if c.createColumnEncryptionKey() != null => c.createColumnEncryptionKey().accept(this) + case c if c.createColumnMasterKey() != null => c.createColumnMasterKey().accept(this) + case c if c.createColumnstoreIndex() != null => c.createColumnstoreIndex().accept(this) + case c if c.createCredential() != null => c.createCredential().accept(this) + case c if c.createCryptographicProvider() != null => c.createCryptographicProvider().accept(this) + case c if c.createDatabaseScopedCredential() != null => c.createDatabaseScopedCredential().accept(this) + case c if c.createDatabase() != null => c.createDatabase().accept(this) + case c if c.createDatabaseAuditSpecification() != null => c.createDatabaseAuditSpecification().accept(this) + case c if c.createDbRole() != null => c.createDbRole().accept(this) + case c if c.createEndpoint() != null => c.createEndpoint().accept(this) + case c if c.createEventNotification() != null => c.createEventNotification().accept(this) + case c if c.createExternalLibrary() != null => c.createExternalLibrary().accept(this) + case c if c.createExternalResourcePool() != null => c.createExternalResourcePool().accept(this) + case c if c.createExternalDataSource() != null => c.createExternalDataSource().accept(this) + case c if c.createFulltextCatalog() != null => c.createFulltextCatalog().accept(this) + case c if c.createFulltextStoplist() != null => c.createFulltextStoplist().accept(this) + case c if c.createIndex() != null => c.createIndex().accept(this) + case c if c.createLoginAzureSql() != null => c.createLoginAzureSql().accept(this) + case c if c.createLoginPdw() != null => c.createLoginPdw().accept(this) + case c if c.createLoginSqlServer() != null => c.createLoginSqlServer().accept(this) + case c if c.createMasterKeyAzureSql() != null => c.createMasterKeyAzureSql().accept(this) + case c if c.createMasterKeySqlServer() != null => c.createMasterKeySqlServer().accept(this) + case c if c.createNonclusteredColumnstoreIndex() != null => c.createNonclusteredColumnstoreIndex().accept(this) + case c if c.createOrAlterBrokerPriority() != null => c.createOrAlterBrokerPriority().accept(this) + case c if c.createOrAlterEventSession() != null => c.createOrAlterEventSession().accept(this) + case c if c.createPartitionFunction() != null => c.createPartitionFunction().accept(this) + case c if c.createPartitionScheme() != null => c.createPartitionScheme().accept(this) + case c if c.createRemoteServiceBinding() != null => c.createRemoteServiceBinding().accept(this) + case c if c.createResourcePool() != null => c.createResourcePool().accept(this) + case c if c.createRoute() != null => c.createRoute().accept(this) + case c if c.createRule() != null => c.createRule().accept(this) + case c if c.createSchema() != null => c.createSchema().accept(this) + case c if c.createSchemaAzureSqlDwAndPdw() != null => c.createSchemaAzureSqlDwAndPdw().accept(this) + case c if c.createSearchPropertyList() != null => c.createSearchPropertyList().accept(this) + case c if c.createSecurityPolicy() != null => c.createSecurityPolicy().accept(this) + case c if c.createSequence() != null => c.createSequence().accept(this) + case c if c.createServerAudit() != null => c.createServerAudit().accept(this) + case c if c.createServerAuditSpecification() != null => c.createServerAuditSpecification().accept(this) + case c if c.createServerRole() != null => c.createServerRole().accept(this) + case c if c.createService() != null => c.createService().accept(this) + case c if c.createStatistics() != null => c.createStatistics().accept(this) + case c if c.createSynonym() != null => c.createSynonym().accept(this) + case c if c.createTable() != null => c.createTable().accept(this) + case c if c.createType() != null => c.createType().accept(this) + case c if c.createUser() != null => c.createUser().accept(this) + case c if c.createUserAzureSqlDw() != null => c.createUserAzureSqlDw().accept(this) + case c if c.createWorkloadGroup() != null => c.createWorkloadGroup().accept(this) + case c if c.createXmlIndex() != null => c.createXmlIndex().accept(this) + case c if c.createXmlSchemaCollection() != null => c.createXmlSchemaCollection().accept(this) + case c if c.triggerDisEn() != null => c.triggerDisEn().accept(this) + case c if c.dropAggregate() != null => c.dropAggregate().accept(this) + case c if c.dropApplicationRole() != null => c.dropApplicationRole().accept(this) + case c if c.dropAssembly() != null => c.dropAssembly().accept(this) + case c if c.dropAsymmetricKey() != null => c.dropAsymmetricKey().accept(this) + case c if c.dropAvailabilityGroup() != null => c.dropAvailabilityGroup().accept(this) + case c if c.dropBrokerPriority() != null => c.dropBrokerPriority().accept(this) + case c if c.dropCertificate() != null => c.dropCertificate().accept(this) + case c if c.dropColumnEncryptionKey() != null => c.dropColumnEncryptionKey().accept(this) + case c if c.dropColumnMasterKey() != null => c.dropColumnMasterKey().accept(this) + case c if c.dropContract() != null => c.dropContract().accept(this) + case c if c.dropCredential() != null => c.dropCredential().accept(this) + case c if c.dropCryptograhicProvider() != null => c.dropCryptograhicProvider().accept(this) + case c if c.dropDatabase() != null => c.dropDatabase().accept(this) + case c if c.dropDatabaseAuditSpecification() != null => c.dropDatabaseAuditSpecification().accept(this) + case c if c.dropDatabaseEncryptionKey() != null => c.dropDatabaseEncryptionKey().accept(this) + case c if c.dropDatabaseScopedCredential() != null => c.dropDatabaseScopedCredential().accept(this) + case c if c.dropDbRole() != null => c.dropDbRole().accept(this) + case c if c.dropDefault() != null => c.dropDefault().accept(this) + case c if c.dropEndpoint() != null => c.dropEndpoint().accept(this) + case c if c.dropEventNotifications() != null => c.dropEventNotifications().accept(this) + case c if c.dropEventSession() != null => c.dropEventSession().accept(this) + case c if c.dropExternalDataSource() != null => c.dropExternalDataSource().accept(this) + case c if c.dropExternalFileFormat() != null => c.dropExternalFileFormat().accept(this) + case c if c.dropExternalLibrary() != null => c.dropExternalLibrary().accept(this) + case c if c.dropExternalResourcePool() != null => c.dropExternalResourcePool().accept(this) + case c if c.dropExternalTable() != null => c.dropExternalTable().accept(this) + case c if c.dropFulltextCatalog() != null => c.dropFulltextCatalog().accept(this) + case c if c.dropFulltextIndex() != null => c.dropFulltextIndex().accept(this) + case c if c.dropFulltextStoplist() != null => c.dropFulltextStoplist().accept(this) + case c if c.dropFunction() != null => c.dropFunction().accept(this) + case c if c.dropIndex() != null => c.dropIndex().accept(this) + case c if c.dropLogin() != null => c.dropLogin().accept(this) + case c if c.dropMasterKey() != null => c.dropMasterKey().accept(this) + case c if c.dropMessageType() != null => c.dropMessageType().accept(this) + case c if c.dropPartitionFunction() != null => c.dropPartitionFunction().accept(this) + case c if c.dropPartitionScheme() != null => c.dropPartitionScheme().accept(this) + case c if c.dropProcedure() != null => c.dropProcedure().accept(this) + case c if c.dropQueue() != null => c.dropQueue().accept(this) + case c if c.dropRemoteServiceBinding() != null => c.dropRemoteServiceBinding().accept(this) + case c if c.dropResourcePool() != null => c.dropResourcePool().accept(this) + case c if c.dropRoute() != null => c.dropRoute().accept(this) + case c if c.dropRule() != null => c.dropRule().accept(this) + case c if c.dropSchema() != null => c.dropSchema().accept(this) + case c if c.dropSearchPropertyList() != null => c.dropSearchPropertyList().accept(this) + case c if c.dropSecurityPolicy() != null => c.dropSecurityPolicy().accept(this) + case c if c.dropSequence() != null => c.dropSequence().accept(this) + case c if c.dropServerAudit() != null => c.dropServerAudit().accept(this) + case c if c.dropServerAuditSpecification() != null => c.dropServerAuditSpecification().accept(this) + case c if c.dropServerRole() != null => c.dropServerRole().accept(this) + case c if c.dropService() != null => c.dropService().accept(this) + case c if c.dropSignature() != null => c.dropSignature().accept(this) + case c if c.dropStatistics() != null => c.dropStatistics().accept(this) + case c if c.dropStatisticsNameAzureDwAndPdw() != null => c.dropStatisticsNameAzureDwAndPdw().accept(this) + case c if c.dropSymmetricKey() != null => c.dropSymmetricKey().accept(this) + case c if c.dropSynonym() != null => c.dropSynonym().accept(this) + case c if c.dropTable() != null => c.dropTable().accept(this) + case c if c.dropTrigger() != null => c.dropTrigger().accept(this) + case c if c.dropType() != null => c.dropType().accept(this) + case c if c.dropUser() != null => c.dropUser().accept(this) + case c if c.dropView() != null => c.dropView().accept(this) + case c if c.dropWorkloadGroup() != null => c.dropWorkloadGroup().accept(this) + case c if c.dropXmlSchemaCollection() != null => c.dropXmlSchemaCollection().accept(this) + case c if c.triggerDisEn() != null => c.triggerDisEn().accept(this) + case c if c.lockTable() != null => c.lockTable().accept(this) + case c if c.truncateTable() != null => c.truncateTable().accept(this) + case c if c.updateStatistics() != null => c.updateStatistics().accept(this) + case _ => + ir.UnresolvedCatalog( + ruleText = contextText(ctx), + message = "Unknown DDL clause", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDMLBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDMLBuilder.scala new file mode 100644 index 0000000000..99e5d41fc6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDMLBuilder.scala @@ -0,0 +1,219 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.{StringContext => _, _} +import com.databricks.labs.remorph.parsers.tsql.rules.InsertDefaultsAction +import com.databricks.labs.remorph.{intermediate => ir} + +import scala.collection.JavaConverters.asScalaBufferConverter +class TSqlDMLBuilder(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[ir.Modification] + with ParserCommon[ir.Modification] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Modification = + ir.UnresolvedModification(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitDmlClause(ctx: DmlClauseContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + // NB: select is handled by the relationBuilder + case dml if dml.insert() != null => dml.insert.accept(this) + case dml if dml.delete() != null => dml.delete().accept(this) + case dml if dml.merge() != null => dml.merge().accept(this) + case dml if dml.update() != null => dml.update().accept(this) + case bulk if bulk.bulkStatement() != null => bulk.bulkStatement().accept(this) + case _ => + ir.UnresolvedModification( + ruleText = contextText(ctx), + message = s"Unknown DML clause ${ctx.getStart.getText} in TSqlDMLBuilder.visitDmlClause", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + override def visitMerge(ctx: MergeContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val targetPlan = ctx.ddlObject().accept(vc.relationBuilder) + val hints = vc.relationBuilder.buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(targetPlan, hints) + } else { + targetPlan + } + + val mergeCondition = ctx.searchCondition().accept(vc.expressionBuilder) + val tableSourcesPlan = ctx.tableSources().tableSource().asScala.map(_.accept(vc.relationBuilder)) + // Reduce is safe: Grammar rule for tableSources ensures that there is always at least one tableSource. + val sourcePlan = tableSourcesPlan.reduceLeft( + ir.Join(_, _, None, ir.CrossJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + + // We may have a number of when clauses, each with a condition and an action. We keep the ANTLR syntax compact + // and lean and determine which of the three types of action we have in the whenMatch method based on + // the presence or absence of syntactical elements NOT and SOURCE as SOURCE can only be used with NOT + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = Option(ctx.whenMatch()) + .map(_.asScala.foldLeft((List.empty[ir.MergeAction], List.empty[ir.MergeAction], List.empty[ir.MergeAction])) { + case ((matched, notMatched, notMatchedBySource), m) => + val action = buildWhenMatch(m) + (m.NOT(), m.SOURCE()) match { + case (null, _) => (action :: matched, notMatched, notMatchedBySource) + case (_, null) => (matched, action :: notMatched, notMatchedBySource) + case _ => (matched, notMatched, action :: notMatchedBySource) + } + }) + .getOrElse((List.empty, List.empty, List.empty)) + + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + val outputClause = Option(ctx.outputClause()).map(buildOutputClause) + + val mergeIntoTable = ir.MergeIntoTable( + finalTarget, + sourcePlan, + mergeCondition, + matchedActions, + notMatchedActions, + notMatchedBySourceActions) + + val withOptions = optionClause match { + case Some(option) => ir.WithModificationOptions(mergeIntoTable, option) + case None => mergeIntoTable + } + + outputClause match { + case Some(output) => WithOutputClause(withOptions, output) + case None => withOptions + } + } + + private def buildWhenMatch(ctx: WhenMatchContext): ir.MergeAction = { + val condition = Option(ctx.searchCondition()).map(_.accept(vc.expressionBuilder)) + ctx.mergeAction() match { + case action if action.DELETE() != null => ir.DeleteAction(condition) + case action if action.UPDATE() != null => buildUpdateAction(action, condition) + case action if action.INSERT() != null => buildInsertAction(action, condition) + } + } + + private def buildInsertAction(ctx: MergeActionContext, condition: Option[ir.Expression]): ir.MergeAction = { + + ctx match { + case action if action.DEFAULT() != null => InsertDefaultsAction(condition) + case _ => + val assignments = + (ctx.cols + .expression() + .asScala + .map(_.accept(vc.expressionBuilder)) zip ctx.vals.expression().asScala.map(_.accept(vc.expressionBuilder))) + .map { case (col, value) => + ir.Assign(col, value) + } + ir.InsertAction(condition, assignments) + } + } + + private def buildUpdateAction(ctx: MergeActionContext, condition: Option[ir.Expression]): ir.UpdateAction = { + val setElements = ctx.updateElem().asScala.collect { case elem => + elem.accept(vc.expressionBuilder) match { + case assign: ir.Assign => assign + } + } + ir.UpdateAction(condition, setElements) + } + + override def visitUpdate(ctx: UpdateContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(vc.relationBuilder) + val hints = vc.relationBuilder.buildTableHints(Option(ctx.withTableHints())) + val hintTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val finalTarget = vc.relationBuilder.buildTop(Option(ctx.topClause()), hintTarget) + val output = Option(ctx.outputClause()).map(buildOutputClause) + val setElements = ctx.updateElem().asScala.map(_.accept(vc.expressionBuilder)) + + val sourceRelation = buildTableSourcesPlan(Option(ctx.tableSources())) + val where = Option(ctx.updateWhereClause()) map (_.accept(vc.expressionBuilder)) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.UpdateTable(finalTarget, sourceRelation, setElements, where, output, optionClause) + } + + override def visitDelete(ctx: DeleteContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(vc.relationBuilder) + val hints = vc.relationBuilder.buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val output = Option(ctx.outputClause()).map(buildOutputClause) + val sourceRelation = buildTableSourcesPlan(Option(ctx.tableSources())) + val where = Option(ctx.updateWhereClause()) map (_.accept(vc.expressionBuilder)) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.DeleteFromTable(finalTarget, sourceRelation, where, output, optionClause) + } + + private[this] def buildTableSourcesPlan(tableSources: Option[TableSourcesContext]): Option[ir.LogicalPlan] = { + val sources = tableSources + .map(_.tableSource().asScala) + .getOrElse(Seq()) + .map(_.accept(vc.relationBuilder)) + sources.reduceLeftOption( + ir.Join(_, _, None, ir.CrossJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + } + + override def visitInsert(ctx: InsertContext): ir.Modification = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(vc.relationBuilder) + val hints = vc.relationBuilder.buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val columns = Option(ctx.expressionList()) + .map(_.expression().asScala.map(_.accept(vc.expressionBuilder)).collect { case col: ir.Column => + col.columnName + }) + + val output = Option(ctx.outputClause()).map(buildOutputClause) + val values = buildInsertStatementValue(ctx.insertStatementValue()) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.InsertIntoTable(finalTarget, columns, values, output, optionClause) + } + + private def buildInsertStatementValue(ctx: InsertStatementValueContext): ir.LogicalPlan = { + Option(ctx) match { + case Some(context) if context.derivedTable() != null => context.derivedTable().accept(vc.relationBuilder) + case Some(context) if context.VALUES() != null => DefaultValues() + case Some(context) => context.executeStatement().accept(vc.relationBuilder) + } + } + + private def buildOutputClause(ctx: OutputClauseContext): Output = { + val outputs = ctx.outputDmlListElem().asScala.map(_.accept(vc.expressionBuilder)) + val target = Option(ctx.ddlObject()).map(_.accept(vc.relationBuilder)) + val columns = + Option(ctx.columnNameList()) + .map(_.id().asScala.map(id => ir.Column(None, vc.expressionBuilder.buildId(id)))) + + // Databricks SQL does not support the OUTPUT clause, but we may be able to translate + // the clause to SELECT statements executed before or after the INSERT/DELETE/UPDATE/MERGE + // is executed + Output(target, outputs, columns) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStrategy.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStrategy.scala new file mode 100644 index 0000000000..58b364387b --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStrategy.scala @@ -0,0 +1,965 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.SqlErrorStrategy +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.{StringContext => _, _} +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.IntervalSet + +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * Custom error strategy for SQL parsing

While we do not do anything super special here, we wish to override a + * couple of the message generating methods and the token insert and delete messages, which do not create an exception + * and don't allow us to create an error message in context. Additionally, we can now implement i18n, should that ever + * become necessary.

+ * + *

At the moment, we require valid SQL as child to the conversion process, but if we ever change that strategy, then + * we can implement custom recovery steps here based upon context, though there is no improvement on the sync() + * method.

+ */ +class TSqlErrorStrategy extends SqlErrorStrategy { + + /** + * Generate a message for the error. + * + * The exception contains a stack trace, from which we can construct a more informative error message than just + * mismatched child and a huge list of things we were looking for. + * + * @param e + * the RecognitionException + * @return + * the error message + */ + override protected def generateMessage(recognizer: Parser, e: RecognitionException): String = { + val messages = new ListBuffer[String]() + + // We build the messages by looking at the stack trace of the exception, but if the + // rule translation is not found, or it is the same as the previous message, we skip it, + // to avoid repeating the same message multiple times. This is because a recognition error + // could be found in a parent rule or a child rule but there is no extra information + // provided by being more specific about the rule location. Also, in some productions + // we may be embedded very deeply in the stack trace, so we want to avoid too many contexts + // in a message. + e.getStackTrace.foreach { traceElement => + val methodName = traceElement.getMethodName + TSqlErrorStrategy.ruleTranslation.get(methodName).foreach { translatedMessage => + // Only mention a batch if we have recovered all the way to the top rule + val shouldAppend = if (methodName == "tSqlFile") { + messages.isEmpty + } else { + messages.isEmpty || messages.last != translatedMessage + } + + if (shouldAppend) { + messages.append(translatedMessage) + } + } + } + + if (messages.nonEmpty) { + val initialMessage = s"while parsing ${articleFor(messages.head)} ${messages.head}" + messages.drop(1).foldLeft(initialMessage) { (acc, message) => + s"$acc in ${articleFor(message)} $message" + } + } else "" + } + + private[this] val vowels = Set('a', 'e', 'i', 'o', 'u') + + def articleFor(word: String): String = { + if (word.nonEmpty && vowels.contains(word.head.toLower)) "an" else "a" + } + + /** + * When building the list of expected tokens, we do some custom manipulation so that we do not produce a list of 750 + * possible tokens because there are so many keywords that can be used as id/column names. If ID is a valid expected + * token, then we remove all the keywords that are there because they can be an ID. + * @param expected + * the set of valid tokens at this point in the parse, where the error was found + * @return + * the expected string with tokens renamed in more human friendly form + */ + override protected def buildExpectedMessage(recognizer: Parser, expected: IntervalSet): String = { + val expect = if (expected.contains(ID)) { + removeIdKeywords(expected) + } else { + expected + } + + val uniqueExpectedTokens = mutable.Set[String]() + + // Iterate through the expected tokens + expect.toList.foreach { tokenId => + // Check if the token ID has a custom translation + val tokenString = TSqlErrorStrategy.tokenTranslation.get(tokenId) match { + case Some(translatedName) => translatedName + case None => recognizer.getVocabulary.getDisplayName(tokenId) + } + uniqueExpectedTokens += tokenString + } + + // Join the unique expected token strings with a comma and space for readability + // but only take the first 12 tokens to avoid a huge list of expected tokens + if (uniqueExpectedTokens.size <= 12) { + return uniqueExpectedTokens.toSeq.sorted(capitalizedSort).mkString(", ") + } + uniqueExpectedTokens.toSeq.sorted(capitalizedSort).take(12).mkString(", ") + "..." + } + + /** + * Runs through the given interval and removes all the keywords that are in the set. + * @param set + * The interval from whence to remove keywords that can be Identifiers + */ + private def removeIdKeywords(set: IntervalSet): IntervalSet = { + set.subtract(TSqlErrorStrategy.keywordIDs) + } +} + +object TSqlErrorStrategy { + + // A map that will override teh default display name for tokens that represent text with + // pattern matches like IDENTIFIER, STRING, etc. + private[TSqlErrorStrategy] val tokenTranslation: Map[Int, String] = Map( + AAPSEUDO -> "@@Reference", + DOUBLE_QUOTE_ID -> "Identifier", + FLOAT -> "Float", + ID -> "Identifier", + INT -> "Integer", + LOCAL_ID -> "@Local", + MONEY -> "$Currency", + REAL -> "Real", + SQUARE_BRACKET_ID -> "Identifier", + STRING -> "'String'", + TEMP_ID -> "Identifier", + -1 -> "End of batch", + JINJA_REF -> "Jinja Template Element", + + // When the next thing we expect can be every statement, we just say "statement" + ALTER -> "Statement", + BACKUP -> "Statement", + BEGIN -> "Statement", + BREAK -> "Statement", + CHECKPOINT -> "Statement", + CLOSE -> "Statement", + COMMIT -> "Statement", + CONTINUE -> "Statement", + CREATE -> "Statement", + DEALLOCATE -> "Statement", + DECLARE -> "Statement", + DELETE -> "Statement", + DROP -> "Statement", + END -> "Statement", + EXECUTE -> "Statement", + EXECUTE -> "Statement", + FETCH -> "Statement", + GOTO -> "Statement", + GRANT -> "Statement", + IF -> "Statement", + INSERT -> "Statement", + KILL -> "Statement", + MERGE -> "Statement", + OPEN -> "Statement", + PRINT -> "Statement", + RAISERROR -> "Statement", + RECONFIGURE -> "Statement", + RETURN -> "Statement", + REVERT -> "Statement", + ROLLBACK -> "Statement", + SAVE -> "Statement", + SELECT -> "Select Statement", + SET -> "Statement", + SETUSER -> "Statement", + SHUTDOWN -> "Statement", + TRUNCATE -> "Statement", + UPDATE -> "Statement", + USE -> "Statement", + WAITFOR -> "Statement", + WHILE -> "Statement", + WITHIN -> "Statement", + + // No need to distinguish between operators + + AND_ASSIGN -> "Assignment Operator", + OR_ASSIGN -> "Assignment Operator", + XOR_ASSIGN -> "Assignment Operator", + BANG -> "Operator", + BIT_AND -> "Operator", + BIT_NOT -> "Operator", + BIT_OR -> "Operator", + BIT_XOR -> "Operator", + DE -> "Operator", + DIV -> "Operator", + DOUBLE_BAR -> "Operator", + EQ -> "Operator", + GT -> "Operator", + LT -> "Operator", + ME -> "Operator", + MEA -> "Operator", + MINUS -> "Operator", + MOD -> "Operator", + PE -> "Operator", + PLUS -> "Operator", + SE -> "Operator") + + private[TSqlErrorStrategy] val ruleTranslation: Map[String, String] = Map( + "tSqlFile" -> "T-SQL batch", + "executeBodyBatch" -> "Stored procedure call", + "jingjaTemplate" -> "Jinja template element", + "selectStatement" -> "SELECT statement", + "selectStatementStandalone" -> "SELECT statement", + "selectList" -> "SELECT list", + "selectListElement" -> "SELECT list element", + "selectItem" -> "SELECT item", + "fromClause" -> "FROM clause", + "whereClause" -> "WHERE clause", + "groupByClause" -> "GROUP BY clause", + "havingClause" -> "HAVING clause", + "orderByClause" -> "ORDER BY clause", + "limitClause" -> "LIMIT clause", + "offsetClause" -> "OFFSET clause", + "joinClause" -> "JOIN clause", + "joinCondition" -> "JOIN condition", + "joinType" -> "JOIN type", + "joinOn" -> "JOIN ON", + "joinUsing" -> "JOIN USING", + "joinTable" -> "JOIN table", + "joinAlias" -> "JOIN alias", + "joinColumn" -> "JOIN column", + "joinExpression" -> "JOIN expression", + "joinOperator" -> "JOIN operator", + "joinSubquery" -> "JOIN subquery", + "joinSubqueryAlias" -> "JOIN subquery alias", + "joinSubqueryColumn" -> "JOIN subquery column", + "joinSubqueryExpression" -> "JOIN subquery expression", + "joinSubqueryOperator" -> "JOIN subquery operator", + "joinSubqueryTable" -> "JOIN subquery table", + "joinSubqueryTableAlias" -> "JOIN subquery table alias", + "joinSubqueryTableColumn" -> "JOIN subquery table column", + "joinSubqueryTableExpression" -> "JOIN subquery table expression", + "joinSubqueryTableOperator" -> "JOIN subquery table operator", + "joinSubqueryTableSubquery" -> "JOIN subquery table subquery", + "joinSubqueryTableSubqueryAlias" -> "JOIN subquery table subquery alias", + "joinSubqueryTableSubqueryColumn" -> "JOIN subquery table subquery column", + "joinSubqueryTableSubqueryExpression" -> "JOIN subquery table subquery expression", + "joinSubqueryTableSubqueryOperator" -> "JOIN subquery table subquery operator", + "joinSubqueryTableSubqueryTable" -> "JOIN subquery table subquery table", + "updateStatement" -> "UPDATE statement", + "update" -> "UPDATE statement", + "topClause" -> "TOP clause", + "ddlObject" -> "TABLE object", + "withTableHints" -> "WITH table hints", + "updateElem" -> "UPDATE element specification", + "outputClause" -> "OUTPUT clause", + "updateWhereClause" -> "WHERE clause", + "optionClause" -> "OPTION clause", + + // Etc + "tableSource" -> "table source", + "tableSourceItem" -> "table source") + + private[TSqlErrorStrategy] val keywordIDs: IntervalSet = new IntervalSet( + ABORT, + ABORT_AFTER_WAIT, + ABSENT, + ABSOLUTE, + ACCENT_SENSITIVITY, + ACCESS, + ACTION, + ACTIVATION, + ACTIVE, + ADD, + ADDRESS, + ADMINISTER, + AES, + AES_128, + AES_192, + AES_256, + AFFINITY, + AFTER, + AGGREGATE, + ALGORITHM, + ALL_CONSTRAINTS, + ALL_ERRORMSGS, + ALL_INDEXES, + ALL_LEVELS, + ALLOW_CONNECTIONS, + ALLOW_ENCRYPTED_VALUE_MODIFICATIONS, + ALLOW_MULTIPLE_EVENT_LOSS, + ALLOW_PAGE_LOCKS, + ALLOW_ROW_LOCKS, + ALLOW_SINGLE_EVENT_LOSS, + ALLOW_SNAPSHOT_ISOLATION, + ALLOWED, + ALWAYS, + ANONYMOUS, + ANSI_DEFAULTS, + ANSI_NULL_DEFAULT, + ANSI_NULL_DFLT_OFF, + ANSI_NULL_DFLT_ON, + ANSI_NULLS, + ANSI_PADDING, + ANSI_WARNINGS, + APPEND, + APPLICATION, + APPLICATION_LOG, + APPLY, + ARITHABORT, + ARITHIGNORE, + ASSEMBLY, + ASYMMETRIC, + ASYNCHRONOUS_COMMIT, + AT_KEYWORD, + AUDIT, + AUDIT_GUID, + AUTHENTICATE, + AUTHENTICATION, + AUTO, + AUTO_CLEANUP, + AUTO_CLOSE, + AUTO_CREATE_STATISTICS, + AUTO_DROP, + AUTO_SHRINK, + AUTO_UPDATE_STATISTICS, + AUTO_UPDATE_STATISTICS_ASYNC, + AUTOGROW_ALL_FILES, + AUTOGROW_SINGLE_FILE, + AUTOMATED_BACKUP_PREFERENCE, + AUTOMATIC, + AVAILABILITY, + AVAILABILITY_MODE, + BACKUP_CLONEDB, + BACKUP_PRIORITY, + BEFORE, + BEGIN_DIALOG, + BINARY, + BINDING, + BLOB_STORAGE, + BLOCK, + BLOCKERS, + BLOCKSIZE, + BROKER, + BROKER_INSTANCE, + BUFFER, + BUFFERCOUNT, + BULK_LOGGED, + CACHE, + CALLED, + CALLER, + CAP_CPU_PERCENT, + CAST, + CATALOG, + CATCH, + CERTIFICATE, + CHANGE, + CHANGE_RETENTION, + CHANGE_TRACKING, + CHANGES, + CHANGETABLE, + CHECK_EXPIRATION, + CHECK_POLICY, + CHECKALLOC, + CHECKCATALOG, + CHECKCONSTRAINTS, + CHECKDB, + CHECKFILEGROUP, + CHECKSUM, + CHECKTABLE, + CLASSIFIER_FUNCTION, + CLEANTABLE, + CLEANUP, + CLONEDATABASE, + CLUSTER, + COLLECTION, + COLUMN_ENCRYPTION_KEY, + COLUMN_MASTER_KEY, + COLUMNS, + COLUMNSTORE, + COLUMNSTORE_ARCHIVE, + COMMITTED, + COMPATIBILITY_LEVEL, + COMPRESS_ALL_ROW_GROUPS, + COMPRESSION, + COMPRESSION_DELAY, + CONCAT, + CONCAT_NULL_YIELDS_NULL, + CONFIGURATION, + CONNECT, + CONNECTION, + CONTAINMENT, + CONTENT, + CONTEXT, + CONTINUE_AFTER_ERROR, + CONTRACT, + CONTRACT_NAME, + CONTROL, + CONVERSATION, + COOKIE, + COPY_ONLY, + COUNTER, + CPU, + CREATE_NEW, + CREATION_DISPOSITION, + CREDENTIAL, + CRYPTOGRAPHIC, + CURSOR_CLOSE_ON_COMMIT, + CURSOR_DEFAULT, + CYCLE, + DATA, + DATA_COMPRESSION, + DATA_PURITY, + DATA_SOURCE, + DATABASE_MIRRORING, + DATASPACE, + DATE_CORRELATION_OPTIMIZATION, + DAYS, + DB_CHAINING, + DB_FAILOVER, + DBCC, + DBREINDEX, + DDL, + DECRYPTION, + DEFAULT, + DEFAULT_DATABASE, + DEFAULT_DOUBLE_QUOTE, + DEFAULT_FULLTEXT_LANGUAGE, + DEFAULT_LANGUAGE, + DEFAULT_SCHEMA, + DEFINITION, + DELAY, + DELAYED_DURABILITY, + DELETED, + DEPENDENTS, + DES, + DESCRIPTION, + DESX, + DETERMINISTIC, + DHCP, + DIAGNOSTICS, + DIALOG, + DIFFERENTIAL, + DIRECTORY_NAME, + DISABLE, + DISABLE_BROKER, + DISABLED, + DISTRIBUTION, + DOCUMENT, + DROP_EXISTING, + DROPCLEANBUFFERS, + DTC_SUPPORT, + DYNAMIC, + ELEMENTS, + EMERGENCY, + EMPTY, + ENABLE, + ENABLE_BROKER, + ENABLED, + ENCRYPTED, + ENCRYPTED_VALUE, + ENCRYPTION, + ENCRYPTION_TYPE, + ENDPOINT, + ENDPOINT_URL, + ERROR, + ERROR_BROKER_CONVERSATIONS, + ESTIMATEONLY, + EVENT, + EVENT_RETENTION_MODE, + EXCLUSIVE, + EXECUTABLE, + EXECUTABLE_FILE, + EXPIREDATE, + EXPIRY_DATE, + EXPLICIT, + EXTENDED_LOGICAL_CHECKS, + EXTENSION, + EXTERNAL_ACCESS, + FAIL_OPERATION, + FAILOVER, + FAILOVER_MODE, + FAILURE, + FAILURE_CONDITION_LEVEL, + FAILURECONDITIONLEVEL, + FAN_IN, + FAST_FORWARD, + FILE_SNAPSHOT, + FILEGROUP, + FILEGROWTH, + FILENAME, + FILEPATH, + FILESTREAM, + FILESTREAM_ON, + FILTER, + FIRST, + FMTONLY, + FOLLOWING, + FOR, + FORCE, + FORCE_FAILOVER_ALLOW_DATA_LOSS, + FORCE_SERVICE_ALLOW_DATA_LOSS, + FORCEPLAN, + FORCESCAN, + FORCESEEK, + FORMAT, + FORWARD_ONLY, + FREE, + FULLSCAN, + FULLTEXT, + GB, + GENERATED, + GET, + GETROOT, + GLOBAL, + GO, + GOVERNOR, + GROUP_MAX_REQUESTS, + GROUPING, + HADR, + HASH, + HASHED, + HEALTH_CHECK_TIMEOUT, + HEALTHCHECKTIMEOUT, + HEAP, + HIDDEN_KEYWORD, + HIERARCHYID, + HIGH, + HONOR_BROKER_PRIORITY, + HOURS, + IDENTITY_VALUE, + IGNORE_CONSTRAINTS, + IGNORE_DUP_KEY, + IGNORE_REPLICATED_TABLE_CACHE, + IGNORE_TRIGGERS, + IIF, + IMMEDIATE, + IMPERSONATE, + IMPLICIT_TRANSACTIONS, + IMPORTANCE, + INCLUDE, + INCLUDE_NULL_VALUES, + INCREMENT, + INCREMENTAL, + INFINITE, + INIT, + INITIATOR, + INPUT, + INSENSITIVE, + INSERTED, + INSTEAD, + IO, + IP, + ISOLATION, + JOB, + JSON, + JSON_ARRAY, + JSON_OBJECT, + KB, + KEEPDEFAULTS, + KEEPIDENTITY, + KERBEROS, + KEY_PATH, + KEY_SOURCE, + KEY_STORE_PROVIDER_NAME, + KEYS, + KEYSET, + LANGUAGE, + LAST, + LEVEL, + LIBRARY, + LIFETIME, + LINKED, + LINUX, + LIST, + LISTENER, + LISTENER_IP, + LISTENER_PORT, + LISTENER_URL, + LOB_COMPACTION, + LOCAL, + LOCAL_SERVICE_NAME, + LOCATION, + LOCK, + LOCK_ESCALATION, + LOGIN, + LOOP, + LOW, + MANUAL, + MARK, + MASK, + MASKED, + MASTER, + MATCHED, + MATERIALIZED, + MAX, + MAX_CPU_PERCENT, + MAX_DISPATCH_LATENCY, + MAX_DOP, + MAX_DURATION, + MAX_EVENT_SIZE, + MAX_FILES, + MAX_IOPS_PER_VOLUME, + MAX_MEMORY, + MAX_MEMORY_PERCENT, + MAX_OUTSTANDING_IO_PER_VOLUME, + MAX_PROCESSES, + MAX_QUEUE_READERS, + MAX_ROLLOVER_FILES, + MAX_SIZE, + MAXSIZE, + MAXTRANSFER, + MAXVALUE, + MB, + MEDIADESCRIPTION, + MEDIANAME, + MEDIUM, + MEMBER, + MEMORY_OPTIMIZED_DATA, + MEMORY_PARTITION_MODE, + MESSAGE, + MESSAGE_FORWARD_SIZE, + MESSAGE_FORWARDING, + MIN_CPU_PERCENT, + MIN_IOPS_PER_VOLUME, + MIN_MEMORY_PERCENT, + MINUTES, + MINVALUE, + MIRROR, + MIRROR_ADDRESS, + MIXED_PAGE_ALLOCATION, + MODE, + MODIFY, + MOVE, + MULTI_USER, + MUST_CHANGE, + NAME, + NESTED_TRIGGERS, + NEW_ACCOUNT, + NEW_BROKER, + NEW_PASSWORD, + NEWNAME, + NEXT, + NO, + NO_CHECKSUM, + NO_COMPRESSION, + NO_EVENT_LOSS, + NO_INFOMSGS, + NO_QUERYSTORE, + NO_STATISTICS, + NO_TRUNCATE, + NOCOUNT, + NODES, + NOEXEC, + NOEXPAND, + NOFORMAT, + NOINDEX, + NOINIT, + NOLOCK, + NON_TRANSACTED_ACCESS, + NONE, + NORECOMPUTE, + NORECOVERY, + NOREWIND, + NOSKIP, + NOTIFICATION, + NOTIFICATIONS, + NOUNLOAD, + NTILE, + NTLM, + NUMANODE, + NUMERIC_ROUNDABORT, + OBJECT, + OFFLINE, + OFFSET, + OLD_ACCOUNT, + OLD_PASSWORD, + ON_FAILURE, + ON, + OFF, + ONLINE, + ONLY, + OPEN_EXISTING, + OPENJSON, + OPERATIONS, + OPTIMISTIC, + OUT, + OUTPUT, + OVERRIDE, + OWNER, + OWNERSHIP, + PAD_INDEX, + PAGE, + PAGE_VERIFY, + PAGECOUNT, + PAGLOCK, + PARAM_NODE, + PARAMETERIZATION, + PARSEONLY, + PARTIAL, + PARTITION, + PARTITIONS, + PARTNER, + PASSWORD, + PATH, + PAUSE, + PDW_SHOWSPACEUSED, + PER_CPU, + PER_DB, + PER_NODE, + PERMISSION_SET, + PERSIST_SAMPLE_PERCENT, + PERSISTED, + PHYSICAL_ONLY, + PLATFORM, + POISON_MESSAGE_HANDLING, + POLICY, + POOL, + PORT, + PRECEDING, + PREDICATE, + PRIMARY_ROLE, + PRIOR, + PRIORITY, + PRIORITY_LEVEL, + PRIVATE, + PRIVATE_KEY, + PRIVILEGES, + PROCCACHE, + PROCEDURE_NAME, + PROCESS, + PROFILE, + PROPERTY, + PROVIDER, + PROVIDER_KEY_NAME, + QUERY, + QUEUE, + QUEUE_DELAY, + QUOTED_IDENTIFIER, + RANDOMIZED, + RANGE, + RC2, + RC4, + RC4_128, + READ_COMMITTED_SNAPSHOT, + READ_ONLY, + READ_ONLY_ROUTING_LIST, + READ_WRITE, + READ_WRITE_FILEGROUPS, + READCOMMITTED, + READCOMMITTEDLOCK, + READONLY, + READPAST, + READUNCOMMITTED, + READWRITE, + REBUILD, + RECEIVE, + RECOVERY, + RECURSIVE_TRIGGERS, + REGENERATE, + RELATED_CONVERSATION, + RELATED_CONVERSATION_GROUP, + RELATIVE, + REMOTE, + REMOTE_PROC_TRANSACTIONS, + REMOTE_SERVICE_NAME, + REMOVE, + REORGANIZE, + REPAIR_ALLOW_DATA_LOSS, + REPAIR_FAST, + REPAIR_REBUILD, + REPEATABLE, + REPEATABLEREAD, + REPLACE, + REPLICA, + REQUEST_MAX_CPU_TIME_SEC, + REQUEST_MAX_MEMORY_GRANT_PERCENT, + REQUEST_MEMORY_GRANT_TIMEOUT_SEC, + REQUIRED, + REQUIRED_SYNCHRONIZED_SECONDARIES_TO_COMMIT, + RESAMPLE, + RESERVE_DISK_SPACE, + RESET, + RESOURCE, + RESOURCE_MANAGER_LOCATION, + RESOURCES, + RESTART, + RESTRICTED_USER, + RESUMABLE, + RESUME, + RETAINDAYS, + RETENTION, + RETURNS, + REWIND, + ROLE, + ROOT, + ROUND_ROBIN, + ROUTE, + ROW, + ROWGUID, + ROWLOCK, + ROWS, + RSA_512, + RSA_1024, + RSA_2048, + RSA_3072, + RSA_4096, + SAFE, + SAFETY, + SAMPLE, + SCHEDULER, + SCHEMABINDING, + SCHEME, + SCOPED, + SCRIPT, + SCROLL, + SCROLL_LOCKS, + SEARCH, + SECONDARY, + SECONDARY_ONLY, + SECONDARY_ROLE, + SECONDS, + SECRET, + SECURABLES, + SECURITY, + SECURITY_LOG, + SEEDING_MODE, + SELF, + SEMI_SENSITIVE, + SEND, + SENT, + SEQUENCE, + SEQUENCE_NUMBER, + SERIALIZABLE, + SERVER, + SERVICE, + SERVICE_BROKER, + SERVICE_NAME, + SERVICEBROKER, + SESSION, + SESSION_TIMEOUT, + SETTINGS, + SHARE, + SHARED, + SHOWCONTIG, + SHOWPLAN, + SHOWPLAN_ALL, + SHOWPLAN_TEXT, + SHOWPLAN_XML, + SHRINKLOG, + SID, + SIGNATURE, + SINGLE_USER, + SIZE, + KWSKIP, + SNAPSHOT, + SOFTNUMA, + SORT_IN_TEMPDB, + SOURCE, + SP_EXECUTESQL, + SPARSE, + SPATIAL_WINDOW_MAX_CELLS, + SPECIFICATION, + SPLIT, + SQLDUMPERFLAGS, + SQLDUMPERPATH, + SQLDUMPERTIMEOUT, + STANDBY, + START, + START_DATE, + STARTED, + STARTUP_STATE, + STATE, + STATIC, + STATISTICS_INCREMENTAL, + STATISTICS_NORECOMPUTE, + STATS, + STATS_STREAM, + STATUS, + STATUSONLY, + STOP, + STOP_ON_ERROR, + STOPLIST, + STOPPED, + SUBJECT, + SUBSCRIBE, + SUBSCRIPTION, + SUPPORTED, + SUSPEND, + SWITCH, + SYMMETRIC, + SYNCHRONOUS_COMMIT, + SYNONYM, + SYSTEM, + TABLE, + TABLERESULTS, + TABLOCK, + TABLOCKX, + TAKE, + TAPE, + TARGET, + TARGET_RECOVERY_TIME, + TB, + TCP, + TEXTIMAGE_ON, + THROW, + TIES, + TIMEOUT, + TIMER, + TORN_PAGE_DETECTION, + TOSTRING, + TRACE, + TRACK_CAUSALITY, + TRACKING, + TRANSACTION_ID, + TRANSFER, + TRANSFORM_NOISE_WORDS, + TRIPLE_DES, + TRIPLE_DES_3KEY, + TRUSTWORTHY, + TRY, + TRY_CAST, + TSQL, + TWO_DIGIT_YEAR_CUTOFF, + TYPE, + TYPE_WARNING, + UNBOUNDED, + UNCHECKED, + UNCOMMITTED, + UNLIMITED, + UNLOCK, + UNMASK, + UNSAFE, + UOW, + UPDLOCK, + URL, + USED, + USING, + VALID_XML, + VALIDATION, + VALUE, + VAR, + VERBOSELOGGING, + VERIFY_CLONEDB, + VERSION, + VIEW_METADATA, + VISIBILITY, + WAIT, + WAIT_AT_LOW_PRIORITY, + WELL_FORMED_XML, + WINDOWS, + WITHOUT, + WITHOUT_ARRAY_WRAPPER, + WITNESS, + WORK, + WORKLOAD, + XACT_ABORT, + XLOCK, + XML, + XML_COMPRESSION, + XMLDATA, + XMLNAMESPACES, + XMLSCHEMA, + XSINIL, + ZONE) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilder.scala new file mode 100644 index 0000000000..1369bb9de7 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilder.scala @@ -0,0 +1,813 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.{StringContext => _, _} +import com.databricks.labs.remorph.parsers.{ParserCommon, XmlFunction, tsql} +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.Token +import org.antlr.v4.runtime.tree.Trees + +import scala.collection.JavaConverters._ + +class TSqlExpressionBuilder(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[ir.Expression] + with ParserCommon[ir.Expression] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.Expression = { + ir.UnresolvedExpression(ruleText = ruleText, message = message) + } + + // Concrete visitors.. + + override def visitOptionClause(ctx: TSqlParser.OptionClauseContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // we gather the options given to use by the original query, though at the moment, we do nothing + // with them. + val opts = vc.optionBuilder.buildOptionList(ctx.lparenOptionList().optionList().genericOption().asScala) + ir.Options(opts.expressionOpts, opts.stringOpts, opts.boolFlags, opts.autoFlags) + } + + override def visitUpdateElemCol(ctx: TSqlParser.UpdateElemColContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val value = ctx.expression().accept(this) + val target1 = Option(ctx.l2) + .map(l2 => ir.Identifier(l2.getText, isQuoted = false)) + .getOrElse(ctx.fullColumnName().accept(this)) + val a1 = buildAssign(target1, value, ctx.op) + Option(ctx.l1).map(l1 => ir.Assign(ir.Identifier(l1.getText, isQuoted = false), a1)).getOrElse(a1) + } + + override def visitUpdateElemUdt(ctx: TSqlParser.UpdateElemUdtContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val args = ctx.expressionList().expression().asScala.map(_.accept(this)) + val fName = ctx.id(0).getText + "." + ctx.id(1).getText + val functionResult = vc.functionBuilder.buildFunction(fName, args) + + functionResult match { + case unresolvedFunction: ir.UnresolvedFunction => + unresolvedFunction.copy(is_user_defined_function = true) + case _ => functionResult + } + } + + override def visitUpdateWhereClause(ctx: UpdateWhereClauseContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.searchCondition().accept(this) + // TODO: TSQL also supports updates via cursor traversal, which is not supported in Databricks SQL + // generate UnresolvedExpression + } + + private[tsql] def buildSelectList(ctx: TSqlParser.SelectListContext): Seq[ir.Expression] = { + val first = buildSelectListElem(ctx.selectListElem()) + val rest = ctx.selectElemTempl().asScala.flatMap(buildSelectListElemTempl) + first ++ rest + } + + private def buildSelectListElemTempl(ctx: TSqlParser.SelectElemTemplContext): Seq[ir.Expression] = { + val templ = Option(ctx.jinjaTemplate()).map(_.accept(this)).toSeq + val elem = Option(ctx.selectListElem()).toSeq.flatMap(buildSelectListElem) + errorCheck(ctx).toSeq ++ templ ++ elem + } + + private[tsql] def buildSelectListElem(ctx: TSqlParser.SelectListElemContext): Seq[ir.Expression] = { + + // If this node has an error such as an extra comma, then don't discard it, prefix it with the errorNode + val errors = errorCheck(ctx) + val elem = ctx match { + case c if c.asterisk() != null => c.asterisk().accept(this) + case c if c.LOCAL_ID() != null => vc.expressionBuilder.buildLocalAssign(ctx) + case c if c.expressionElem() != null => ctx.expressionElem().accept(this) + case _ => + ir.UnresolvedExpression( + ruleText = contextText(ctx), + message = s"Unsupported select list element", + ruleName = "expression", + tokenName = Some(tokenName(ctx.getStart))) + } + errors match { + case Some(errorResult) => Seq(errorResult, elem) + case None => Seq(elem) + } + } + + /** + * Build a local variable assignment from a column source + * + * @param ctx + * the parse tree containing the assignment + */ + private def buildLocalAssign(ctx: TSqlParser.SelectListElemContext): ir.Expression = { + val localId = ir.Identifier(ctx.LOCAL_ID().getText, isQuoted = false) + val expression = ctx.expression().accept(this) + buildAssign(localId, expression, ctx.op) + } + + private def buildAssign(target: ir.Expression, value: ir.Expression, op: Token): ir.Expression = { + op.getType match { + case EQ => ir.Assign(target, value) + case PE => ir.Assign(target, ir.Add(target, value)) + case ME => ir.Assign(target, ir.Subtract(target, value)) + case SE => ir.Assign(target, ir.Multiply(target, value)) + case DE => ir.Assign(target, ir.Divide(target, value)) + case MEA => ir.Assign(target, ir.Mod(target, value)) + case AND_ASSIGN => ir.Assign(target, ir.BitwiseAnd(target, value)) + case OR_ASSIGN => ir.Assign(target, ir.BitwiseOr(target, value)) + case XOR_ASSIGN => ir.Assign(target, ir.BitwiseXor(target, value)) + // We can only reach here if the grammar is changed to add more operators and this function is not updated + case _ => + ir.UnresolvedExpression( + ruleText = op.getText, + message = s"Unexpected operator ${op.getText} in assignment", + ruleName = "expression", + tokenName = Some(tokenName(op)) + ) // Handle unexpected operation types + } + } + + private def buildTableName(ctx: TableNameContext): ir.ObjectReference = { + val linkedServer = Option(ctx.linkedServer).map(buildId) + val ids = ctx.ids.asScala.map(buildId) + val allIds = linkedServer.fold(ids)(ser => ser +: ids) + ir.ObjectReference(allIds.head, allIds.tail: _*) + } + + override def visitExprId(ctx: ExprIdContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Column(None, buildId(ctx.id())) + } + + override def visitFullColumnName(ctx: FullColumnNameContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val columnName = buildId(ctx.id) + val tableName = Option(ctx.tableName()).map(buildTableName) + ir.Column(tableName, columnName) + } + + /** + * Handles * used in column expressions. + * + * This can be used in things like SELECT * FROM table + * + * @param ctx + * the parse tree + */ + override def visitAsterisk(ctx: AsteriskContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case _ if ctx.tableName() != null => + val objectName = Option(ctx.tableName()).map(buildTableName) + ir.Star(objectName) + case _ if ctx.INSERTED() != null => Inserted(ir.Star(None)) + case _ if ctx.DELETED() != null => Deleted(ir.Star(None)) + case _ => ir.Star(None) + } + } + + /** + * Expression precedence as defined by parenthesis + * + * @param ctx + * the ExprPrecedenceContext to visit, which contains the expression to which precedence is applied + * @return + * the visited expression in IR + * + * Note that precedence COULD be explicitly placed in the AST here. If we wish to construct an exact replication of + * expression source code from the AST, we need to know that the () were there. Redundant parens are otherwise elided + * and the generated code may seem to be incorrect in the eyes of the customer, even though it will be logically + * equivalent. + */ + override def visitExprPrecedence(ctx: ExprPrecedenceContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.expression().accept(this) + } + + override def visitExprBitNot(ctx: ExprBitNotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.BitwiseNot(ctx.expression().accept(this)) + } + + // Note that while we could evaluate the unary expression if it is a numeric + // constant, it is usually better to be explicit about the unary operation as + // if people use -+-42 then maybe they have a reason. + override def visitExprUnary(ctx: ExprUnaryContext): ir.Expression = { + val expr = ctx.expression().accept(this) + ctx.op.getType match { + case MINUS => ir.UMinus(expr) + case PLUS => ir.UPlus(expr) + } + } + + override def visitExprOpPrec1(ctx: ExprOpPrec1Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryExpression(ctx.expression(0).accept(this), ctx.expression(1).accept(this), ctx.op) + } + + override def visitExprOpPrec2(ctx: ExprOpPrec2Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryExpression(ctx.expression(0).accept(this), ctx.expression(1).accept(this), ctx.op) + } + + override def visitExprOpPrec3(ctx: ExprOpPrec3Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryExpression(ctx.expression(0).accept(this), ctx.expression(1).accept(this), ctx.op) + } + + override def visitExprOpPrec4(ctx: ExprOpPrec4Context): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildBinaryExpression(ctx.expression(0).accept(this), ctx.expression(1).accept(this), ctx.op) + } + + /** + * Note that the dot operator is considerably more complex than the simple case of a.b. It can also have constructs + * such as Function().value etc. This is a simple implementation that assumes that we are building a string for a + * column or table name in contexts where we cannot specifically know that. + * + * TODO: Expand this to handle more complex cases + * + * @param ctx + * the parse tree + */ + override def visitExprDot(ctx: ExprDotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + (left, right) match { + case (c1: ir.Column, c2: ir.Column) => + val path = c1.columnName +: c2.tableNameOrAlias.map(ref => ref.head +: ref.tail).getOrElse(Nil) + ir.Column(Some(ir.ObjectReference(path.head, path.tail: _*)), c2.columnName) + case (_: ir.Column, c2: ir.CallFunction) => + vc.functionBuilder.functionType(c2.function_name) match { + case XmlFunction => tsql.TsqlXmlFunction(c2, left) + case _ => ir.Dot(left, right) + } + // Other cases + case _ => ir.Dot(left, right) + } + } + + override def visitExprCase(ctx: ExprCaseContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.caseExpression().accept(this) + } + + override def visitCaseExpression(ctx: CaseExpressionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val caseExpr = if (ctx.caseExpr != null) Option(ctx.caseExpr.accept(this)) else None + val elseExpr = if (ctx.elseExpr != null) Option(ctx.elseExpr.accept(this)) else None + val whenThenPairs: Seq[ir.WhenBranch] = ctx + .switchSection() + .asScala + .map(buildWhen) + + ir.Case(caseExpr, whenThenPairs, elseExpr) + } + + private def buildWhen(ctx: SwitchSectionContext): ir.WhenBranch = + ir.WhenBranch(ctx.searchCondition.accept(this), ctx.expression().accept(this)) + + override def visitExprFunc(ctx: ExprFuncContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.functionCall.accept(this) + } + + override def visitExprDollar(ctx: ExprDollarContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.DollarAction + } + + override def visitExprStar(ctx: ExprStarContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Star(None) + } + + override def visitExprFuncVal(ctx: ExprFuncValContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + vc.functionBuilder.buildFunction(ctx.getText, Seq.empty) + } + + override def visitExprPrimitive(ctx: ExprPrimitiveContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.primitiveExpression().accept(this) + } + + override def visitExprCollate(ctx: ExprCollateContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Collate(ctx.expression.accept(this), removeQuotes(ctx.id.getText)) + } + + override def visitPrimitiveExpression(ctx: PrimitiveExpressionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + Option(ctx.op).map(buildPrimitive).getOrElse(ctx.constant().accept(this)) + } + + override def visitConstant(ctx: TSqlParser.ConstantContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildPrimitive(ctx.con) + } + + override def visitExprSubquery(ctx: ExprSubqueryContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.ScalarSubquery(ctx.selectStatement().accept(vc.relationBuilder)) + } + + override def visitExprTz(ctx: ExprTzContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expression = ctx.expression().accept(this) + val timezone = ctx.timeZone.expression().accept(this) + ir.Timezone(expression, timezone) + } + + override def visitScNot(ctx: TSqlParser.ScNotContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Not(ctx.searchCondition().accept(this)) + } + + override def visitScAnd(ctx: TSqlParser.ScAndContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.And(ctx.searchCondition(0).accept(this), ctx.searchCondition(1).accept(this)) + } + + override def visitScOr(ctx: TSqlParser.ScOrContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Or(ctx.searchCondition(0).accept(this), ctx.searchCondition(1).accept(this)) + } + + override def visitScPred(ctx: TSqlParser.ScPredContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.predicate().accept(this) + } + + override def visitScPrec(ctx: TSqlParser.ScPrecContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.searchCondition.accept(this) + } + + override def visitPredExists(ctx: PredExistsContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.Exists(ctx.selectStatement().accept(vc.relationBuilder)) + } + + override def visitPredFreetext(ctx: PredFreetextContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // TODO: build FREETEXT ? + ir.UnresolvedExpression( + ruleText = contextText(ctx), + message = s"Freetext predicates are unsupported", + ruleName = "expression", + tokenName = Some(tokenName(ctx.getStart))) + } + + override def visitPredBinop(ctx: PredBinopContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + ctx.comparisonOperator match { + case op if op.LT != null && op.EQ != null => ir.LessThanOrEqual(left, right) + case op if op.GT != null && op.EQ != null => ir.GreaterThanOrEqual(left, right) + case op if op.LT != null && op.GT != null => ir.NotEquals(left, right) + case op if op.BANG != null && op.GT != null => ir.LessThanOrEqual(left, right) + case op if op.BANG != null && op.LT != null => ir.GreaterThanOrEqual(left, right) + case op if op.BANG != null && op.EQ != null => ir.NotEquals(left, right) + case op if op.EQ != null => ir.Equals(left, right) + case op if op.GT != null => ir.GreaterThan(left, right) + case op if op.LT != null => ir.LessThan(left, right) + } + } + + override def visitPredASA(ctx: PredASAContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // TODO: build ASA + ir.UnresolvedExpression( + ruleText = contextText(ctx), + message = s"ALL | SOME | ANY predicate not yet supported", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + + override def visitPredBetween(ctx: PredBetweenContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val lowerBound = ctx.expression(1).accept(this) + val upperBound = ctx.expression(2).accept(this) + val expression = ctx.expression(0).accept(this) + val between = ir.Between(expression, lowerBound, upperBound) + Option(ctx.NOT()).fold[ir.Expression](between)(_ => ir.Not(between)) + } + + override def visitPredIn(ctx: PredInContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val in = if (ctx.selectStatement() != null) { + // In the result of a sub query + ir.In(ctx.expression().accept(this), Seq(ir.ScalarSubquery(ctx.selectStatement().accept(vc.relationBuilder)))) + } else { + // In a list of expressions + ir.In(ctx.expression().accept(this), ctx.expressionList().expression().asScala.map(_.accept(this))) + } + Option(ctx.NOT()).fold[ir.Expression](in)(_ => ir.Not(in)) + } + + override def visitPredLike(ctx: PredLikeContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.expression(0).accept(this) + val right = ctx.expression(1).accept(this) + // NB: The escape character is a complete expression that evaluates to a single char at runtime + // and not a single char at parse time. + val escape = Option(ctx.expression(2)) + .map(_.accept(this)) + val like = ir.Like(left, right, escape) + Option(ctx.NOT()).fold[ir.Expression](like)(_ => ir.Not(like)) + } + + override def visitPredIsNull(ctx: PredIsNullContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expression = ctx.expression().accept(this) + if (ctx.NOT() != null) ir.IsNotNull(expression) else ir.IsNull(expression) + } + + override def visitPredExpression(ctx: PredExpressionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.expression().accept(this) + } + + override def visitFunctionCall(ctx: FunctionCallContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case b if b.builtInFunctions() != null => b.builtInFunctions().accept(this) + case s if s.standardFunction() != null => s.standardFunction().accept(this) + case f if f.freetextFunction() != null => f.freetextFunction().accept(this) + case p if p.partitionFunction() != null => p.partitionFunction().accept(this) + case h if h.hierarchyidStaticMethod() != null => h.hierarchyidStaticMethod().accept(this) + } + } + + override def visitId(ctx: IdContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + buildId(ctx) + } + + private[tsql] def buildId(ctx: IdContext): ir.Id = + ctx match { + case c if c.ID() != null => ir.Id(ctx.getText) + case c if c.TEMP_ID() != null => ir.Id(ctx.getText) + case c if c.DOUBLE_QUOTE_ID() != null => + ir.Id(ctx.getText.trim.stripPrefix("\"").stripSuffix("\""), caseSensitive = true) + case c if c.SQUARE_BRACKET_ID() != null => + ir.Id(ctx.getText.trim.stripPrefix("[").stripSuffix("]"), caseSensitive = true) + case c if c.RAW() != null => ir.Id(ctx.getText) + case _ => ir.Id(removeQuotes(ctx.getText)) + } + + override def visitJinjaTemplate(ctx: TSqlParser.JinjaTemplateContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ir.JinjaAsExpression(ctx.getText) + } + + private[tsql] def removeQuotes(str: String): String = { + str.stripPrefix("'").stripSuffix("'") + } + + private def buildBinaryExpression(left: ir.Expression, right: ir.Expression, operator: Token): ir.Expression = + operator.getType match { + case STAR => ir.Multiply(left, right) + case DIV => ir.Divide(left, right) + case MOD => ir.Mod(left, right) + case PLUS => ir.Add(left, right) + case MINUS => ir.Subtract(left, right) + case BIT_AND => ir.BitwiseAnd(left, right) + case BIT_XOR => ir.BitwiseXor(left, right) + case BIT_OR => ir.BitwiseOr(left, right) + case DOUBLE_BAR => ir.Concat(Seq(left, right)) + } + + private def buildPrimitive(con: Token): ir.Expression = con.getType match { + case DEFAULT => Default() + case LOCAL_ID => ir.Identifier(con.getText, isQuoted = false) + case STRING => ir.Literal(removeQuotes(con.getText)) + case NULL => ir.Literal.Null + case HEX => ir.Literal(con.getText) // Preserve format + case MONEY => Money(ir.StringLiteral(con.getText)) + case INT | REAL | FLOAT => ir.NumericLiteral(con.getText) + } + + override def visitStandardFunction(ctx: StandardFunctionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val name = ctx.funcId.getText + val args = Option(ctx.expression()).map(_.asScala.map(_.accept(this))).getOrElse(Seq.empty) + vc.functionBuilder.buildFunction(name, args) + } + + // Note that this visitor is made complicated and difficult because the built-in ir does not use options. + // So we build placeholder values for the optional values. They also do not extend expression + // so we can't build them logically with visit and accept. Maybe replace them with + // extensions that do this? + override def visitExprOver(ctx: ExprOverContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // The OVER clause is used to accept the IGNORE nulls clause that can be specified after certain + // windowing functions such as LAG or LEAD, so that the clause is manifest here. The syntax allows + // 'IGNORE NULLS' and 'RESPECT NULLS', but 'RESPECT NULLS' is the default behavior. + val windowFunction = + buildWindowingFunction(ctx.expression().accept(this)) + val partitionByExpressions = + Option(ctx.overClause().expression()).map(_.asScala.toList.map(_.accept(this))).getOrElse(List.empty) + val orderByExpressions = Option(ctx.overClause().orderByClause()) + .map(buildOrderBy) + .getOrElse(List.empty) + val windowFrame = Option(ctx.overClause().rowOrRangeClause()) + .map(buildWindowFrame) + + ir.Window( + windowFunction, + partitionByExpressions, + orderByExpressions, + windowFrame, + ctx.overClause().IGNORE() != null) + } + + // Some functions need to be converted to Databricks equivalent Windowing functions for the OVER clause + private def buildWindowingFunction(expression: ir.Expression): ir.Expression = expression match { + case ir.CallFunction("MONOTONICALLY_INCREASING_ID", args) => ir.CallFunction("ROW_NUMBER", args) + case _ => expression + } + + private def buildOrderBy(ctx: OrderByClauseContext): Seq[ir.SortOrder] = + ctx.orderByExpression().asScala.map { orderByExpr => + val expression = orderByExpr.expression(0).accept(this) + val sortOrder = + orderByExpr match { + case o if o.DESC() != null => ir.Descending + case o if o.ASC() != null => ir.Ascending + case _ => ir.UnspecifiedSortDirection + } + ir.SortOrder(expression, sortOrder, ir.SortNullsUnspecified) + } + + private def buildWindowFrame(ctx: RowOrRangeClauseContext): ir.WindowFrame = { + val frameType = buildFrameType(ctx) + val bounds = Trees + .findAllRuleNodes(ctx, TSqlParser.RULE_windowFrameBound) + .asScala + .collect { case wfb: WindowFrameBoundContext => wfb } + .map(buildFrame) + + val frameStart = bounds.head // Safe due to the nature of window frames always having at least a start bound + val frameEnd = + bounds.tail.headOption.getOrElse(ir.NoBoundary) + + ir.WindowFrame(frameType, frameStart, frameEnd) + } + + private def buildFrameType(ctx: RowOrRangeClauseContext): ir.FrameType = { + if (Option(ctx.ROWS()).isDefined) ir.RowsFrame + else ir.RangeFrame + } + + private[tsql] def buildFrame(ctx: WindowFrameBoundContext): ir.FrameBoundary = + ctx match { + case c if c.UNBOUNDED() != null && c.PRECEDING() != null => ir.UnboundedPreceding + case c if c.UNBOUNDED() != null && c.FOLLOWING() != null => ir.UnboundedFollowing + case c if c.CURRENT() != null => ir.CurrentRow + case c if c.INT() != null && c.PRECEDING() != null => + ir.PrecedingN(ir.Literal(c.INT().getText.toInt, ir.IntegerType)) + case c if c.INT() != null && c.FOLLOWING() != null => + ir.FollowingN(ir.Literal(c.INT().getText.toInt, ir.IntegerType)) + } + + override def visitExpressionElem(ctx: ExpressionElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val columnDef = ctx.expression().accept(this) + val aliasOption = Option(ctx.columnAlias()).map { alias => + val name = vc.relationBuilder.buildColumnAlias(alias) + ir.Alias(columnDef, name) + } + aliasOption.getOrElse(columnDef) + } + + override def visitExprWithinGroup(ctx: ExprWithinGroupContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expression = ctx.expression().accept(this) + val orderByExpressions = buildOrderBy(ctx.withinGroup().orderByClause()) + ir.WithinGroup(expression, orderByExpressions) + } + + override def visitExprDistinct(ctx: ExprDistinctContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // Support for functions such as COUNT(DISTINCT column), which is an expression not a child + ir.Distinct(ctx.expression().accept(this)) + } + + override def visitExprJinja(ctx: ExprJinjaContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.jinjaTemplate().accept(this) + } + + override def visitExprAll(ctx: ExprAllContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // Support for functions such as COUNT(ALL column), which is an expression not a child. + // ALL has no actual effect on the result so we just pass the expression as is. If we wish to + // reproduce existing annotations like this, then we would need to add IR. + ctx.expression().accept(this) + } + + override def visitPartitionFunction(ctx: PartitionFunctionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // $$PARTITION is not supported in Databricks SQL, so we will report it is not supported + vc.functionBuilder.buildFunction(s"$$PARTITION", List.empty) + } + + /** + * Handles the NEXT VALUE FOR function in SQL Server, which has a special syntax. + * + * @param ctx + * the parse tree + */ + override def visitNextValueFor(ctx: NextValueForContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val sequenceName = buildTableName(ctx.tableName()) + vc.functionBuilder.buildFunction("NEXTVALUEFOR", Seq(sequenceName)) + } + + override def visitCast(ctx: CastContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val expression = ctx.expression().accept(this) + val dataType = vc.dataTypeBuilder.build(ctx.dataType()) + ir.Cast(expression, dataType, returnNullOnError = ctx.TRY_CAST() != null) + } + + override def visitJsonArray(ctx: JsonArrayContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val elements = buildExpressionList(Option(ctx.expressionList())) + val absentOnNull = checkAbsentNull(ctx.jsonNullClause()) + buildJsonArray(elements, absentOnNull) + } + + override def visitJsonObject(ctx: JsonObjectContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val jsonKeyValues = Option(ctx.jsonKeyValue()).map(_.asScala).getOrElse(Nil) + val namedStruct = buildNamedStruct(jsonKeyValues) + val absentOnNull = checkAbsentNull(ctx.jsonNullClause()) + buildJsonObject(namedStruct, absentOnNull) + } + + override def visitFreetextFunction(ctx: FreetextFunctionContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // Databricks SQL does not support FREETEXT functions, so there is no point in trying to convert these + // functions. We do need to generate IR that indicates that this is a function that is not supported. + vc.functionBuilder.buildFunction(ctx.f.getText, List.empty) + } + + override def visitHierarchyidStaticMethod(ctx: HierarchyidStaticMethodContext): ir.Expression = errorCheck( + ctx) match { + case Some(errorResult) => errorResult + case None => + // Databricks SQL does not support HIERARCHYID functions, so there is no point in trying to convert these + // functions. We do need to generate IR that indicates that this is a function that is not supported. + vc.functionBuilder.buildFunction("HIERARCHYID", List.empty) + } + + override def visitOutputDmlListElem(ctx: OutputDmlListElemContext): ir.Expression = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val columnDef = Option(ctx.expression()).map(_.accept(this)).getOrElse(ctx.asterisk().accept(this)) + val aliasOption = Option(ctx.columnAlias()).map { alias => + val name = vc.relationBuilder.buildColumnAlias(alias) + ir.Alias(columnDef, name) + } + aliasOption.getOrElse(columnDef) + } + + /** + * Check if the ABSENT ON NULL clause is present in the JSON clause. The behavior is as follows: + *
    + *
  • If the clause does not exist, the ABSENT ON NULL is assumed - so true
  • + *
  • If the clause exists and ABSENT ON NULL - true
  • + *
  • If the clause exists and NULL ON NULL - false
  • + *
+ * + * @param ctx + * null clause parser context + * @return + */ + private def checkAbsentNull(ctx: JsonNullClauseContext): Boolean = { + Option(ctx).forall(_.loseNulls != null) + } + + private def buildNamedStruct(ctx: Seq[JsonKeyValueContext]): ir.NamedStruct = { + val (keys, values) = ctx.map { keyValueContext => + val expressions = keyValueContext.expression().asScala.toList + (expressions.head.accept(this), expressions(1).accept(this)) + }.unzip + + ir.NamedStruct(keys, values) + } + + private def buildExpressionList(ctx: Option[ExpressionListContext]): Seq[ir.Expression] = { + ctx.map(_.expression().asScala.map(_.accept(this))).getOrElse(Seq.empty) + } + + /** + * Databricks SQL does not have a native JSON_ARRAY function, so we use a Lambda filter and TO_JSON instead, but have + * to cater for the case where an expression is NULL and the TSql option ABSENT ON NULL is set. When ABSENT ON NULL is + * set, then any NULL expressions are left out of the JSON array. + * + * @param args + * the list of expressions yield JSON values + * @param absentOnNull + * whether we should remove NULL values from the JSON array + * @return + * IR for the JSON_ARRAY function + */ + private[tsql] def buildJsonArray(args: Seq[ir.Expression], absentOnNull: Boolean): ir.Expression = { + if (absentOnNull) { + val lambdaVariable = ir.UnresolvedNamedLambdaVariable(Seq("x")) + val lambdaBody = ir.Not(ir.IsNull(lambdaVariable)) + val lambdaFunction = ir.LambdaFunction(lambdaBody, Seq(lambdaVariable)) + val filter = ir.FilterExpr(args, lambdaFunction) + ir.CallFunction("TO_JSON", Seq(ir.ValueArray(Seq(filter)))) + } else { + ir.CallFunction("TO_JSON", Seq(ir.ValueArray(args))) + } + } + + /** + * Databricks SQL does not have a native JSON_OBJECT function, so we use a Lambda filter and TO_JSON instead, but have + * to cater for the case where an expression is NULL and the TSql option ABSENT ON NULL is set. When ABSENT ON NULL is + * set, then any NULL expressions are left out of the JSON object. + * + * @param namedStruct + * the named struct of key-value pairs + * @param absentOnNull + * whether we should remove NULL values from the JSON object + * @return + * IR for the JSON_OBJECT function + */ + // TODO: This is not likely the correct way to handle this, but it is a start + // maybe needs external function at runtime + private[tsql] def buildJsonObject(namedStruct: ir.NamedStruct, absentOnNull: Boolean): ir.Expression = { + if (absentOnNull) { + val lambdaVariables = ir.UnresolvedNamedLambdaVariable(Seq("k", "v")) + val valueVariable = ir.UnresolvedNamedLambdaVariable(Seq("v")) + val lambdaBody = ir.Not(ir.IsNull(valueVariable)) + val lambdaFunction = ir.LambdaFunction(lambdaBody, Seq(lambdaVariables)) + val filter = ir.FilterStruct(namedStruct, lambdaFunction) + ir.CallFunction("TO_JSON", Seq(filter)) + } else { + ir.CallFunction("TO_JSON", Seq(namedStruct)) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilder.scala new file mode 100644 index 0000000000..b4fe507923 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilder.scala @@ -0,0 +1,81 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.{FunctionBuilder, FunctionDefinition, StringConverter} +import com.databricks.labs.remorph.{intermediate => ir} + +class TSqlFunctionBuilder extends FunctionBuilder with StringConverter { + + private[this] val tSqlFunctionDefinitionPf: PartialFunction[String, FunctionDefinition] = { + case """$PARTITION""" => FunctionDefinition.notConvertible(0) + case "@@CURSOR_ROWS" => FunctionDefinition.notConvertible(0) + case "@@DBTS" => FunctionDefinition.notConvertible(0) + case "@@FETCH_STATUS" => FunctionDefinition.notConvertible(0) + case "@@LANGID" => FunctionDefinition.notConvertible(0) + case "@@LANGUAGE" => FunctionDefinition.notConvertible(0) + case "@@LOCKTIMEOUT" => FunctionDefinition.notConvertible(0) + case "@@MAX_CONNECTIONS" => FunctionDefinition.notConvertible(0) + case "@@MAX_PRECISION" => FunctionDefinition.notConvertible(0) + case "@@NESTLEVEL" => FunctionDefinition.notConvertible(0) + case "@@OPTIONS" => FunctionDefinition.notConvertible(0) + case "@@REMSERVER" => FunctionDefinition.notConvertible(0) + case "@@SERVERNAME" => FunctionDefinition.notConvertible(0) + case "@@SERVICENAME" => FunctionDefinition.notConvertible(0) + case "@@SPID" => FunctionDefinition.notConvertible(0) + case "@@TEXTSIZE" => FunctionDefinition.notConvertible(0) + case "@@VERSION" => FunctionDefinition.notConvertible(0) + case "COLLATIONPROPERTY" => FunctionDefinition.notConvertible(2) + case "CONTAINSTABLE" => FunctionDefinition.notConvertible(0) + case "CUBE" => FunctionDefinition.standard(1, Int.MaxValue) // Snowflake hard codes this + case "FREETEXTTABLE" => FunctionDefinition.notConvertible(0) + case "GET_BIT" => FunctionDefinition.standard(2).withConversionStrategy(rename) + case "HIERARCHYID" => FunctionDefinition.notConvertible(0) + case "ISNULL" => FunctionDefinition.standard(2).withConversionStrategy(rename) + case "LEFT_SHIFT" => FunctionDefinition.standard(2).withConversionStrategy(rename) + case "MODIFY" => FunctionDefinition.xml(1) + case "NEXTVALUEFOR" => FunctionDefinition.standard(1).withConversionStrategy(nextValueFor) + case "RIGHT_SHIFT" => FunctionDefinition.standard(2).withConversionStrategy(rename) + case "ROLLUP" => FunctionDefinition.standard(1, Int.MaxValue) // Snowflake hard codes this + case "SEMANTICKEYPHRASETABLE" => FunctionDefinition.notConvertible(0) + case "SEMANTICSIMILARITYDETAILSTABLE" => FunctionDefinition.notConvertible(0) + case "SEMANTICSSIMILARITYTABLE" => FunctionDefinition.notConvertible(0) + case "SET_BIT" => FunctionDefinition.standard(2, 3).withConversionStrategy(rename) + } + + override def functionDefinition(name: String): Option[FunctionDefinition] = + // If not found, check common functions + tSqlFunctionDefinitionPf.lift(name.toUpperCase()).orElse(super.functionDefinition(name)) + + def applyConversionStrategy( + functionArity: FunctionDefinition, + args: Seq[ir.Expression], + irName: String): ir.Expression = { + functionArity.conversionStrategy match { + case Some(strategy) => strategy.convert(irName, args) + case _ => ir.CallFunction(irName, args) + } + } + + // TSql specific function converters + // + private[tsql] def nextValueFor(irName: String, args: Seq[ir.Expression]): ir.Expression = { + // Note that this conversion assumes that the CREATE SEQUENCE it references was an increment in ascending order. + // We may run across instances where this is not the case, and will have to handle that as a special case, perhaps + // with external procedures or functions in Java/Scala, or even python. + // For instance a SequenceHandler supplied by the user. + // + // Given this, then we use this converter rather than just the simple Rename converter. + // TODO: Implement external SequenceHandler? + ir.CallFunction("MONOTONICALLY_INCREASING_ID", List.empty) + } + + private[tsql] def rename(irName: String, args: Seq[ir.Expression]): ir.Expression = { + irName.toUpperCase() match { + case "ISNULL" => ir.CallFunction(convertString(irName, "IFNULL"), args) + case "GET_BIT" => ir.CallFunction(convertString(irName, "GETBIT"), args) + case "LEFT_SHIFT" => ir.CallFunction(convertString(irName, "LEFTSHIFT"), args) + case "RIGHT_SHIFT" => ir.CallFunction(convertString(irName, "RIGHTSHIFT"), args) + case _ => ir.CallFunction(irName, args) + } + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlPlanParser.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlPlanParser.scala new file mode 100644 index 0000000000..11c7dd843f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlPlanParser.scala @@ -0,0 +1,29 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.parsers.tsql.rules.{PullLimitUpwards, TSqlCallMapper, TopPercentToLimitSubquery, TrapInsertDefaultsAction} +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime._ + +class TSqlPlanParser extends PlanParser[TSqlParser] { + + val vc = new TSqlVisitorCoordinator(TSqlParser.VOCABULARY, TSqlParser.ruleNames) + + override protected def createLexer(input: CharStream): Lexer = new TSqlLexer(input) + override protected def createParser(stream: TokenStream): TSqlParser = new TSqlParser(stream) + override protected def createTree(parser: TSqlParser): ParserRuleContext = parser.tSqlFile() + override protected def createPlan(tree: ParserRuleContext): ir.LogicalPlan = vc.astBuilder.visit(tree) + override protected def addErrorStrategy(parser: TSqlParser): Unit = parser.setErrorHandler(new TSqlErrorStrategy) + def dialect: String = "tsql" + + // TODO: Note that this is not the correct place for the optimizer, but it is here for now + override protected def createOptimizer: ir.Rules[ir.LogicalPlan] = { + ir.Rules( + new TSqlCallMapper, + ir.AlwaysUpperNameForCallFunction, + PullLimitUpwards, + new TopPercentToLimitSubquery, + TrapInsertDefaultsAction) + } + +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilder.scala new file mode 100644 index 0000000000..180fbefa8f --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilder.scala @@ -0,0 +1,618 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.ParserCommon +import com.databricks.labs.remorph.parsers.tsql.TSqlParser.{StringContext => _, _} +import com.databricks.labs.remorph.parsers.tsql.rules.{InsertDefaultsAction, TopPercent} +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.ParserRuleContext + +import scala.collection.JavaConverters.asScalaBufferConverter + +class TSqlRelationBuilder(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[ir.LogicalPlan] + with ParserCommon[ir.LogicalPlan] { + + // The default result is returned when there is no visitor implemented, and we produce an unresolved + // object to represent the input that we have no visitor for. + protected override def unresolved(ruleText: String, message: String): ir.LogicalPlan = + ir.UnresolvedRelation(ruleText = ruleText, message = message) + + // Concrete visitors + + override def visitCommonTableExpression(ctx: CommonTableExpressionContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableName = vc.expressionBuilder.buildId(ctx.id()) + // Column list can be empty if the select specifies distinct column names + val columns = + Option(ctx.columnNameList()) + .map(_.id().asScala.map(vc.expressionBuilder.buildId)) + .getOrElse(Seq.empty) + val query = ctx.selectStatement().accept(this) + ir.SubqueryAlias(query, tableName, columns) + } + + override def visitSelectStatementStandalone(ctx: TSqlParser.SelectStatementStandaloneContext): ir.LogicalPlan = + errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val query = ctx.selectStatement().accept(this) + Option(ctx.withExpression()) + .map { withExpression => + val ctes = withExpression.commonTableExpression().asScala.map(_.accept(this)) + ir.WithCTE(ctes, query) + } + .getOrElse(query) + } + + override def visitSelectStatement(ctx: TSqlParser.SelectStatementContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + // TODO: The FOR clause of TSQL is not supported in Databricks SQL as XML and JSON are not supported + // we need to create an UnresolvedRelation for it + + // We visit the OptionClause because in the future, we may be able to glean information from it + // as an aid to migration, however the clause is not used in the AST or translation. + val query = ctx.queryExpression.accept(this) + Option(ctx.optionClause) match { + case Some(optionClause) => ir.WithOptions(query, optionClause.accept(vc.expressionBuilder)) + case None => query + } + } + + override def visitQueryInParenthesis(ctx: TSqlParser.QueryInParenthesisContext): ir.LogicalPlan = { + errorCheck(ctx).getOrElse(visit(ctx.queryExpression())) + } + + override def visitQueryUnion(ctx: TSqlParser.QueryUnionContext): ir.LogicalPlan = errorCheck(ctx).getOrElse { + val Seq(lhs, rhs) = ctx.queryExpression().asScala.map(visit) + val setOp = ctx match { + case u if u.UNION() != null => ir.UnionSetOp + case e if e.EXCEPT() != null => ir.ExceptSetOp + } + val isAll = ctx.ALL() != null + ir.SetOperation(lhs, rhs, setOp, is_all = isAll, by_name = false, allow_missing_columns = false) + } + + override def visitQueryIntersect(ctx: TSqlParser.QueryIntersectContext): ir.LogicalPlan = errorCheck(ctx).getOrElse { + val Seq(lhs, rhs) = ctx.queryExpression().asScala.map(visit) + ir.SetOperation(lhs, rhs, ir.IntersectSetOp, is_all = false, by_name = false, allow_missing_columns = false) + } + + override def visitQuerySimple(ctx: TSqlParser.QuerySimpleContext): ir.LogicalPlan = errorCheck(ctx).getOrElse { + visitQuerySpecification(ctx.querySpecification()) + } + + override def visitQuerySpecification(ctx: TSqlParser.QuerySpecificationContext): ir.LogicalPlan = errorCheck( + ctx) match { + case Some(errorResult) => errorResult + case None => + // TODO: Check the logic here for all the elements of a query specification + val select = ctx.selectOptionalClauses().accept(this) + + // A single column definition could also hold an ErrorNode that it recovered from so we collect all of them + val columns: Seq[ir.Expression] = + vc.expressionBuilder.buildSelectList(ctx.selectList()) + // Note that ALL is the default so we don't need to check for it + ctx match { + case c if c.DISTINCT() != null => + ir.Project(buildTop(Option(ctx.topClause()), buildDistinct(select, columns)), columns) + case _ => + ir.Project(buildTop(Option(ctx.topClause()), select), columns) + } + } + + private[tsql] def buildTop(ctxOpt: Option[TSqlParser.TopClauseContext], input: ir.LogicalPlan): ir.LogicalPlan = + ctxOpt.fold(input) { top => + val limit = top.expression().accept(vc.expressionBuilder) + if (top.PERCENT() != null) { + TopPercent(input, limit, with_ties = top.TIES() != null) + } else { + ir.Limit(input, limit) + } + } + + override def visitSelectOptionalClauses(ctx: SelectOptionalClausesContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val from = Option(ctx.fromClause()).map(_.accept(this)).getOrElse(ir.NoTable) + buildOrderBy( + ctx.selectOrderByClause(), + buildHaving(ctx.havingClause(), buildGroupBy(ctx.groupByClause(), buildWhere(ctx.whereClause(), from)))) + } + + private def buildFilter[A](ctx: A, conditionRule: A => ParserRuleContext, input: ir.LogicalPlan): ir.LogicalPlan = + Option(ctx).fold(input) { c => + ir.Filter(input, conditionRule(c).accept(vc.expressionBuilder)) + } + + private def buildHaving(ctx: HavingClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = + buildFilter[HavingClauseContext](ctx, _.searchCondition(), input) + + private def buildWhere(ctx: WhereClauseContext, from: ir.LogicalPlan): ir.LogicalPlan = + buildFilter[WhereClauseContext](ctx, _.searchCondition(), from) + + // TODO: We are not catering for GROUPING SETS here, or in Snowflake + private def buildGroupBy(ctx: GroupByClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx).fold(input) { c => + val groupingExpressions = + c.expression() + .asScala + .map(_.accept(vc.expressionBuilder)) + ir.Aggregate(child = input, group_type = ir.GroupBy, grouping_expressions = groupingExpressions, pivot = None) + } + } + + private def buildOrderBy(ctx: SelectOrderByClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { + Option(ctx).fold(input) { c => + val sortOrders = c.orderByClause().orderByExpression().asScala.map { orderItem => + val expression = orderItem.expression(0).accept(vc.expressionBuilder) + // orderItem.expression(1) is COLLATE - we will not support that, but should either add a comment in the + // translated source or raise some kind of linting alert. + if (orderItem.DESC() == null) { + ir.SortOrder(expression, ir.Ascending, ir.SortNullsUnspecified) + } else { + ir.SortOrder(expression, ir.Descending, ir.SortNullsUnspecified) + } + } + val sorted = ir.Sort(input, sortOrders, is_global = false) + + // Having created the IR for ORDER BY, we now need to apply any OFFSET, and then any FETCH + if (ctx.OFFSET() != null) { + val offset = ir.Offset(sorted, ctx.expression(0).accept(vc.expressionBuilder)) + if (ctx.FETCH() != null) { + ir.Limit(offset, ctx.expression(1).accept(vc.expressionBuilder)) + } else { + offset + } + } else { + sorted + } + } + } + + override def visitFromClause(ctx: FromClauseContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tableSources = ctx.tableSources().tableSource().asScala.map(_.accept(this)) + // The tableSources seq cannot be empty (as empty FROM clauses are not allowed + tableSources match { + case Seq(tableSource) => tableSource + case sources => + sources.reduce( + ir.Join(_, _, None, ir.CrossJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + } + } + + private def buildDistinct(from: ir.LogicalPlan, columns: Seq[ir.Expression]): ir.LogicalPlan = { + val columnNames = columns.collect { + case ir.Column(_, c) => c + case ir.Alias(_, a) => a + // Note that the ir.Star(None) is not matched so that we set all_columns_as_keys to true + } + ir.Deduplicate(from, columnNames, all_columns_as_keys = columnNames.isEmpty, within_watermark = false) + } + + override def visitTableName(ctx: TableNameContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val linkedServer = Option(ctx.linkedServer).map(_.getText) + val ids = ctx.ids.asScala.map(_.getText).mkString(".") + val fullName = linkedServer.fold(ids)(ls => s"$ls..$ids") + ir.NamedTable(fullName, Map.empty, is_streaming = false) + } + + override def visitTableSource(ctx: TableSourceContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val left = ctx.tableSourceItem().accept(this) + ctx match { + case c if c.joinPart() != null => c.joinPart().asScala.foldLeft(left)(buildJoinPart) + } + } + + override def visitTableSourceItem(ctx: TableSourceItemContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val tsiElement = ctx.tsiElement().accept(this) + + // Assemble any table hints, though we do nothing with them for now + val hints = buildTableHints(Option(ctx.withTableHints())) + + // If we have column aliases, they are applied here first + val tsiElementWithAliases = Option(ctx.columnAliasList()) + .map { aliasList => + val aliases = aliasList.columnAlias().asScala.map(id => buildColumnAlias(id)) + ColumnAliases(tsiElement, aliases) + } + .getOrElse(tsiElement) + + val relation = if (hints.nonEmpty) { + ir.TableWithHints(tsiElementWithAliases, hints) + } else { + tsiElementWithAliases + } + + // Then any table alias is applied to the source + Option(ctx.asTableAlias()) + .map(alias => ir.TableAlias(relation, alias.id.getText)) + .getOrElse(relation) + } + + // Table hints arrive syntactically as a () delimited list of options and, in the + // case of deprecated hint syntax, as a list of generic options without (). Here, + // we build a single map from both sources, either or both of which may be empty. + // In true TSQL style, some of the hints have non-orthodox syntax, and must be handled + // directly. + private[tsql] def buildTableHints(ctx: Option[WithTableHintsContext]): Seq[ir.TableHint] = { + ctx.map(_.tableHint().asScala.map(buildHint).toList).getOrElse(Seq.empty) + } + + private def buildHint(ctx: TableHintContext): ir.TableHint = { + ctx match { + case index if index.INDEX() != null => + ir.IndexHint(index.expressionList().expression().asScala.map { expr => + expr.accept(vc.expressionBuilder) match { + case column: ir.Column => column.columnName + case other => other + } + }) + case force if force.FORCESEEK() != null => + val name = Option(force.expression()).map(_.accept(vc.expressionBuilder)) + val columns = Option(force.columnNameList()).map(_.id().asScala.map(_.accept(vc.expressionBuilder))) + ir.ForceSeekHint(name, columns) + case _ => + val option = vc.optionBuilder.buildOption(ctx.genericOption()) + ir.FlagHint(option.id) + } + } + + private[tsql] def buildColumnAlias(ctx: TSqlParser.ColumnAliasContext): ir.Id = { + ctx match { + case c if c.id() != null => vc.expressionBuilder.buildId(c.id()) + case t if t.jinjaTemplate() != null => ir.Id(vc.expressionBuilder.removeQuotes(ctx.jinjaTemplate.getText)) + case _ => ir.Id(vc.expressionBuilder.removeQuotes(ctx.getText)) + } + } + + override def visitTsiNamedTable(ctx: TsiNamedTableContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.tableName().accept(this) + } + + override def visitTsiDerivedTable(ctx: TsiDerivedTableContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx.derivedTable().accept(this) + } + + override def visitDerivedTable(ctx: DerivedTableContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val result = if (ctx.tableValueConstructor() != null) { + ctx.tableValueConstructor().accept(this) + } else { + ctx.selectStatement().accept(this) + } + result + } + + override def visitTsiJinja(ctx: TsiJinjaContext): ir.LogicalPlan = + errorCheck(ctx).getOrElse(ctx.jinjaTemplate().accept(this)) + + override def visitJinjaTemplate(ctx: TSqlParser.JinjaTemplateContext): ir.LogicalPlan = + errorCheck(ctx).getOrElse(ir.JinjaAsStatement(ctx.getText)) + + override def visitTableValueConstructor(ctx: TableValueConstructorContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val rows = ctx.tableValueRow().asScala.map(buildValueRow) + DerivedRows(rows) + } + + private def buildValueRow(ctx: TableValueRowContext): Seq[ir.Expression] = { + ctx.expressionList().expression().asScala.map(_.accept(vc.expressionBuilder)) + } + + override def visitMerge(ctx: MergeContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val targetPlan = ctx.ddlObject().accept(this) + val hints = buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(targetPlan, hints) + } else { + targetPlan + } + + val mergeCondition = ctx.searchCondition().accept(vc.expressionBuilder) + val tableSourcesPlan = ctx.tableSources().tableSource().asScala.map(_.accept(this)) + // Reduce is safe: Grammar rule for tableSources ensures that there is always at least one tableSource. + val sourcePlan = tableSourcesPlan.reduceLeft( + ir.Join(_, _, None, ir.CrossJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + + // We may have a number of when clauses, each with a condition and an action. We keep the ANTLR syntax compact + // and lean and determine which of the three types of action we have in the whenMatch method based on + // the presence or absence of syntactical elements NOT and SOURCE as SOURCE can only be used with NOT + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = Option(ctx.whenMatch()) + .map(_.asScala.foldLeft((List.empty[ir.MergeAction], List.empty[ir.MergeAction], List.empty[ir.MergeAction])) { + case ((matched, notMatched, notMatchedBySource), m) => + val action = buildWhenMatch(m) + (m.NOT(), m.SOURCE()) match { + case (null, _) => (action :: matched, notMatched, notMatchedBySource) + case (_, null) => (matched, action :: notMatched, notMatchedBySource) + case _ => (matched, notMatched, action :: notMatchedBySource) + } + }) + .getOrElse((List.empty, List.empty, List.empty)) + + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + val outputClause = Option(ctx.outputClause()).map(_.accept(this)) + + val mergeIntoTable = ir.MergeIntoTable( + finalTarget, + sourcePlan, + mergeCondition, + matchedActions, + notMatchedActions, + notMatchedBySourceActions) + + val withOptions = optionClause match { + case Some(option) => ir.WithOptions(mergeIntoTable, option) + case None => mergeIntoTable + } + + outputClause match { + case Some(output) => WithOutputClause(withOptions, output) + case None => withOptions + } + } + + private def buildWhenMatch(ctx: WhenMatchContext): ir.MergeAction = { + val condition = Option(ctx.searchCondition()).map(_.accept(vc.expressionBuilder)) + ctx.mergeAction() match { + case action if action.DELETE() != null => ir.DeleteAction(condition) + case action if action.UPDATE() != null => buildUpdateAction(action, condition) + case action if action.INSERT() != null => buildInsertAction(action, condition) + } + } + + private def buildInsertAction(ctx: MergeActionContext, condition: Option[ir.Expression]): ir.MergeAction = { + + ctx match { + case action if action.DEFAULT() != null => InsertDefaultsAction(condition) + case _ => + val assignments = + (ctx.cols + .expression() + .asScala + .map(_.accept(vc.expressionBuilder)) zip ctx.vals.expression().asScala.map(_.accept(vc.expressionBuilder))) + .map { case (col, value) => + ir.Assign(col, value) + } + ir.InsertAction(condition, assignments) + } + } + + private def buildUpdateAction(ctx: MergeActionContext, condition: Option[ir.Expression]): ir.UpdateAction = { + val setElements = ctx.updateElem().asScala.collect { case elem => + elem.accept(vc.expressionBuilder) match { + case assign: ir.Assign => assign + } + } + ir.UpdateAction(condition, setElements) + } + + override def visitUpdate(ctx: UpdateContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(this) + val hints = buildTableHints(Option(ctx.withTableHints())) + val hintTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val finalTarget = buildTop(Option(ctx.topClause()), hintTarget) + val output = Option(ctx.outputClause()).map(_.accept(this)) + val setElements = ctx.updateElem().asScala.map(_.accept(vc.expressionBuilder)) + + val sourceRelation = buildTableSourcesPlan(Option(ctx.tableSources())) + val where = Option(ctx.updateWhereClause()) map (_.accept(vc.expressionBuilder)) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.UpdateTable(finalTarget, sourceRelation, setElements, where, output, optionClause) + } + + override def visitDelete(ctx: DeleteContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(this) + val hints = buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val output = Option(ctx.outputClause()).map(_.accept(this)) + val sourceRelation = buildTableSourcesPlan(Option(ctx.tableSources())) + val where = Option(ctx.updateWhereClause()) map (_.accept(vc.expressionBuilder)) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.DeleteFromTable(finalTarget, sourceRelation, where, output, optionClause) + } + + private[this] def buildTableSourcesPlan(tableSources: Option[TableSourcesContext]): Option[ir.LogicalPlan] = { + val sources = tableSources + .map(_.tableSource().asScala) + .getOrElse(Seq()) + .map(_.accept(vc.relationBuilder)) + sources.reduceLeftOption( + ir.Join(_, _, None, ir.CrossJoin, Seq(), ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + } + + override def visitInsert(ctx: InsertContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val target = ctx.ddlObject().accept(this) + val hints = buildTableHints(Option(ctx.withTableHints())) + val finalTarget = if (hints.nonEmpty) { + ir.TableWithHints(target, hints) + } else { + target + } + + val columns = Option(ctx.expressionList()) + .map(_.expression().asScala.map(_.accept(vc.expressionBuilder)).collect { case col: ir.Column => + col.columnName + }) + + val output = Option(ctx.outputClause()).map(_.accept(this)) + val values = ctx.insertStatementValue().accept(this) + val optionClause = Option(ctx.optionClause).map(_.accept(vc.expressionBuilder)) + ir.InsertIntoTable(finalTarget, columns, values, output, optionClause, overwrite = false) + } + + override def visitInsertStatementValue(ctx: InsertStatementValueContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + Option(ctx) match { + case Some(context) if context.derivedTable() != null => context.derivedTable().accept(this) + case Some(context) if context.VALUES() != null => DefaultValues() + case Some(context) => context.executeStatement().accept(this) + } + } + + override def visitOutputClause(ctx: OutputClauseContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + val outputs = ctx.outputDmlListElem().asScala.map(_.accept(vc.expressionBuilder)) + val target = Option(ctx.ddlObject()).map(_.accept(this)) + val columns = + Option(ctx.columnNameList()) + .map(_.id().asScala.map(id => ir.Column(None, vc.expressionBuilder.buildId(id)))) + + // Databricks SQL does not support the OUTPUT clause, but we may be able to translate + // the clause to SELECT statements executed before or after the INSERT/DELETE/UPDATE/MERGE + // is executed + Output(target, outputs, columns) + } + + override def visitDdlObject(ctx: DdlObjectContext): ir.LogicalPlan = errorCheck(ctx) match { + case Some(errorResult) => errorResult + case None => + ctx match { + case tableName if tableName.tableName() != null => tableName.tableName().accept(this) + case localId if localId.LOCAL_ID() != null => ir.LocalVarTable(ir.Id(localId.LOCAL_ID().getText)) + // TODO: OPENROWSET and OPENQUERY + case _ => + ir.UnresolvedRelation( + ruleText = contextText(ctx), + message = s"Unknown DDL object type ${ctx.getStart.getText} in TSqlRelationBuilder.visitDdlObject", + ruleName = vc.ruleName(ctx), + tokenName = Some(tokenName(ctx.getStart))) + } + } + + private def buildJoinPart(left: ir.LogicalPlan, ctx: JoinPartContext): ir.LogicalPlan = { + ctx match { + case c if c.joinOn() != null => buildJoinOn(left, c.joinOn()) + case c if c.crossJoin() != null => buildCrossJoin(left, c.crossJoin()) + case c if c.apply() != null => buildApply(left, c.apply()) + case c if c.pivot() != null => buildPivot(left, c.pivot()) + case _ => buildUnpivot(left, ctx.unpivot()) // Only case left + } + } + + private def buildUnpivot(left: ir.LogicalPlan, ctx: UnpivotContext): ir.LogicalPlan = { + val unpivotColumns = ctx + .unpivotClause() + .fullColumnNameList() + .fullColumnName() + .asScala + .map(_.accept(vc.expressionBuilder)) + val variableColumnName = vc.expressionBuilder.buildId(ctx.unpivotClause().id(0)) + val valueColumnName = vc.expressionBuilder.buildId(ctx.unpivotClause().id(1)) + ir.Unpivot( + child = left, + ids = unpivotColumns, + values = None, + variable_column_name = variableColumnName, + value_column_name = valueColumnName) + } + + private def buildPivot(left: ir.LogicalPlan, ctx: PivotContext): ir.LogicalPlan = { + // Though the pivotClause allows expression, it must be a function call and we require + // correct source code to be given to remorph. + val aggregateFunction = ctx.pivotClause().expression().accept(vc.expressionBuilder) + val column = ctx.pivotClause().fullColumnName().accept(vc.expressionBuilder) + // TODO: All other aliases are handled as ir.Id, but here we use ir.Literal - should it change? + val values = ctx.pivotClause().columnAliasList().columnAlias().asScala.map(c => buildLiteral(c.getText)) + ir.Aggregate( + child = left, + group_type = ir.Pivot, + grouping_expressions = Seq(aggregateFunction), + pivot = Some(ir.Pivot(column, values))) + } + + private def buildLiteral(str: String): ir.Expression = ir.Literal(removeQuotesAndBrackets(str)) + + private def buildApply(left: ir.LogicalPlan, ctx: ApplyContext): ir.LogicalPlan = { + val rightRelation = ctx.tableSourceItem().accept(this) + ir.Join( + left, + rightRelation, + None, + if (ctx.CROSS() != null) ir.CrossApply else ir.OuterApply, + Seq.empty, + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) + } + + private def buildCrossJoin(left: ir.LogicalPlan, ctx: CrossJoinContext): ir.LogicalPlan = { + val rightRelation = ctx.tableSourceItem().accept(this) + ir.Join( + left, + rightRelation, + None, + ir.CrossJoin, + Seq.empty, + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) + } + + private def buildJoinOn(left: ir.LogicalPlan, ctx: JoinOnContext): ir.Join = { + val rightRelation = ctx.tableSource().accept(this) + val joinCondition = ctx.searchCondition().accept(vc.expressionBuilder) + + ir.Join( + left, + rightRelation, + Some(joinCondition), + translateJoinType(ctx), + Seq.empty, + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) + } + + private[tsql] def translateJoinType(ctx: JoinOnContext): ir.JoinType = ctx.joinType() match { + case jt if jt == null || jt.outerJoin() == null || jt.INNER() != null => ir.InnerJoin + case jt if jt.outerJoin().LEFT() != null => ir.LeftOuterJoin + case jt if jt.outerJoin().RIGHT() != null => ir.RightOuterJoin + case jt if jt.outerJoin().FULL() != null => ir.FullOuterJoin + case _ => ir.UnspecifiedJoin + } + + private def removeQuotesAndBrackets(str: String): String = { + val quotations = Map('\'' -> "'", '"' -> "\"", '[' -> "]", '\\' -> "\\") + str match { + case s if s.length < 2 => s + case s => + quotations.get(s.head).fold(s) { closingQuote => + if (s.endsWith(closingQuote)) { + s.substring(1, s.length - 1) + } else { + s + } + } + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlVisitorCoordinator.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlVisitorCoordinator.scala new file mode 100644 index 0000000000..e1048e999e --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlVisitorCoordinator.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.VisitorCoordinator +import org.antlr.v4.runtime.Vocabulary + +class TSqlVisitorCoordinator(parserVocab: Vocabulary, ruleNames: Array[String]) + extends VisitorCoordinator(parserVocab, ruleNames) { + + val astBuilder = new TSqlAstBuilder(this) + val relationBuilder = new TSqlRelationBuilder(this) + val expressionBuilder = new TSqlExpressionBuilder(this) + val dmlBuilder = new TSqlDMLBuilder(this) + val ddlBuilder = new TSqlDDLBuilder(this) + val functionBuilder = new TSqlFunctionBuilder + + // TSQL extension + val optionBuilder = new OptionBuilder(this) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/expressions.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/expressions.scala new file mode 100644 index 0000000000..27cf001bf9 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/expressions.scala @@ -0,0 +1,25 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate._ + +// Specialized function calls, such as XML functions that usually apply to columns +case class TsqlXmlFunction(function: CallFunction, column: Expression) extends Binary(function, column) { + override def dataType: DataType = UnresolvedType +} + +case class Money(value: Literal) extends Unary(value) { + override def dataType: DataType = UnresolvedType +} + +case class Deleted(selection: Expression) extends Unary(selection) { + override def dataType: DataType = selection.dataType +} + +case class Inserted(selection: Expression) extends Unary(selection) { + override def dataType: DataType = selection.dataType +} + +// The default case for the expression parser needs to be explicitly defined to distinguish [DEFAULT] +case class Default() extends LeafExpression { + override def dataType: DataType = UnresolvedType +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/relations.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/relations.scala new file mode 100644 index 0000000000..8159d8bd46 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/relations.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate._ + +case class DerivedRows(rows: Seq[Seq[Expression]]) extends LeafNode { + override def output: Seq[Attribute] = rows.flatten.map(e => AttributeReference(e.toString, e.dataType)) +} + +case class Output(target: Option[LogicalPlan], outputs: Seq[Expression], columns: Option[Seq[Column]]) + extends RelationCommon { + override def output: Seq[Attribute] = outputs.map(e => AttributeReference(e.toString, e.dataType)) + override def children: Seq[LogicalPlan] = Seq(target.getOrElse(NoopNode)) +} + +case class WithOutputClause(input: LogicalPlan, target: LogicalPlan) extends Modification { + override def output: Seq[Attribute] = target.output + override def children: Seq[LogicalPlan] = Seq(input, target) +} + +case class BackupDatabase( + databaseName: String, + disks: Seq[String], + flags: Map[String, Boolean], + autoFlags: Seq[String], + values: Map[String, Expression]) + extends Catalog {} + +case class ColumnAliases(input: LogicalPlan, aliases: Seq[Id]) extends RelationCommon { + override def output: Seq[Attribute] = aliases.map(a => AttributeReference(a.id, StringType)) + override def children: Seq[LogicalPlan] = Seq(input) +} + +case class DefaultValues() extends LeafNode { + override def output: Seq[Attribute] = Seq.empty +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwards.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwards.scala new file mode 100644 index 0000000000..f0cc9e941a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwards.scala @@ -0,0 +1,17 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.intermediate._ + +// TSQL has "SELECT TOP N * FROM .." vs "SELECT * FROM .. LIMIT N", so we fix it here +object PullLimitUpwards extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Project(Limit(child, limit), exprs) => + Limit(Project(child, exprs), limit) + case Filter(Limit(child, limit), cond) => + Limit(Filter(child, cond), limit) + case Sort(Limit(child, limit), order, global) => + Limit(Sort(child, order, global), limit) + case Offset(Limit(child, limit), offset) => + Limit(Offset(child, offset), limit) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapper.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapper.scala new file mode 100644 index 0000000000..afcc44df6a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapper.scala @@ -0,0 +1,119 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.intermediate._ + +class TSqlCallMapper extends CallMapper { + + override def convert(call: Fn): Expression = { + call match { + case CallFunction("DATEADD", args) => + processDateAdd(args) + case CallFunction("GET_BIT", args) => BitwiseGet(args.head, args(1)) + case CallFunction("SET_BIT", args) => genBitSet(args) + case CallFunction("CHECKSUM_AGG", args) => checksumAgg(args) + + // No special case for TSQl, so we ask the main mapper to look at this one + case cf: CallFunction => super.convert(cf) + + // As well as CallFunctions, we can receive concrete functions, which are already resolved, + // and don't need to be converted + case x: Fn => x + } + } + + /** + * Converts a CHECKSUM_ARG call to a MD5 function call sequence + * @param args + * @return + */ + private def checksumAgg(args: Seq[Expression]): Expression = { + Md5(ConcatWs(Seq(Literal(","), CollectList(args.head)))) + } + + private def genBitSet(args: Seq[Expression]): Expression = { + val x = args.head + val n = args(1) + val v = if (args.length > 2) { args(2) } + else { Literal(1) } + + // There is no direct way to achieve a bit set in Databricks SQL, so + // we use standard bitwise operations to achieve the same effect. If we have + // literals for the bit the bit Value, we can generate less complicated + // code, but if we have columns or other expressions, we have to generate a longer sequence + // as we don't know if we are setting or clearing a bit until runtime. + v match { + case lit: Literal if lit == Literal(1) => BitwiseOr(x, ShiftLeft(Literal(1), n)) + case _ => + BitwiseOr(BitwiseAnd(x, BitwiseXor(Literal(-1), ShiftLeft(Literal(1), n))), ShiftRight(v, n)) + } + } + + private def processDateAdd(args: Seq[Expression]): Expression = { + + // The first argument of the TSQL DATEADD function is the interval type, which is one of way too + // many strings and aliases for "day", "month", "year", etc. We need to extract this string and + // perform the translation based on what we get + val interval = args.head match { + case Column(_, Id(id, _)) => id.toLowerCase() + case _ => + throw new IllegalArgumentException("DATEADD interval type is not valid. Should be 'day', 'month', 'year', etc.") + } + + // The value is how many units, type indicated by interval, to add to the date + val value = args(1) + + // And this is the thing we are going to add the value to + val objectReference = args(2) + + // The interval type names are all over the place in TSQL, some of them having names that + // belie their actual function. + interval match { + + // Days are all that Spark DATE_ADD operates on, but the arguments are transposed from TSQL + // despite the fact that 'dayofyear' implies the number of the day in the year, it is in fact the + // same as day, as is `weekday` + case "day" | "dayofyear" | "dd" | "d" | "dy" | "y" | "weekday" | "dw" | "w" => + DateAdd(objectReference, value) + + // Months are handled by the MonthAdd function, with arguments transposed from TSQL + case "month" | "mm" | "m" => AddMonths(objectReference, value) + + // There is no equivalent to quarter in Spark, so we have to use the MonthAdd function and multiply by 3 + case "quarter" | "qq" | "q" => AddMonths(objectReference, Multiply(value, Literal(3))) + + // There is no equivalent to year in Spark SQL, but we use months and multiply by 12 + case "year" | "yyyy" | "yy" => AddMonths(objectReference, Multiply(value, Literal(12))) + + // Weeks are not supported in Spark SQL, but we can multiply by 7 to get the same effect with DATE_ADD + case "week" | "wk" | "ww" => DateAdd(objectReference, Multiply(value, Literal(7))) + + // Hours are not supported in Spark SQL, but we can use the number of hours to create an INTERVAL + // and add it to the object reference + case "hour" | "hh" => Add(objectReference, KnownInterval(value, HOUR_INTERVAL)) + + // Minutes are not supported in Spark SQL, but we can use the number of minutes to create an INTERVAL + // and add it to the object reference + case "minute" | "mi" | "n" => Add(objectReference, KnownInterval(value, MINUTE_INTERVAL)) + + // Seconds are not supported in Spark SQL, but we can use the number of seconds to create an INTERVAL + // and add it to the object reference + case "second" | "ss" | "s" => Add(objectReference, KnownInterval(value, SECOND_INTERVAL)) + + // Milliseconds are not supported in Spark SQL, but we can use the number of milliseconds to create an INTERVAL + // and add it to the object reference + case "millisecond" | "ms" => Add(objectReference, KnownInterval(value, MILLISECOND_INTERVAL)) + + // Microseconds are not supported in Spark SQL, but we can use the number of microseconds to create an INTERVAL + // and add it to the object reference + case "microsecond" | "mcs" => Add(objectReference, KnownInterval(value, MICROSECOND_INTERVAL)) + + // Nanoseconds are not supported in Spark SQL, but we can use the number of nanoseconds to create an INTERVAL + // and add it to the object reference + case "nanosecond" | "ns" => Add(objectReference, KnownInterval(value, NANOSECOND_INTERVAL)) + + case _ => + throw new IllegalArgumentException( + s"DATEADD interval type '${interval}' is not valid. Should be 'day', 'month', 'year', etc.") + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubquery.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubquery.scala new file mode 100644 index 0000000000..b33af6a420 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubquery.scala @@ -0,0 +1,96 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.intermediate._ + +import java.util.concurrent.atomic.AtomicLong + +case class TopPercent(child: LogicalPlan, percentage: Expression, with_ties: Boolean = false) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +class TopPercentToLimitSubquery extends Rule[LogicalPlan] { + private[this] val counter = new AtomicLong() + override def apply(plan: LogicalPlan): LogicalPlan = normalize(plan) transformUp { + case TopPercent(child, percentage, withTies) => + if (withTies) { + withPercentiles(child, percentage) + } else { + viaTotalCount(child, percentage) + } + } + + /** See [[PullLimitUpwards]] */ + private def normalize(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Project(TopPercent(child, limit, withTies), exprs) => + TopPercent(Project(child, exprs), limit, withTies) + case Filter(TopPercent(child, limit, withTies), cond) => + TopPercent(Filter(child, cond), limit, withTies) + case Sort(TopPercent(child, limit, withTies), order, global) => + TopPercent(Sort(child, order, global), limit, withTies) + case Offset(TopPercent(child, limit, withTies), offset) => + TopPercent(Offset(child, offset), limit, withTies) + } + + private def withPercentiles(child: LogicalPlan, percentage: Expression) = { + val cteSuffix = counter.incrementAndGet() + val originalCteName = s"_limited$cteSuffix" + val withPercentileCteName = s"_with_percentile$cteSuffix" + val percentileColName = s"_percentile$cteSuffix" + child match { + case Sort(child, order, _) => + // TODO: this is (temporary) hack due to the lack of star resolution. otherwise child.output is fine + val reProject = child.find(_.isInstanceOf[Project]).map(_.asInstanceOf[Project]) match { + case Some(Project(_, expressions)) => expressions + case None => + throw new IllegalArgumentException("Cannot find a projection") + } + WithCTE( + Seq( + SubqueryAlias(child, Id(originalCteName)), + SubqueryAlias( + Project( + UnresolvedRelation(originalCteName, message = s"Unresolved $originalCteName"), + reProject ++ Seq(Alias(Window(NTile(Literal(100)), sort_order = order), Id(percentileColName)))), + Id(withPercentileCteName))), + Filter( + Project( + UnresolvedRelation(withPercentileCteName, message = s"Unresolved $withPercentileCteName"), + reProject), + LessThanOrEqual( + UnresolvedAttribute( + percentileColName, + ruleText = percentileColName, + message = s"Unresolved $percentileColName"), + Divide(percentage, Literal(100))))) + case _ => + // TODO: (jimidle) figure out cases when this is not true + throw new IllegalArgumentException("TopPercent with ties requires a Sort node") + } + } + + private def viaTotalCount(child: LogicalPlan, percentage: Expression) = { + val cteSuffix = counter.incrementAndGet() + val originalCteName = s"_limited$cteSuffix" + val countedCteName = s"_counted$cteSuffix" + WithCTE( + Seq( + SubqueryAlias(child, Id(originalCteName)), + SubqueryAlias( + Project( + UnresolvedRelation(ruleText = originalCteName, message = s"Unresolved relation $originalCteName"), + Seq(Alias(Count(Seq(Star())), Id("count")))), + Id(countedCteName))), + Limit( + Project( + UnresolvedRelation(ruleText = originalCteName, message = s"Unresolved relation $originalCteName"), + Seq(Star())), + ScalarSubquery( + Project( + UnresolvedRelation( + ruleText = countedCteName, + message = s"Unresolved relation $countedCteName", + ruleName = "N/A", + tokenName = Some("N/A")), + Seq(Cast(Multiply(Divide(Id("count"), percentage), Literal(100)), LongType)))))) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultsAction.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultsAction.scala new file mode 100644 index 0000000000..4e763e1f58 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultsAction.scala @@ -0,0 +1,20 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.intermediate._ + +case class InsertDefaultsAction(condition: Option[Expression]) extends MergeAction { + override def children: Seq[Expression] = condition.toSeq +} + +object TrapInsertDefaultsAction extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case merge @ MergeIntoTable(_, _, _, _, notMatchedActions, _) => + notMatchedActions.collectFirst { case InsertDefaultsAction(_) => + throw new IllegalArgumentException( + "The MERGE action 'INSERT DEFAULT VALUES' is not supported in Databricks SQL") + } + merge + case _ => plan + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/preprocessors/Processor.scala b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/Processor.scala new file mode 100644 index 0000000000..2159f72dc8 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/Processor.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.preprocessors + +import com.databricks.labs.remorph._ +import com.databricks.labs.remorph.intermediate.IncoherentState +import org.antlr.v4.runtime.{CharStream, Lexer} + +trait Processor extends TransformationConstructors { + protected def createLexer(input: CharStream): Lexer + final def pre: Transformation[Unit] = { + getCurrentPhase.flatMap { + case p: PreProcessing => + preprocess(p.source).flatMap { preprocessedString => + setPhase(Parsing(preprocessedString, p.filename, Some(p))) + } + case other => ko(WorkflowStage.PARSE, IncoherentState(other, classOf[PreProcessing])) + } + } + + def preprocess(input: String): Transformation[String] + def post(input: String): Transformation[String] +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/preprocessors/README.md b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/README.md new file mode 100644 index 0000000000..d57f8e2619 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/README.md @@ -0,0 +1,348 @@ +# A Processor for Jinja Templates, and DBT Projects + +| | | +|:-----------------------|:-----------------| +| **Author** | Jim Idle | +| **Engineering Owner** | Serge Smertin | +| **Product Owner** | Guenia Izquierdo | +| **Status** | Draft for review | + +Major Changes: + + - 2024-11 Initial draft + - 2024-11 Modifed solution descriptions after prototyping/experimentation + - 2024-11 Added workflow diagrams + - 2024-11 Added DBT project section + - 2024-11 Added placeholders section and explanations + +# Table of Contents +- [A preprocessor for macros, parameters, Jinja templates, and DBT](#a-preprocessor-for-macros-parameters-jinja-templates-and-dbt) + - [Motivation](#motivation) + - [Definitions](#definitions) + - [Sample Template](#sample-template) + - [Complications](#complications) + - [Approaches to handling templates](#approaches-to-handling-templates) + - [Lexing and Placeholding](#lexing-and-placeholding) + - [In-place parsing](#in-place-parsing) + - [Translate on the fly](#translate-on-the-fly) + - [The preprocessor](#the-preprocessor) + - [Placeholders](#placeholders) + - [Workflow](#workflow) + - [DBT Projects](#dbt-projects) + +## Motivation + +Many customers converting from a particular SQL dialect to Databricks SQL will have adopted +Jinja templates, usually in the form of DBT projects. These templates will contain SQL code +in the source dialect, which we must transpile to Databricks SQL, while preserving the original +template structure. + +Here, we define the approach to handling these templates within the existing +remorph framework, and approach to handling DBT project conversion. + +## Definitions + + - Template - a template is a piece of text that contains placeholders for variables, loops, and + conditional statements in a format defined by the Jinja template engine. + - Template Element The templates contain templating elements enclosed in `{{ }}`, `{# #}`, + and `{% %}` blocks. + - Statement Element - a statement element is a template element that contains Jinja processing + directives such as a loop or conditional statement. They are enclosed in `{% %}` blocks and + quite often consist of matched pais of `{% %}` blocks for things like `if` `end`. + - Expression Element - an expression element is a template element that contains a variable + reference or a function call. They are enclosed in `{{ }}` blocks. + - Text Element - a text element is a piece of text that is not a template element - it is just + free text. It is passed through the Jinja processor unchanged. In our case the text will be + SQL code in a particular dialect and we must translate it to Databricks SQL. + +## Proposal sketches + +In the following diagrams, existing components are shown in blue, modified components in Magenta, +and new components are shown in green. + +Firstly, we introduce a new pre-processing phase in the toolchain. This phase is based upon and ANTLR lexer and supporting +Scala code. Raw text is passed through, but template elements are located and replaced with placeholders (defined later +in this document). The resulting text can then be used in the existing ANTLR based toolchain. + +In the process of capturing the template elements, we will note whether the template element was preceded or followed by +whitespace. This is important when regenerating the template elements in the post-processing phase. + +```mermaid +--- +title: Jinja Preprocessor Workflow +config: + theme: dark +--- +block-beta + block + columns 3 + Template(("Template Source Text")) ba1<[" "]>(right) Tokenizer + space:2 ba2<[" "]>(down) + Replace["Generate Placeholder"] ba3<[" "]>(left) Capture["Capture Templates\nand Text"] + ba4<[" "]>(down) space:2 + Text(("Text with Placeholders")) + end + + classDef new fill:#696,stroke:#333; + classDef exist fill:#969,stroke:#333; + classDef modified fill:##a752ba,stroke:#333; + class Template,Replace,Tokenizer,Capture,Text new +``` + +The placeholders are then used in the existing ANTLR based toolchain, with some modifications: + +```mermaid +--- +config: + theme: dark +--- +block-beta + block + columns 3 + Placeholder(("Text with Placeholders")) ba1<[" "]>(right) Tokenizer["Common Lexer"] + space:2 ba2<[" "]>(down) + Ir["Logical Plan with\nnew IR for template\nelements."] ba3<[" "]>(left) Parser["Dialect Parser/Visitor"] + ba4<[" "]>(down) space:2 + Generator["SQL Generator caters\nfor template IR"] ba5<[" "]>(right) Text["Translated Text"] + end + classDef new fill:#696,stroke:#333; + classDef exist fill:#969,stroke:#333; + classDef modified fill:#a752ba,stroke:#333; + class Placeholder,Text new + class Tokenizer,Ir,Parser,Generator modified +``` + +There is then another additional post-processing phase that will replace the placeholders with the template element +text. If the original template element was not preceded or followed by whitespace, then placeholders will elide preceding +of following whitespace accordingly. This allows the code generator to be as simple as possible. Not that this phase must +be applied before any SQL formatting is applied. + +```mermaid +--- +config: + theme: dark +--- +block-beta + block + columns 3 + Placeholder(("Translated Text")) ba1<[" "]>(right) PostProcessor["Replace Placeholders\n(with whitespace handling)"] + space:2 ba2<[" "]>(down) + space:2 Text["Reformed Template Text"] + end + classDef new fill:#696,stroke:#333; + classDef exist fill:#969,stroke:#333; + classDef modified fill:#a752ba,stroke:#333; + class Placeholder,Text new + class Tokenizer,Ir,Parser,Generator modified +``` + +## Sample Template + +```sql +WITH order_payments AS ( + SELECT + order_id, + {# payment_methods is a list of strings - this is a comment #} + {% for payment_method in payment_methods -%} + SUM(CASE WHEN payment_method = '{{ payment_method }}' + THEN amount else 0 end) AS {{ payment_method }}_amount, + {% endfor -%} + SUM(amount) AS total_amount + FROM {{ ref('payments') }} + GROUP BY order_id +) +SELECT * FROM order_payments +``` + +The example shows a query embdedded in a DBT/Jinja template and shows the various ways in which +macros and template references are used. We see that: + + - literal strings can contain template/parameter references: `'{{ payment_method }}'` + - `{% -%}` templates can contain templating code: `{% for payment_method in payment_methods -%}` + - `{{ ref('payments') }}` is a macro, which invokes a function. IN this case ref is a function that + returns a string stored in a variable called payments. + +Of note, we see that a template use such as `{{ payment_method }}_amount` will generate text that +creates a single Identifier, and so whitespace needs to be accounted for in the generated code, as +there is a clear difference in generated template output depending upon whether whitespace +is present or not: + +```sql + SELECT {{ x }} xalias FROM table + SELECT {{ x }}_not_alias FROM table +``` + +### Complications + + - Jinja allows the user to change the delimiters for the templates from the default `{{ }}` to anything else. Hence + lexical tricks are used such that we can still use an ANTLR based lexer as the basis of the preprocessor. + - In many cases the templates will be used in place of say, _expressions_, and therefore we can just accept a + special token: `NINJAEXPR: 'Jinja_' [0-9]+ ;`. + - However, we are going to find both statement and expression + templates located in places where the current SQL parser will not expect them. In the example above, the statement + template `{% for payment_method in payment_methods -%}` is located in the middle of a SQL statement. In this case + we would need to allow templates to occur anywhere in the parse in violation of the normal syntax rules. + - Jinja allows line statements, also with the prefix being configurable. Hence, we need to be able to handle + them too. Typically, they would start with a single prefix such as `# for item in seq`, and the entire line + is a ninja statement. + - There is nothing stopping a user from stuffing variables with actual SQL statements. We will probably draw a + line at supporting that, although when translating say an entire DBT project, we will likely find the definition + of the variables and they would then be translated naturally. + - Macros can contain bits of SQL code, which we may be able to attempt to translate by trying different entry points + into the SQL parser. However, this is not guaranteed to be successful as literally anything can be put anywhere + in macros. + - There are macros that do things like add or not add commas at the ends for loops of text generation. We can cater + for this using the generally accepted practice of allow trailing commas in sequences, so we can accept partial + lists of columns, for example. + +## Major alternatives to handling templates + +We should first note that our template handler/preprocessor is not expanding the templates, but locating them and +translating embedded SQL. The templates themselves will be left in place, and the translated SQL will be adorned +with them exactly as it was before transpilation. + +In other words we are writing a XXX SQL plus DBT to Databricks SQL plus DBT, such that users will then maintain the +Databricks SQL alongside their templates. + +We therefore have the following approaches: + +### Lexing and Placeholding + +This is the proposed approach. + +As there are few constructs in the templates, we can lex the raw input, store the gathered template elements, and replace +them with a single token identifying the template element. + +The SQL grammars then need only look for simple template references and not have to worry about the template syntax itself. +However, the grammars for the source dialects will need to be modified to accept the template tokens at strategic points +in the parse. We note that it may not be possible to cover 100% of the esoteric uses of templates, but we can cover the +majority. The existing system for handling parsing errors will generate code that shows why we could not parse what looked +like SQL but the template was used in a place where it is impossible to parse it. + +#### Advantages + + - The existing lexers and parsers are not burdened with the template syntax + - The lexers, parsers and code generator can be used with a few small modifications + +#### Disadvantages + + - Templates can be placed anywhere, so the grammar will need to be strategically adorned with valid + places that templates can occur. This is not a trivial task, however it is doable if we accept that + we will not be able to process some esoteric uses of templates, which are generally abuses of the concept. + For instance `SELECT A {{ , }} B {{ , }} C` is not a reasonable construct. It would need to be dealt + with manually. + +### In-place parsing + +While similar to the above, we could capture and manage the actual templates in the dialect grammars. This +removes the need for a pre-processor but means that we would pollute the existing pipeline wiht template handling. +It smells like not separating clear responsibilities/functionality. Because teh delimiters for template elements +can be changed by the user, the common lexer would need to know about DBT configs and so on. + +#### Advantages + + - No need for a pre-processor as such + +#### Disadvantages + + - We end up with code dealing with templates intermingled with the dialect parsing code + - We still need to track the templates, so we are not saving any coding effort, just moving it around + +### Translate on the fly + +One other approach would be to defer the translation of the templates until the SQL is produced by +DBT, then convert on the fly. + +#### Advantages + + - No need to modify the dialect grammars + - No need to track templates + +#### Disadvantages + + - We do not guarantee that we can perform a 100% conversion of the incoming source dialect + - Users would then be maintaining code in the source dialect, which is not the goal of the project + +# Part II - Implementation Details + +## The Processor + +We therefore conclude that the best approach is to use a preprocessor that will locate and track the templates +and replace them in the pre-processed output with a placeholder. + +As the templates are just text with the macro types above sprinkled in, we can create a preprocessor +that will always run against the given input text, even if, with un-templated SQL, the preprocessor +will merely pass through what it sees. However, there may well be other functionality for the preprocessor +phase to provide in the future such as perhaps parameter tracking/processing, or perhaps daisy-chaining of +pre-processors that perform specific tasks. The optimizer phase of the toolchain shows the value of being +able to apply multiple transformations in any particular transpilation phase. + +Hence the preprocessor will be a simple text processor that will: + + - find and replace all `{{ }}`, `{# #}` and `{% %}` blocks with a placeholder + - track the unique placeholders and their original text + - pass through the placeholders to the common SQL lexer + +To process the preprocessor output, we will: + + - modify the dialect grammars such that the placeholders are accepted at strategic points in the parse + - attempt to convert the SQL within the text where the macros are replaced with placeholders. + - the code generator will just pass the placeholders back to the template processor and + the template processor will replace the placeholders with the template elements, which may + also require some form of processing, which can now be handled outside of the toolchain. + +Note that Jinja template elements cannot contain random text and so at this time, we ssee no need to +process them in any way. However, shoudl the need arise, we can now + +### Placeholders + +The placeholders will be simple strings that are unique to the template. A template manager will +generate placeholder names in the form: `_!JinjaNNNN` where `NNNN` is a unique number, generated +sequentially. This allows the common lexer to be used for all dialects, as the placeholders will +always be of the same form. At this point we do not need to distinguish between the types of template +element (statement, expression, etc.) as they will be replaced via a post-processing step. The IR +generator will create a new IR node for each placeholder, holding its text value, and the code +generator will merely replace the placeholders with the text value. After code generation completes, +the placeholders will be replaced with the template definitions using a post-processing step. + +## DBT Projects + +As well as the template processing, we will need a DBT processor to bring in the DBT configuration and associated fluff that +goes with a complete DBT project layout, and translate every SQL bearing template. + +This includes processing the .yml configuration files and using anything from the configuration that is required to +generate the Databricks DBT project. This will be a separate utility which will coordinate with the transpiler to ensure +that the correct configuration is used and that the translated templates are stored in the appropriate location for the DBT project. + +A conversion will create a new DBT project (leaving the original input intact) with the same structure as the original, +but with the SQL code translated to Databricks SQL. The ourput will + +A simplified sequence diagram of the process is shown below: + +```mermaid +sequenceDiagram + participant User + participant DBTProcessor + participant ConfigLoader + participant Processor + participant Template Directory + participant OutputWriter + + User->>DBTProcessor: Load DBT Project + activate DBTProcessor + DBTProcessor->>ConfigLoader: Load Configuration + activate ConfigLoader + ConfigLoader-->>DBTProcessor: Configuration Loaded + deactivate ConfigLoader + DBTProcessor->>DBTProcessor: Adjust Configuration + DBTProcessor->>Processor: Process Template + activate Processor + Processor-->>DBTProcessor: Translated Template + deactivate Processor + DBTProcessor->>Template Directory: Store Translated Template + DBTProcessor->>OutputWriter: Write Translated Project + activate OutputWriter + OutputWriter-->>User: Translated DBT Project + deactivate OutputWriter + deactivate DBTProcessor +``` diff --git a/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessor.scala b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessor.scala new file mode 100644 index 0000000000..092a3afbd5 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessor.scala @@ -0,0 +1,226 @@ +package com.databricks.labs.remorph.preprocessors.jinja + +import com.databricks.labs.remorph._ +import com.databricks.labs.remorph.intermediate.{IncoherentState, Origin, PreParsingError} +import com.databricks.labs.remorph.parsers.preprocessor.DBTPreprocessorLexer +import com.databricks.labs.remorph.preprocessors.Processor +import org.antlr.v4.runtime._ + +class JinjaProcessor extends Processor { + + override protected def createLexer(input: CharStream): Lexer = new DBTPreprocessorLexer(input) + + override def preprocess(input: String): Transformation[String] = { + + val inputString = CharStreams.fromString(input) + val tokenizer = createLexer(inputString) + val tokenStream = new CommonTokenStream(tokenizer) + + updatePhase { case p: PreProcessing => + p.copy(tokenStream = Some(tokenStream)) + }.flatMap(_ => loop) + + } + + def loop: Transformation[String] = { + getCurrentPhase.flatMap { + case PreProcessing(_, _, _, Some(tokenStream), preprocessedSoFar) => + if (tokenStream.LA(1) == Token.EOF) { + ok(preprocessedSoFar) + } else { + val token = tokenStream.LT(1) + + // TODO: Line statements and comments + token.getType match { + case DBTPreprocessorLexer.STATEMENT => + loopStep(tokenStream, DBTPreprocessorLexer.STATEMENT_END, preprocessedSoFar) + case DBTPreprocessorLexer.EXPRESSION => + loopStep(tokenStream, DBTPreprocessorLexer.EXPRESSION_END, preprocessedSoFar) + case DBTPreprocessorLexer.COMMENT => + loopStep(tokenStream, DBTPreprocessorLexer.COMMENT_END, preprocessedSoFar) + case DBTPreprocessorLexer.C | DBTPreprocessorLexer.WS => + updatePhase { case p: PreProcessing => + val sb = new StringBuilder() + var tok = token + while (tok.getType == DBTPreprocessorLexer.C || tok.getType == DBTPreprocessorLexer.WS) { + tokenStream.consume() + sb.append(tok.getText) + tok = tokenStream.LT(1) + } + p.copy(preprocessedInputSoFar = preprocessedSoFar + sb.toString()) + }.flatMap(_ => loop) + case _ => + lift( + PartialResult( + preprocessedSoFar, + PreParsingError( + token.getLine, + token.getCharPositionInLine, + token.getText, + "Malformed template element was unexpected"))) + } + } + case other => ko(WorkflowStage.PARSE, IncoherentState(other, classOf[PreProcessing])) + } + } + + def loopStep(tokenStream: CommonTokenStream, token: Int, preprocessedSoFar: String): Transformation[String] = + buildElement(tokenStream, token) + .flatMap { elem => + updatePhase { case p: PreProcessing => + p.copy(preprocessedInputSoFar = preprocessedSoFar + elem) + } + } + .flatMap(_ => loop) + + def post(input: String): Transformation[String] = { + getTemplateManager.map(tm => tm.rebuild(input)) + } + + /** + * Accumulates tokens from the token stream into the template element, while building a regex to match the template element. + * It handles preceding and trailing whitespace, and optionally elides trailing commas. + * An accumulated template definition is added to the template manager, and it returns the placeholder name + * of the template to be used instead of the raw template text. + * + * @param tokenStream the token stream to process + * @param endType the token type that signifies the end of the template element + */ + private def buildElement(tokenStream: CommonTokenStream, endType: Int): Transformation[String] = { + + getTemplateManager + .flatMap { templateManager => + // Builds the regex that matches the template element + val regex = new StringBuilder + + // Was there any preceding whitespace? We need to know if this template element was specified like this: + // sometext_{{ expression }} + // or like this: + // sometext_ {{ expression }} + // + // So that our regular expression can elide any whitespace that was inserted by the SQL generator + if (!hasSpace(tokenStream, -1)) { + regex.append("[\t\f ]*") + } + + // Preserve new lines etc in the template text as it is much easier than doing this at replacement time + val preText = preFix(tokenStream, -1) + + val start = tokenStream.LT(1) + var token = start + do { + tokenStream.consume() + token = tokenStream.LT(1) + } while (token.getType != endType) + tokenStream.consume() + + // What is the next template placeholder? + val templateKey = templateManager.nextKey + regex.append(templateKey) + + // If there is no trailing space following the template element definition, then we need to elide any + // that are inserted by the SQL generator + if (!hasSpace(tokenStream, 1)) { + regex.append("[\t\f ]*") + } + + // If there is no trailing comma after the template element definition, then we need to elide any + // that are automatically inserted by the SQL generator - we therefore match any whitespace and newlines + // and just delete them, because the postfix will accumulate the original whitespace and newlines in the + // template text + if (!hasTrailingComma(tokenStream, 1)) { + regex.append("[\n\t\f ]*[,]?[ ]?") + } + + // Preserve new lines and space in the template text as it is much easier than doing this at replacement time + val text = preText + tokenStream.getText(start, token) + postFix(tokenStream, 1) + + val origin = + Origin( + Some(start.getLine), + Some(start.getCharPositionInLine), + Some(start.getStartIndex), + Some(token.getStopIndex), + Some(text)) + val template = endType match { + case DBTPreprocessorLexer.STATEMENT_END => + StatementElement(origin, text, regex.toString()) + case DBTPreprocessorLexer.EXPRESSION_END => + ExpressionElement(origin, text, regex.toString()) + case DBTPreprocessorLexer.COMMENT_END => + CommentElement(origin, text, regex.toString()) + } + updateTemplateManager(_.add(templateKey, template)).map(_ => templateKey) + } + } + + /** + * Checks if the token at the specified index in the token stream is a whitespace token. + * + * @param tokenStream the token stream to check + * @param index the index of the token to check + * @return true if the token at the specified index is a whitespace token, false otherwise + */ + private def hasSpace(tokenStream: CommonTokenStream, index: Int): Boolean = + Option(tokenStream.LT(index)) match { + case None => false + case Some(s) if s.getType == DBTPreprocessorLexer.WS => true + case _ => false + } + + /** + * Accumulates preceding whitespace and newline tokens from the given index in the token stream. + * + * @param tokenStream the token stream to search backwards from (inclusive) + * @param index the starting index in the token stream + * @return a string containing the accumulated whitespace and newline tokens + */ + private def preFix(tokenStream: CommonTokenStream, index: Int): String = { + val builder = new StringBuilder + var token = tokenStream.LT(index) + var i = 1 + while (token != null && (token.getType == DBTPreprocessorLexer.WS || token.getText == "\n")) { + builder.insert(0, token.getText) + token = tokenStream.LT(index - i) + i += 1 + } + + // We do not accumulate the prefix if the immediately preceding context was another + // template element as that template will have accumulated the whitespace etc in its + // postfix + if (token != null && (token.getType == DBTPreprocessorLexer.STATEMENT_END || + token.getType == DBTPreprocessorLexer.EXPRESSION_END || + token.getType == DBTPreprocessorLexer.COMMENT_END)) { + "" + } else { + builder.toString() + } + } + + /** + * Accumulates trailing whitespace and newline tokens from the given index in the token stream. + * + * @param tokenStream the token stream to search forwards from (inclusive) + * @param index the starting index in the token stream + * @return a string containing the accumulated whitespace and newline tokens + */ + private def postFix(tokenStream: CommonTokenStream, index: Int): String = { + val builder = new StringBuilder + var token = tokenStream.LT(index) + while (token != null && (token.getType == DBTPreprocessorLexer.WS || token.getText == "\n")) { + builder.append(token.getText) + token = tokenStream.LT(index + builder.length) + } + builder.toString() + } + + private def hasTrailingComma(tokenStream: CommonTokenStream, index: Int): Boolean = { + var token = tokenStream.LT(index) + var i = 1 + while (token != null && (token.getType == DBTPreprocessorLexer.WS || token.getText == "\n")) { + token = tokenStream.LT(index + i) + i += 1 + } + token != null && token.getText == "," + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateElement.scala b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateElement.scala new file mode 100644 index 0000000000..f3ede790f0 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateElement.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.remorph.preprocessors.jinja + +import com.databricks.labs.remorph.intermediate.Origin + +sealed trait TemplateElement { + def origin: Origin + def text: String + def regex: String + + def appendText(text: String): TemplateElement = { + this match { + case s: StatementElement => s.copy(text = this.text + text) + case e: ExpressionElement => e.copy(text = this.text + text) + case c: CommentElement => c.copy(text = this.text + text) + case l: LineElement => l.copy(text = this.text + text) + case ls: LineStatementElement => ls.copy(text = this.text + text) + case lc: LineCommentElement => lc.copy(text = this.text + text) + } + } +} + +case class StatementElement(origin: Origin, text: String, regex: String) extends TemplateElement + +case class ExpressionElement(origin: Origin, text: String, regex: String) extends TemplateElement + +case class CommentElement(origin: Origin, text: String, regex: String) extends TemplateElement + +// TODO: We don't support # line elements yet +case class LineElement(origin: Origin, text: String, regex: String) extends TemplateElement + +// TODO: We don't support # line statements yet +case class LineStatementElement(origin: Origin, text: String, regex: String) extends TemplateElement + +// TODO: We don't support # line comment elements yet +case class LineCommentElement(origin: Origin, text: String, regex: String) extends TemplateElement diff --git a/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateManager.scala b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateManager.scala new file mode 100644 index 0000000000..4751706200 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/preprocessors/jinja/TemplateManager.scala @@ -0,0 +1,44 @@ +package com.databricks.labs.remorph.preprocessors.jinja + +import com.databricks.labs.remorph.utils.Sed + +case class TemplateManager(templates: Map[String, TemplateElement] = Map.empty) { + + def nextKey: String = { + f"_!Jinja${templates.size + 1}%04d" + } + + def add(key: String, template: TemplateElement): TemplateManager = { + copy(templates = templates + (key -> template)) + } + + def get(key: String): Option[TemplateElement] = { + templates.get(key) + } + + /** + * Replaces all the placeholders in the transpiler generated output with the actual + * template values, taking care to observe whether the original template was space separated + * or not - extra spaces generated by the transpiler will be removed when necessary. + * + * @param generated the output string we wish to rebuild with the original templates + * @return the rebuilt output string + */ + def rebuild(generated: String): String = { + + // Use the Sed utility to replace all the placeholders with the actual template values. We build + // a regexp and a replacement string (which is just the template.text) for each template, the regex + // is told to match and replace any preceding spaces if the template element was originally + // not preceded by white space, and will match and replace trailing whitespace if it was not originally + // followed by a whitespace. We use the resulting map to call the Sed utility to apply the + // replacements. Note that we do not replace newlines as we need to preserve them so we + // can preserve at least the rudimentary formatting of the generator output. + val replacements: Seq[(String, String)] = templates.toSeq.map { case (key, template) => + (template.regex, template.text) + } ++ + // In certain situations, these replacements will orphan trailing whitespace, which we remove with this regex + Seq(("[ ]+\n", "\n")) + val replace = new Sed(replacements: _*) + replace(generated) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/queries/ExampleSource.scala b/core/src/main/scala/com/databricks/labs/remorph/queries/ExampleSource.scala new file mode 100644 index 0000000000..f1a830a760 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/queries/ExampleSource.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.remorph.queries + +import java.io.File +import java.nio.file.{Files, Path, Paths} +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + +case class AcceptanceTest(testName: String, inputFile: File) + +trait ExampleSource { + def listTests: Seq[AcceptanceTest] +} + +object NestedFiles { + def projectRoot: String = checkProjectRoot(Paths.get(".").toAbsolutePath).toString + + @tailrec private def checkProjectRoot(current: Path): Path = { + // check if labs.yml exists in the current folder + if (Files.exists(current.resolve("labs.yml"))) { + current + } else if (current.getParent == null) { + throw new RuntimeException("Could not find project root") + } else { + checkProjectRoot(current.getParent) + } + } +} + +class NestedFiles(root: Path) extends ExampleSource { + def listTests: Seq[AcceptanceTest] = { + val files = + Files + .walk(root) + .iterator() + .asScala + .filter(f => Files.isRegularFile(f)) + .toSeq + + val sqlFiles = files.filter(_.getFileName.toString.endsWith(".sql")) + sqlFiles.sorted.map(p => AcceptanceTest(root.relativize(p).toString, p.toFile)) + } +} + +class TestFile(path: Path) extends ExampleSource { + def listTests: Seq[AcceptanceTest] = { + Array(AcceptanceTest(path.getFileName.toString, path.toFile)) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/queries/QueryExtractor.scala b/core/src/main/scala/com/databricks/labs/remorph/queries/QueryExtractor.scala new file mode 100644 index 0000000000..b10f66f615 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/queries/QueryExtractor.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.remorph.queries + +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.{Parsing, PartialResult, TranspilerState} +import com.typesafe.scalalogging.LazyLogging + +import java.io.File +import scala.io.Source + +trait QueryExtractor { + def extractQuery(file: File): Option[ExampleQuery] +} + +case class ExampleQuery(query: String, expectedTranslation: Option[String], shouldFormat: Boolean = true) + +class WholeFileQueryExtractor extends QueryExtractor { + override def extractQuery(file: File): Option[ExampleQuery] = { + val fileContent = Source.fromFile(file) + val shouldFormat = !file.getName.contains("nofmt") + Some(ExampleQuery(fileContent.getLines().mkString("\n"), None, shouldFormat)) + } +} + +class CommentBasedQueryExtractor(inputDialect: String, targetDialect: String) extends QueryExtractor { + + private[this] val markerCommentPattern = "--\\s*(\\S+)\\s+sql:".r + + override def extractQuery(file: File): Option[ExampleQuery] = { + val source = Source.fromFile(file) + val shouldFormat = !file.getName.contains("nofmt") + val linesByDialect = source + .getLines() + .foldLeft((Option.empty[String], Map.empty[String, Seq[String]])) { + case ((currentDialect, dialectToLines), line) => + markerCommentPattern.findFirstMatchIn(line) match { + case Some(m) => (Some(m.group(1)), dialectToLines) + case None => + if (currentDialect.isDefined) { + ( + currentDialect, + dialectToLines.updated( + currentDialect.get, + dialectToLines.getOrElse(currentDialect.get, Seq()) :+ line)) + } else { + (currentDialect, dialectToLines) + } + } + } + ._2 + + linesByDialect.get(inputDialect).map { linesForInputDialect => + ExampleQuery( + linesForInputDialect.mkString("\n"), + linesByDialect.get(targetDialect).map(_.mkString("\n")), + shouldFormat) + } + } +} + +class ExampleDebugger(parser: PlanParser[_], prettyPrinter: Any => Unit, dialect: String) extends LazyLogging { + def debugExample(name: String): Unit = { + val extractor = new CommentBasedQueryExtractor(dialect, "databricks") + extractor.extractQuery(new File(name)) match { + case Some(ExampleQuery(query, _, _)) => + parser.parse.flatMap(parser.visit).run(TranspilerState(Parsing(query))) match { + case com.databricks.labs.remorph.KoResult(_, error) => + logger.error(s"Failed to parse query: $query ${error.msg}") + case PartialResult((_, plan), error) => + logger.warn(s"Errors occurred while parsing query: $query ${error.msg}") + prettyPrinter(plan) + case com.databricks.labs.remorph.OkResult((_, plan)) => + prettyPrinter(plan) + } + case None => throw new IllegalArgumentException(s"Example $name not found") + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/ConnectionFactory.scala b/core/src/main/scala/com/databricks/labs/remorph/support/ConnectionFactory.scala new file mode 100644 index 0000000000..91be999a46 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/ConnectionFactory.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.support + +import java.sql.Connection + +trait ConnectionFactory { + def newConnection(): Connection +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/SupportContext.scala b/core/src/main/scala/com/databricks/labs/remorph/support/SupportContext.scala new file mode 100644 index 0000000000..58fb5a8295 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/SupportContext.scala @@ -0,0 +1,11 @@ +package com.databricks.labs.remorph.support + +import com.databricks.labs.remorph.discovery.QueryHistoryProvider +import com.databricks.labs.remorph.parsers.PlanParser + +trait SupportContext { + def name: String + def planParser: PlanParser[_] + def connectionFactory: ConnectionFactory + def remoteQueryHistory: QueryHistoryProvider +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeConnectionFactory.scala b/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeConnectionFactory.scala new file mode 100644 index 0000000000..80d1145a0a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeConnectionFactory.scala @@ -0,0 +1,42 @@ +package com.databricks.labs.remorph.support.snowflake + +import com.databricks.labs.remorph.coverage.runners.EnvGetter +import com.databricks.labs.remorph.support.ConnectionFactory +import net.snowflake.client.jdbc.internal.org.bouncycastle.jce.provider.BouncyCastleProvider + +import java.security.spec.PKCS8EncodedKeySpec +import java.security.{KeyFactory, PrivateKey, Security} +import java.sql.{Connection, DriverManager} +import java.util.{Base64, Properties} + +// TODO: This is not how we will handle connections in the future + +class SnowflakeConnectionFactory(env: EnvGetter) extends ConnectionFactory { + // scalastyle:off + Class.forName("net.snowflake.client.jdbc.SnowflakeDriver") + // scalastyle:on + + private[this] val url = env.get("TEST_SNOWFLAKE_JDBC") + private[this] val privateKeyPEM = env.get("TEST_SNOWFLAKE_PRIVATE_KEY") + + private def privateKey: PrivateKey = { + Security.addProvider(new BouncyCastleProvider()) + val keySpecPKCS8 = new PKCS8EncodedKeySpec( + Base64.getDecoder.decode( + privateKeyPEM + .split("\n") + .drop(1) + .dropRight(1) + .mkString)) + val kf = KeyFactory.getInstance("RSA") + kf.generatePrivate(keySpecPKCS8) + } + + private[this] val props = { + val p = new Properties() + p.put("privateKey", privateKey) + p + } + + def newConnection(): Connection = DriverManager.getConnection(url, props) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeContext.scala b/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeContext.scala new file mode 100644 index 0000000000..bf9637ea75 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/snowflake/SnowflakeContext.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.support.snowflake + +import com.databricks.labs.remorph.coverage.runners.EnvGetter +import com.databricks.labs.remorph.support.{ConnectionFactory, SupportContext} +import com.databricks.labs.remorph.discovery.{QueryHistoryProvider, SnowflakeQueryHistory} +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser + +class SnowflakeContext(private[this] val envGetter: EnvGetter) extends SupportContext { + override def name: String = "snowflake" + override def planParser: PlanParser[_] = new SnowflakePlanParser + override lazy val connectionFactory: ConnectionFactory = new SnowflakeConnectionFactory(envGetter) + override lazy val remoteQueryHistory: QueryHistoryProvider = { + new SnowflakeQueryHistory(connectionFactory.newConnection()) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/tsql/SqlServerConnectionFactory.scala b/core/src/main/scala/com/databricks/labs/remorph/support/tsql/SqlServerConnectionFactory.scala new file mode 100644 index 0000000000..7653e2394a --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/tsql/SqlServerConnectionFactory.scala @@ -0,0 +1,18 @@ +package com.databricks.labs.remorph.support.tsql + +import com.databricks.labs.remorph.coverage.runners.EnvGetter +import com.databricks.labs.remorph.support.ConnectionFactory + +import java.sql.{Connection, DriverManager} + +class SqlServerConnectionFactory(env: EnvGetter) extends ConnectionFactory { + // scalastyle:off + Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver") + // scalastyle:on + + private[this] val url = env.get("TEST_TSQL_JDBC") + private[this] val user = env.get("TEST_TSQL_USER") + private[this] val pass = env.get("TEST_TSQL_PASS") + + override def newConnection(): Connection = DriverManager.getConnection(url, user, pass) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/support/tsql/TSqlContext.scala b/core/src/main/scala/com/databricks/labs/remorph/support/tsql/TSqlContext.scala new file mode 100644 index 0000000000..24b61ebb27 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/support/tsql/TSqlContext.scala @@ -0,0 +1,16 @@ +package com.databricks.labs.remorph.support.tsql + +import com.databricks.labs.remorph.coverage.runners.EnvGetter +import com.databricks.labs.remorph.discovery.QueryHistoryProvider +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.parsers.tsql.TSqlPlanParser +import com.databricks.labs.remorph.support.{ConnectionFactory, SupportContext} + +class TSqlContext(private[this] val envGetter: EnvGetter) extends SupportContext { + override def name: String = "tsql" + override def planParser: PlanParser[_] = new TSqlPlanParser + override lazy val connectionFactory: ConnectionFactory = new SqlServerConnectionFactory(envGetter) + override def remoteQueryHistory: QueryHistoryProvider = { + throw new IllegalArgumentException("query history for SQLServer is not yet implemented") + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/PySparkGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/PySparkGenerator.scala new file mode 100644 index 0000000000..ccdc54dfb3 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/PySparkGenerator.scala @@ -0,0 +1,30 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.{KoResult, TransformationConstructors, WorkflowStage, intermediate => ir} +import com.databricks.labs.remorph.generators.py +import com.databricks.labs.remorph.generators.py.rules.{AndOrToBitwise, DotToFCol, ImportClasses, PySparkExpressions, PySparkStatements} +import org.json4s.{Formats, NoTypeHints} +import org.json4s.jackson.Serialization + +import scala.util.control.NonFatal + +class PySparkGenerator extends TransformationConstructors { + private[this] val exprGenerator = new py.ExpressionGenerator + private[this] val stmtGenerator = new py.StatementGenerator(exprGenerator) + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + private[this] val expressionRules = ir.Rules(new DotToFCol, new PySparkExpressions, new AndOrToBitwise) + private[this] val statementRules = ir.Rules(new PySparkStatements(expressionRules), new ImportClasses) + + def generate(optimizedLogicalPlan: ir.LogicalPlan): py.Python = { + try { + val withShims = PySparkStatements(optimizedLogicalPlan) + val statements = statementRules(withShims) + stmtGenerator.generate(statements) + } catch { + case NonFatal(e) => + lift(KoResult(WorkflowStage.GENERATE, ir.UncaughtException(e))) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspiler.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspiler.scala new file mode 100644 index 0000000000..051dd1a874 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspiler.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser + +class SnowflakeToDatabricksTranspiler extends BaseTranspiler { + override val planParser = new SnowflakePlanParser +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspiler.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspiler.scala new file mode 100644 index 0000000000..96c1c33634 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspiler.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.transpilers +import com.databricks.labs.remorph.{Generating, Optimizing} +import com.databricks.labs.remorph.generators.GeneratorContext +import com.databricks.labs.remorph.generators.py.{LogicalPlanGenerator, Python} +import com.databricks.labs.remorph.intermediate.LogicalPlan + +class SnowflakeToPySparkTranspiler extends SnowflakeToDatabricksTranspiler { + val generator = new PySparkGenerator() + + override protected def generate(optimized: LogicalPlan): Python = + updatePhase { + case o: Optimizing => + Generating( + optimizedPlan = optimized, + currentNode = optimized, + ctx = GeneratorContext(new LogicalPlanGenerator), + previousPhase = Some(o)) + case _ => + Generating(optimizedPlan = optimized, currentNode = optimized, ctx = GeneratorContext(new LogicalPlanGenerator)) + }.flatMap(_ => generator.generate(optimized)) +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/Source.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/Source.scala new file mode 100644 index 0000000000..7dc52ed95d --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/Source.scala @@ -0,0 +1,33 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.Parsing + +import java.nio.file.{Files, Path, Paths} +import scala.io.Source.fromFile +import scala.collection.JavaConverters._ + +trait Source extends Iterator[Parsing] + +class DirectorySource(root: String, fileFilter: Option[Path => Boolean] = None) extends Source { + private[this] val files = + Files + .walk(Paths.get(root)) + .iterator() + .asScala + .filter(f => Files.isRegularFile(f) && fileFilter.forall(filter => filter(f))) + .toSeq + .iterator + + override def hasNext: Boolean = files.hasNext + + override def next(): Parsing = { + if (!hasNext) throw new NoSuchElementException("No more source entities") + val file = files.next() + val source = fromFile(file.toFile) + try { + Parsing(source.mkString, file.getFileName.toString) + } finally { + source.close() + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/SqlGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SqlGenerator.scala new file mode 100644 index 0000000000..16f9c0c953 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/SqlGenerator.scala @@ -0,0 +1,31 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.generators.GeneratorContext +import com.databricks.labs.remorph.{KoResult, TransformationConstructors, WorkflowStage, intermediate => ir} +import com.databricks.labs.remorph.generators.sql.{ExpressionGenerator, LogicalPlanGenerator, OptionGenerator, SQL} +import org.json4s.jackson.Serialization +import org.json4s.{Formats, NoTypeHints} + +import scala.util.control.NonFatal + +// TODO: This should not be under transpilers but we have not refactored generation out of the transpiler yet +// and it may need changes before it is considered finished anyway, such as implementing a trait +class SqlGenerator extends TransformationConstructors { + + private[this] val exprGenerator = new ExpressionGenerator + private[this] val optionGenerator = new OptionGenerator(exprGenerator) + private[this] val generator = new LogicalPlanGenerator(exprGenerator, optionGenerator) + + def initialGeneratorContext: GeneratorContext = GeneratorContext(generator) + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + def generate(optimizedLogicalPlan: ir.LogicalPlan): SQL = { + try { + generator.generate(optimizedLogicalPlan) + } catch { + case NonFatal(e) => + lift(KoResult(WorkflowStage.GENERATE, ir.UncaughtException(e))) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/TSqlToDatabricksTranspiler.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/TSqlToDatabricksTranspiler.scala new file mode 100644 index 0000000000..fa70c64b62 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/TSqlToDatabricksTranspiler.scala @@ -0,0 +1,7 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.parsers.tsql.TSqlPlanParser + +class TSqlToDatabricksTranspiler extends BaseTranspiler { + override val planParser = new TSqlPlanParser +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/TranspileException.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/TranspileException.scala new file mode 100644 index 0000000000..992da57abe --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/TranspileException.scala @@ -0,0 +1,5 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.intermediate.RemorphError + +case class TranspileException(err: RemorphError) extends RuntimeException(err.msg) diff --git a/core/src/main/scala/com/databricks/labs/remorph/transpilers/Transpiler.scala b/core/src/main/scala/com/databricks/labs/remorph/transpilers/Transpiler.scala new file mode 100644 index 0000000000..19c03be2be --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/transpilers/Transpiler.scala @@ -0,0 +1,87 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.preprocessors.jinja.JinjaProcessor +import com.databricks.labs.remorph.utils.Sed +import com.databricks.labs.remorph.{Generating, Optimizing, PreProcessing, Transformation, TransformationConstructors, intermediate => ir} +import com.github.vertical_blank.sqlformatter.SqlFormatter +import com.github.vertical_blank.sqlformatter.core.FormatConfig +import com.github.vertical_blank.sqlformatter.languages.Dialect +import org.antlr.v4.runtime.ParserRuleContext +import org.json4s.jackson.Serialization +import org.json4s.{Formats, NoTypeHints} + +trait Transpiler { + // TODO: get rid of the parameter, the initial phase should be provided by `run` or `runAndDiscardState` + def transpile(input: PreProcessing): Transformation[String] +} + +trait Formatter { + private[this] val sqlFormat = FormatConfig + .builder() + .indent(" ") + .uppercase(true) + .maxColumnLength(100) + .build() + + private[this] val formatter = SqlFormatter.of(Dialect.SparkSql) + + // sometimes we cannot just ignore legacy SQLGlot formatter and have to add hacks + private[this] val hacks = new Sed("EXISTS\\(" -> s"EXISTS (") + + def format(input: String): String = { + val pretty = formatter.format(input, sqlFormat) + hacks(pretty) + } +} + +abstract class BaseTranspiler extends Transpiler with Formatter with TransformationConstructors { + + protected val planParser: PlanParser[_] + private[this] val generator = new SqlGenerator + private[this] val jinjaProcessor = new JinjaProcessor + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + protected def pre: Transformation[Unit] = jinjaProcessor.pre + + protected def parse: Transformation[ParserRuleContext] = planParser.parse + + protected def visit(tree: ParserRuleContext): Transformation[ir.LogicalPlan] = planParser.visit(tree) + + // TODO: optimizer really should be its own thing and not part of PlanParser + // I have put it here for now until we discuss^h^h^h^h^h^h^hSerge dictates where it should go ;) + protected def optimize(logicalPlan: ir.LogicalPlan): Transformation[ir.LogicalPlan] = + planParser.optimize(logicalPlan) + + protected def generate(optimizedLogicalPlan: ir.LogicalPlan): Transformation[String] = { + updatePhase { + case o: Optimizing => + Generating( + optimizedPlan = optimizedLogicalPlan, + currentNode = optimizedLogicalPlan, + ctx = generator.initialGeneratorContext, + previousPhase = Some(o)) + case _ => + Generating( + optimizedPlan = optimizedLogicalPlan, + currentNode = optimizedLogicalPlan, + ctx = generator.initialGeneratorContext, + previousPhase = None) + }.flatMap { _ => + generator.generate(optimizedLogicalPlan) + } + } + + protected def post(input: String): Transformation[String] = jinjaProcessor.post(input) + + override def transpile(input: PreProcessing): Transformation[String] = { + setPhase(input) + .flatMap(_ => pre) + .flatMap(_ => parse) + .flatMap(visit) + .flatMap(optimize) + .flatMap(generate) + .flatMap(post) + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala b/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala new file mode 100644 index 0000000000..67b678eb66 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala @@ -0,0 +1,78 @@ +package com.databricks.labs.remorph.utils + +import com.databricks.labs.remorph.intermediate.TranspileFailure +import com.databricks.labs.remorph.{KoResult, OkResult, Result, WorkflowStage} + +import java.io._ +import scala.annotation.tailrec +import scala.sys.process.{Process, ProcessIO} +import scala.util.control.NonFatal + +class StandardInputPythonSubprocess(passArgs: String) { + def apply(input: String): Result[String] = { + val process = Process(s"$getEffectivePythonBin -m $passArgs", None) + val output = new StringBuilder + val error = new StringBuilder + try { + val result = process.run(createIO(input, output, error)).exitValue() + if (result != 0) { + KoResult(WorkflowStage.FORMAT, new TranspileFailure(new IOException(error.toString))) + } else { + OkResult(output.toString) + } + } catch { + case e: IOException if e.getMessage.contains("Cannot run") => + val failure = new TranspileFailure(new IOException("Invalid $PYTHON_BIN environment variable")) + KoResult(WorkflowStage.FORMAT, failure) + case NonFatal(e) => + KoResult(WorkflowStage.FORMAT, new TranspileFailure(e)) + } + } + + private def createIO(input: String, output: StringBuilder, error: StringBuilder) = new ProcessIO( + stdin => { + stdin.write(input.getBytes) + stdin.close() + }, + stdout => { + val reader = new BufferedReader(new InputStreamReader(stdout)) + var line: String = reader.readLine() + while (line != null) { + output.append(s"$line\n") + line = reader.readLine() + } + reader.close() + }, + stderr => { + val reader = new BufferedReader(new InputStreamReader(stderr)) + var line: String = reader.readLine() + while (line != null) { + error.append(s"$line\n") + line = reader.readLine() + } + reader.close() + }) + + private def getEffectivePythonBin: String = { + sys.env.getOrElse( + "PYTHON_BIN", { + val projectRoot = findLabsYmlFolderIn(new File(System.getProperty("user.dir"))) + val venvPython = new File(projectRoot, ".venv/bin/python") + venvPython.getAbsolutePath + }) + } + + @tailrec private def findLabsYmlFolderIn(path: File): File = { + if (new File(path, "labs.yml").exists()) { + path + } else { + val parent = path.getParentFile + if (parent == null) { + throw new FileNotFoundException( + "labs.yml not found anywhere in the project hierarchy. " + + "Please set PYTHON_BIN environment variable to point to the correct Python binary.") + } + findLabsYmlFolderIn(parent) + } + } +} diff --git a/core/src/main/scala/com/databricks/labs/remorph/utils/Strings.scala b/core/src/main/scala/com/databricks/labs/remorph/utils/Strings.scala new file mode 100644 index 0000000000..386a932905 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/utils/Strings.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.remorph.utils + +import java.io.{ByteArrayOutputStream, File, FileInputStream} +import java.nio.charset.Charset +import java.nio.charset.StandardCharsets.UTF_8 +import scala.util.matching.Regex + +/** + * This utility object is based on org.apache.spark.sql.catalyst.util + */ +object Strings { + def fileToString(file: File, encoding: Charset = UTF_8): String = { + val inStream = new FileInputStream(file) + val outStream = new ByteArrayOutputStream + try { + var reading = true + while (reading) { + inStream.read() match { + case -1 => reading = false + case c => outStream.write(c) + } + } + outStream.flush() + } finally { + inStream.close() + } + new String(outStream.toByteArray, encoding) + } + + def sideBySide(left: String, right: String): Seq[String] = { + sideBySide(left.split("\n"), right.split("\n")) + } + + def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { + val maxLeftSize = left.map(_.length).max + val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") + val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") + + leftPadded.zip(rightPadded).map { case (l, r) => + (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = { + truncatedString(seq, "", sep, "", maxFields) + } + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond maxNumToStringFields will be + * dropped and replaced by a "... N more fields" placeholder. + * + * @return + * the trimmed and formatted string. + */ + def truncatedString[T](seq: Seq[T], start: String, sep: String, end: String, maxFields: Int): String = { + if (seq.length > maxFields) { + val numFields = math.max(0, maxFields - 1) + seq.take(numFields).mkString(start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } +} + +class Sed(rules: (String, String)*) { + private[this] val compiledRules: Seq[(Regex, String)] = rules.map { case (regex, replace) => + (regex.r, replace) + } + + def apply(src: String): String = { + compiledRules.foldLeft(src) { (currentSrc, rule) => + val (regex, replace) = rule + regex.replaceAllIn(currentSrc, replace) + } + } +} diff --git a/core/src/test/resources/toolchain/testsource/not_sql.md b/core/src/test/resources/toolchain/testsource/not_sql.md new file mode 100644 index 0000000000..509c052d0f --- /dev/null +++ b/core/src/test/resources/toolchain/testsource/not_sql.md @@ -0,0 +1,3 @@ +# Sample sql files for testing the toolchain interface + +This file is present so that it can be filtered out. diff --git a/core/src/test/resources/toolchain/testsource/test_1.sql b/core/src/test/resources/toolchain/testsource/test_1.sql new file mode 100644 index 0000000000..ebc4f3bdef --- /dev/null +++ b/core/src/test/resources/toolchain/testsource/test_1.sql @@ -0,0 +1,6 @@ +-- +-- The use of CTEs is generally the same in Databricks SQL as TSQL but there are some differences with +-- nesting CTE support. +-- +-- tsql sql: +WITH cte AS (SELECT * FROM t) SELECT * FROM cte diff --git a/core/src/test/resources/toolchain/testsource/test_2.sql b/core/src/test/resources/toolchain/testsource/test_2.sql new file mode 100644 index 0000000000..f151dafd74 --- /dev/null +++ b/core/src/test/resources/toolchain/testsource/test_2.sql @@ -0,0 +1,22 @@ +-- ## WITH cte SELECT +-- +-- The use of CTEs is generally the same in Databricks SQL as TSQL but there are some differences with +-- nesting CTE support. +-- +-- tsql sql: + +WITH cteTable1 (col1, col2, col3count) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table1 + ), + cteTable2 (colx, coly, colxcount) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table2 + ) +SELECT col2, col1, col3count, cteTable2.colx, cteTable2.coly, cteTable2.colxcount +FROM cteTable1 +GO diff --git a/core/src/test/resources/toolchain/testsource/test_3.sql b/core/src/test/resources/toolchain/testsource/test_3.sql new file mode 100644 index 0000000000..e73a8ceb05 --- /dev/null +++ b/core/src/test/resources/toolchain/testsource/test_3.sql @@ -0,0 +1,12 @@ +-- ## WITH XMLWORKSPACES +-- +-- Databricks SQL does not currently support XML workspaces, so for now, we cover the syntax without recommending +-- a translation. +-- +-- tsql sql: +WITH XMLNAMESPACES ('somereference' as namespace) +SELECT col1 as 'namespace:col1', + col2 as 'namespace:col2' +FROM Table1 +WHERE col2 = 'xyz' +FOR XML RAW ('namespace:namespace'), ELEMENTS; diff --git a/core/src/test/scala/com/databricks/labs/remorph/TransformationTest.scala b/core/src/test/scala/com/databricks/labs/remorph/TransformationTest.scala new file mode 100644 index 0000000000..2d69184859 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/TransformationTest.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.remorph + +import com.databricks.labs.remorph.intermediate.{RemorphErrors, UnexpectedNode} +import com.databricks.labs.remorph.preprocessors.jinja.TemplateManager +import com.databricks.labs.remorph.transpilers.Transpiler +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TransformationTest extends AnyWordSpec with Matchers with TransformationConstructors { + + val stubTranspiler = new Transpiler { + override def transpile(input: PreProcessing): Transformation[String] = + for { + _ <- setPhase(Parsing("foo")) + parsed <- lift(PartialResult("bar", UnexpectedNode("foo"))) + _ <- updatePhase { case p: Parsing => BuildingAst(null, Some(p)) } + ast <- lift(PartialResult("qux", UnexpectedNode(parsed))) + _ <- updatePhase { case b: BuildingAst => Optimizing(null, Some(b)) } + opt <- lift(PartialResult("zaz", UnexpectedNode(ast))) + _ <- updatePhase { case o: Optimizing => Generating(null, null, null, previousPhase = Some(o)) } + gen <- lift(PartialResult("nin", UnexpectedNode(opt))) + } yield gen + } + + "Transformation" should { + "collect errors in each phase" in { + val tm = new TemplateManager() + stubTranspiler.transpile(PreProcessing("foo")).run(TranspilerState(Init, tm)).map(_._1) shouldBe PartialResult( + TranspilerState( + Generating( + null, + null, + null, + 0, + 0, + Some(Optimizing( + null, + Some(BuildingAst( + null, + Some(Parsing("foo", "-- test source --", None, List(UnexpectedNode("foo")))), + List(UnexpectedNode("bar")))), + List(UnexpectedNode("qux")))), + List()), + tm), + RemorphErrors(List(UnexpectedNode("foo"), UnexpectedNode("bar"), UnexpectedNode("qux"), UnexpectedNode("zaz")))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/connections/EnvGetter.scala b/core/src/test/scala/com/databricks/labs/remorph/connections/EnvGetter.scala new file mode 100644 index 0000000000..3de1cd1317 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/connections/EnvGetter.scala @@ -0,0 +1,27 @@ +package com.databricks.labs.remorph.connections + +import com.databricks.labs.remorph.utils.Strings +import com.typesafe.scalalogging.LazyLogging +import org.scalatest.exceptions.TestCanceledException +import io.circe.jackson +import java.io.{File, FileNotFoundException} + +class EnvGetter extends LazyLogging { + private[this] val env = getDebugEnv + + def get(key: String): String = env.getOrElse(key, throw new TestCanceledException(s"not in env: $key", 3)) + + private def getDebugEnv: Map[String, String] = { + try { + val debugEnvFile = String.format("%s/.databricks/debug-env.json", System.getProperty("user.home")) + val contents = Strings.fileToString(new File(debugEnvFile)) + logger.debug(s"Found debug env file: $debugEnvFile") + + val raw = jackson.decode[Map[String, Map[String, String]]](contents).getOrElse(Map.empty) + val ucws = raw("ucws") + ucws + } catch { + case _: FileNotFoundException => sys.env + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/connections/SnowflakeConnectionFactory.scala b/core/src/test/scala/com/databricks/labs/remorph/connections/SnowflakeConnectionFactory.scala new file mode 100644 index 0000000000..a0e77bf3b9 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/connections/SnowflakeConnectionFactory.scala @@ -0,0 +1,38 @@ +package com.databricks.labs.remorph.connections + +import net.snowflake.client.jdbc.internal.org.bouncycastle.jce.provider.BouncyCastleProvider + +import java.security.spec.PKCS8EncodedKeySpec +import java.security.{KeyFactory, PrivateKey, Security} +import java.sql.{Connection, DriverManager} +import java.util.{Base64, Properties} + +class SnowflakeConnectionFactory(env: EnvGetter) { + // scalastyle:off + Class.forName("net.snowflake.client.jdbc.SnowflakeDriver") + // scalastyle:on + + private[this] val url = env.get("TEST_SNOWFLAKE_JDBC") + private[this] val privateKeyPEM = env.get("TEST_SNOWFLAKE_PRIVATE_KEY") + + private def privateKey: PrivateKey = { + Security.addProvider(new BouncyCastleProvider()) + val keySpecPKCS8 = new PKCS8EncodedKeySpec( + Base64.getDecoder.decode( + privateKeyPEM + .split("\n") + .drop(1) + .dropRight(1) + .mkString)) + val kf = KeyFactory.getInstance("RSA") + kf.generatePrivate(keySpecPKCS8) + } + + private[this] val props = { + val p = new Properties() + p.put("privateKey", privateKey) + p + } + + def newConnection(): Connection = DriverManager.getConnection(url, props) +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/connections/TSqlConnectionFactory.scala b/core/src/test/scala/com/databricks/labs/remorph/connections/TSqlConnectionFactory.scala new file mode 100644 index 0000000000..090c2230c7 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/connections/TSqlConnectionFactory.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.connections + +import java.sql.{Connection, DriverManager} + +class TSqlConnectionFactory(env: EnvGetter) { + + Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver") + private[this] val jdbcUrl = env.get("TEST_TSQL_JDBC") + private[this] val username = env.get("TEST_TSQL_USER") + private[this] val password = env.get("TEST_TSQL_PASS") + + def newConnection(): Connection = DriverManager.getConnection(jdbcUrl, username, password) +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/AcceptanceSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/AcceptanceSpec.scala new file mode 100644 index 0000000000..363ced810c --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/coverage/AcceptanceSpec.scala @@ -0,0 +1,98 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.queries.{CommentBasedQueryExtractor, NestedFiles, TestFile} +import org.scalatest.Ignore +import org.scalatest.flatspec.AnyFlatSpec + +import java.nio.file.{Path, Paths} + +abstract class AcceptanceSpec(runner: AcceptanceTestRunner) extends AnyFlatSpec { + runner.foreachTest { test => + registerTest(test.testName) { + runner.runAcceptanceTest(test) match { + case None => pending + case Some(r) if r.isSuccess => succeed + case Some(r) if runner.shouldFailParse(test.testName) && r.failedParseOnly => succeed + case Some(report) => fail(report.errorMessage.getOrElse("")) + } + } + } +} + +object SnowflakeAcceptanceSuite { + + val rootPath: Path = Paths.get( + Option(System.getProperty("snowflake.test.resources.path")) + .getOrElse(s"${NestedFiles.projectRoot}/tests/resources/functional/snowflake")) +} + +class SnowflakeAcceptanceSuite + extends AcceptanceSpec( + new AcceptanceTestRunner( + AcceptanceTestConfig( + new NestedFiles(SnowflakeAcceptanceSuite.rootPath), + new CommentBasedQueryExtractor("snowflake", "databricks"), + new IsTranspiledFromSnowflakeQueryRunner, + ignoredTestNames = Set( + "aggregates/listagg/test_listagg_4.sql", + "cast/test_typecasts.sql", + "sqlglot-incorrect/test_uuid_string_2.sql", + "test_command/test_command_2.sql", + "test_command/test_command_3.sql", + "test_skip_unsupported_operations/test_skip_unsupported_operations_7.sql", + "test_skip_unsupported_operations/test_skip_unsupported_operations_9.sql", + "test_skip_unsupported_operations/test_skip_unsupported_operations_10.sql", + // TODO - Fix these tests as part of the lateral view + "arrays/test_array_construct_1.sql", + "arrays/test_array_construct_2.sql", + "functions/parse_json/test_parse_json_3.sql"), + shouldFailParse = Set( + "core_engine/test_invalid_syntax/syntax_error_1.sql", + "core_engine/test_invalid_syntax/syntax_error_2.sql", + "core_engine/test_invalid_syntax/syntax_error_3.sql")))) + +@Ignore // this is for debugging individual acceptance tests, simply uncomment to debug the test +class SnowflakeAcceptanceTest + extends AcceptanceSpec( + new AcceptanceTestRunner( + AcceptanceTestConfig( + new TestFile(Paths.get(SnowflakeAcceptanceSuite.rootPath.toString, "core_engine", "lca", "lca_homonym.sql")), + new CommentBasedQueryExtractor("snowflake", "databricks"), + new IsTranspiledFromSnowflakeQueryRunner))) + +class TSqlAcceptanceSuite + extends AcceptanceSpec( + new AcceptanceTestRunner( + AcceptanceTestConfig( + new NestedFiles(Paths.get(Option(System.getProperty("tsql.test.resources.path")) + .getOrElse(s"${NestedFiles.projectRoot}/tests/resources/functional/tsql"))), + new CommentBasedQueryExtractor("tsql", "databricks"), + new IsTranspiledFromTSqlQueryRunner, + ignoredTestNames = Set( + "functions/test_aadbts_1.sql", + "functions/test_aalangid1.sql", + "functions/test_aalanguage_1.sql", + "functions/test_aalock_timeout_1.sql", + "functions/test_aamax_connections_1.sql", + "functions/test_aamax_precision_1.sql", + "functions/test_aaoptions_1.sql", + "functions/test_aaremserver_1.sql", + "functions/test_aaservername_1.sql", + "functions/test_aaservicename_1.sql", + "functions/test_aaspid_1.sql", + "functions/test_aatextsize_1.sql", + "functions/test_aaversion_1.sql", + "functions/test_approx_count_distinct.sql", + "functions/test_approx_percentile_cont_1.sql", + "functions/test_approx_percentile_disc_1.sql", + "functions/test_collationproperty_1.sql", + "functions/test_grouping_1.sql", + "functions/test_nestlevel_1.sql", + "functions/test_percent_rank_1.sql", + "functions/test_percentile_cont_1.sql", + "functions/test_percentile_disc_1.sql", + "select/test_cte_xml.sql"), + shouldFailParse = Set( + "core_engine/test_invalid_syntax/syntax_error_1.sql", + "core_engine/test_invalid_syntax/syntax_error_2.sql", + "core_engine/test_invalid_syntax/syntax_error_3.sql")))) diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/AcceptanceTestRunner.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/AcceptanceTestRunner.scala new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/EstimationTest.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/EstimationTest.scala new file mode 100644 index 0000000000..6300cce743 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/coverage/EstimationTest.scala @@ -0,0 +1,87 @@ +package com.databricks.labs.remorph.coverage + +import com.databricks.labs.remorph.coverage.estimation.{EstimationAnalyzer, Estimator} +import com.databricks.labs.remorph.discovery.{ExecutedQuery, QueryHistory, QueryHistoryProvider} +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import java.sql.Timestamp +import java.time.Duration + +class EstimationTest extends AnyFlatSpec with Matchers with MockitoSugar { + + "Estimator" should "correctly process query history" in { + // Mock dependencies + val mockQueryHistoryProvider = mock[QueryHistoryProvider] + + // Real dependencies + val planParser = new SnowflakePlanParser + val analyzer = new EstimationAnalyzer + + // Mock query history + val mockHistory = QueryHistory( + Seq( + ExecutedQuery( + "id1", + "SELECT * FROM table1", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("user1")), + ExecutedQuery( + "id2", + "SELECT * FROM table2", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("user2")))) + when(mockQueryHistoryProvider.history()).thenReturn(mockHistory) + + // Create Estimator instance + val estimator = new Estimator(mockQueryHistoryProvider, planParser, analyzer) + + // Run the estimator + val report = estimator.run() + + // Verify the results + report.sampleSize should be(2) + report.uniqueSuccesses should be(2) + report.parseFailures should be(0) + report.transpileFailures should be(0) + } + + it should "handle parsing errors" in { + // Mock dependencies + val mockQueryHistoryProvider = mock[QueryHistoryProvider] + + // Real dependencies + val planParser = new SnowflakePlanParser + val analyzer = new EstimationAnalyzer + + // Mock query history + val mockHistory = QueryHistory( + Seq( + ExecutedQuery( + "id1", + "SOME GARBAGE STATEMENT", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("user1")))) + when(mockQueryHistoryProvider.history()).thenReturn(mockHistory) + + // Create Estimator instance + val estimator = new Estimator(mockQueryHistoryProvider, planParser, analyzer) + + // Run the estimator + val report = estimator.run() + + // Verify the results + report.sampleSize should be(1) + report.uniqueSuccesses should be(0) + report.parseFailures should be(1) + report.transpileFailures should be(0) + } + + // Add more test cases as needed +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQLTest.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQLTest.scala new file mode 100644 index 0000000000..486dce4424 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/DatabricksSQLTest.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.remorph.coverage.runners + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class DatabricksSQLTest extends AnyWordSpec with Matchers { + "connectivity works" ignore { + val env = new EnvGetter() + val databricksSQL = new DatabricksSQL(env) + databricksSQL.spark should not be null + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeTest.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeTest.scala new file mode 100644 index 0000000000..3a0407f800 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/SnowflakeTest.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.coverage.runners + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeTest extends AnyWordSpec with Matchers { + "connectivity works" ignore { + val env = new EnvGetter() + val tsqlRunner = new SnowflakeRunner(env) + val res = tsqlRunner.queryToCSV("SHOW DATABASES") + res should include("REMORPH") + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/TSqlTest.scala b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/TSqlTest.scala new file mode 100644 index 0000000000..2dd08d5916 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/coverage/runners/TSqlTest.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.remorph.coverage.runners + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlTest extends AnyWordSpec with Matchers { + "connectivity works" ignore { + val env = new EnvGetter() + val tsqlRunner = new TSqlRunner(env) + val res = tsqlRunner.queryToCSV("SELECT name, database_id, create_date FROM sys.databases") + res should include("master") + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/discovery/AnonymizerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/discovery/AnonymizerTest.scala new file mode 100644 index 0000000000..87531ff454 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/discovery/AnonymizerTest.scala @@ -0,0 +1,102 @@ +package com.databricks.labs.remorph.discovery + +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import java.sql.Timestamp +import java.time.Duration + +class AnonymizerTest extends AnyWordSpec with Matchers { + "Anonymizer" should { + "work in happy path" in { + val snow = new SnowflakePlanParser + val anonymizer = new Anonymizer(snow) + val query = ExecutedQuery( + "id", + "SELECT a, b FROM c WHERE d >= 300 AND e = 'foo'", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("foo")) + + anonymizer.fingerprint(query) should equal( + Fingerprint( + "id", + new Timestamp(1725032011000L), + "b0b00569bfa1fe3975afc221a4a24630a0ab4ec9", + Duration.ofMillis(300), + "foo", + WorkloadType.SQL_SERVING, + QueryType.DML)) + } + + "work in happy path with DDL" in { + val snow = new SnowflakePlanParser + val anonymizer = new Anonymizer(snow) + val query = + ExecutedQuery( + "id", + "CREATE TABLE foo (a INT, b STRING)", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("foo")) + + anonymizer.fingerprint(query) should equal( + Fingerprint( + "id", + new Timestamp(1725032011000L), + "828f7eb7d417310ab5c1673c96ec82c47f0231e4", + Duration.ofMillis(300), + "foo", + WorkloadType.ETL, + QueryType.DDL)) + } + + "trap an unknown query" in { + val snow = new SnowflakePlanParser + val anonymizer = new Anonymizer(snow) + val query = + ExecutedQuery("id", "THIS IS UNKNOWN;", new Timestamp(1725032011000L), Duration.ofMillis(300), Some("foo")) + + anonymizer.fingerprint(query) should equal( + Fingerprint( + "id", + new Timestamp(1725032011000L), + "290f4d72ca8faeb28873d8fff779ce93ed5cdb69", + Duration.ofMillis(300), + "foo", + WorkloadType.OTHER, + QueryType.OTHER)) + } + } + + "Fingerprints" should { + "work" in { + val snow = new SnowflakePlanParser + val anonymizer = new Anonymizer(snow) + val history = QueryHistory( + Seq( + ExecutedQuery( + "id", + "SELECT a, b FROM c WHERE d >= 300 AND e = 'foo'", + new Timestamp(1725032011000L), + Duration.ofMillis(300), + Some("foo")), + ExecutedQuery( + "id", + "SELECT a, b FROM c WHERE d >= 931 AND e = 'bar'", + new Timestamp(1725032011001L), + Duration.ofMillis(300), + Some("foo")), + ExecutedQuery( + "id", + "SELECT a, b FROM c WHERE d >= 234 AND e = 'something very different'", + new Timestamp(1725032011002L), + Duration.ofMillis(300), + Some("foo")))) + + val fingerprints = anonymizer.apply(history) + fingerprints.uniqueQueries should equal(1) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/discovery/FileQueryHistorySpec.scala b/core/src/test/scala/com/databricks/labs/remorph/discovery/FileQueryHistorySpec.scala new file mode 100644 index 0000000000..27465174d4 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/discovery/FileQueryHistorySpec.scala @@ -0,0 +1,39 @@ +package com.databricks.labs.remorph.discovery + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import java.nio.file.Files +import java.nio.file.StandardOpenOption._ +import scala.collection.JavaConverters._ + +class FileQueryHistorySpec extends AnyFlatSpec with Matchers { + + "FileQueryHistory" should "correctly extract queries from SQL files" in { + val tempDir = Files.createTempDirectory("test_sql_files") + + try { + val sqlFile = tempDir.resolve("test.sql") + val sqlContent = + """ + |SELECT * FROM table1; + |SELECT * FROM table2; + |""".stripMargin + Files.write(sqlFile, sqlContent.getBytes, CREATE, WRITE) + + val fileQueryHistory = new FileQueryHistory(tempDir) + + val queryHistory = fileQueryHistory.history() + + queryHistory.queries should have size 1 + queryHistory.queries.head.source shouldBe + """ + |SELECT * FROM table1; + |SELECT * FROM table2; + |""".stripMargin + + } finally { + // Clean up the temporary directory + Files.walk(tempDir).iterator().asScala.toSeq.reverse.foreach(Files.delete) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitionTest.scala b/core/src/test/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitionTest.scala new file mode 100644 index 0000000000..9a8e2dce7e --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/discovery/SnowflakeTableDefinitionTest.scala @@ -0,0 +1,76 @@ +package com.databricks.labs.remorph.discovery + +import java.sql.{Connection, ResultSet, Statement} + +import org.mockito.ArgumentMatchers._ +import org.mockito.Mockito._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeTableDefinitionTest extends AnyWordSpec with Matchers { + + "getTableDefinitions" should { + + "return table definitions for a all catalogs" in { + val mockConn = mock(classOf[Connection]) + val mockStmt = mock(classOf[Statement]) + val mockRs = mock(classOf[ResultSet]) + + when(mockConn.createStatement()).thenReturn(mockStmt) + when(mockStmt.executeQuery(anyString())).thenReturn(mockRs) + when(mockRs.next()).thenReturn(true, false) + when(mockRs.getString("TABLE_CATALOG")).thenReturn("CATALOG") + when(mockRs.getString("TABLE_SCHEMA")).thenReturn("SCHEMA") + when(mockRs.getString("TABLE_NAME")).thenReturn("TABLE") + when(mockRs.getString("DERIVED_SCHEMA")).thenReturn("col1§int§true§comment‡col2§string§false§comment") + when(mockRs.getString("LOCATION")).thenReturn(null) + when(mockRs.getString("FILE_FORMAT_NAME")).thenReturn(null) + when(mockRs.getString("VIEW_DEFINITION")).thenReturn(null) + when(mockRs.getInt("BYTES")).thenReturn(1024 * 1024 * 1024) + + // Mock behavior for getAllCatalogs + val mockCatalogResultSet = mock(classOf[ResultSet]) + when(mockStmt.executeQuery("SHOW DATABASES")).thenReturn(mockCatalogResultSet) + when(mockCatalogResultSet.next()).thenReturn(true, false) + when(mockCatalogResultSet.getString("name")).thenReturn("TEST_CATALOG") + + val snowflakeTableDefinitions = new SnowflakeTableDefinitions(mockConn) + val result = snowflakeTableDefinitions.getAllTableDefinitions + + result should have size 1 + result.head.columns should have size 2 + } + + "return all schemas for a valid catalog" in { + val mockConn = mock(classOf[Connection]) + val mockStmt = mock(classOf[Statement]) + val mockRs = mock(classOf[ResultSet]) + + when(mockConn.createStatement()).thenReturn(mockStmt) + when(mockStmt.executeQuery(anyString())).thenReturn(mockRs) + when(mockRs.next()).thenReturn(true, true, false) + when(mockRs.getString("name")).thenReturn("SCHEMA1", "SCHEMA2") + + val snowflakeTableDefinitions = new SnowflakeTableDefinitions(mockConn) + val result = snowflakeTableDefinitions.getAllSchemas("CATALOG") + + result should contain allOf ("SCHEMA1", "SCHEMA2") + } + + "return all catalogs" in { + val mockConn = mock(classOf[Connection]) + val mockStmt = mock(classOf[Statement]) + val mockRs = mock(classOf[ResultSet]) + + when(mockConn.createStatement()).thenReturn(mockStmt) + when(mockStmt.executeQuery(anyString())).thenReturn(mockRs) + when(mockRs.next()).thenReturn(true, true, false) + when(mockRs.getString("name")).thenReturn("CATALOG1", "CATALOG2") + + val snowflakeTableDefinitions = new SnowflakeTableDefinitions(mockConn) + val result = snowflakeTableDefinitions.getAllCatalogs + + result should contain allOf ("CATALOG1", "CATALOG2") + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitionTest.scala b/core/src/test/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitionTest.scala new file mode 100644 index 0000000000..29cb353e50 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/discovery/TSqlTableDefinitionTest.scala @@ -0,0 +1,67 @@ +package com.databricks.labs.remorph.discovery + +import org.mockito.ArgumentMatchers.anyString +import org.mockito.Mockito._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import java.sql.{Connection, ResultSet, Statement} + +class TSqlTableDefinitionTest extends AnyFlatSpec with Matchers with MockitoSugar { + + "getTableDefinitions" should "return correct table definitions for valid catalog name" in { + val conn = mock[Connection] + val stmt = mock[Statement] + val rs = mock[ResultSet] + val mockRs = mock[ResultSet] + when(conn.createStatement()).thenReturn(stmt) + when(stmt.executeQuery(anyString())).thenReturn(mockRs) + when(mockRs.next()).thenReturn(true, false) + when(mockRs.getString("TABLE_CATALOG")).thenReturn("CATALOG") + when(mockRs.getString("TABLE_SCHEMA")).thenReturn("SCHEMA") + when(mockRs.getString("TABLE_NAME")).thenReturn("TABLE") + when(mockRs.getString("DERIVED_SCHEMA")).thenReturn("col1§int§true§hi‡col2§string§false§hi") + when(mockRs.getString("LOCATION")).thenReturn(null) + when(mockRs.getString("FILE_FORMAT_NAME")).thenReturn(null) + when(mockRs.getString("VIEW_DEFINITION")).thenReturn(null) + when(mockRs.getInt("BYTES")).thenReturn(1024 * 1024 * 1024) + + when(stmt.executeQuery("SELECT NAME FROM sys.databases WHERE NAME != 'MASTER'")).thenReturn(rs) + when(rs.next()).thenReturn(true, false) + when(rs.getString("name")).thenReturn("catalog") + + val tSqlTableDefinitions = new TSqlTableDefinitions(conn) + val result = tSqlTableDefinitions.getAllTableDefinitions + result should not be empty + } + + "getAllSchemas" should "return all schemas for valid catalog name" in { + val conn = mock[Connection] + val stmt = mock[Statement] + val rs = mock[ResultSet] + when(conn.createStatement()).thenReturn(stmt) + when(stmt.executeQuery(anyString())).thenReturn(rs) + when(rs.next()).thenReturn(true, false) + when(rs.getString("SCHEMA_NAME")).thenReturn("schema") + + val tSqlTableDefinitions = new TSqlTableDefinitions(conn) + val result = tSqlTableDefinitions.getAllSchemas("catalog") + result should contain("schema") + } + + "getAllCatalogs" should "return all catalogs" in { + val conn = mock[Connection] + val stmt = mock[Statement] + val rs = mock[ResultSet] + when(conn.createStatement()).thenReturn(stmt) + when(stmt.executeQuery(anyString())).thenReturn(rs) + when(rs.next()).thenReturn(true, false) + when(rs.getString("name")).thenReturn("catalog") + + val tSqlTableDefinitions = new TSqlTableDefinitions(conn) + val result = tSqlTableDefinitions.getAllCatalogs + result should contain("catalog") + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/CodeInterpolatorSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/CodeInterpolatorSpec.scala new file mode 100644 index 0000000000..d701521ea7 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/CodeInterpolatorSpec.scala @@ -0,0 +1,97 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class CodeInterpolatorSpec extends AnyWordSpec with Matchers with TransformationConstructors { + + "SQLInterpolator" should { + + "interpolate the empty string" in { + code"".runAndDiscardState(TranspilerState()) shouldBe OkResult("") + } + + "interpolate argument-less strings" in { + code"foo".runAndDiscardState(TranspilerState()) shouldBe OkResult("foo") + } + + "interpolate argument-less strings with escape sequences" in { + code"\tbar\n".runAndDiscardState(TranspilerState()) shouldBe OkResult("\tbar\n") + } + + "interpolate strings consisting only in a single String argument" in { + val arg = "FOO" + code"$arg".runAndDiscardState(TranspilerState()) shouldBe OkResult("FOO") + } + + "interpolate strings consisting only in a single OkResult argument" in { + val arg = ok("FOO") + code"$arg".runAndDiscardState(TranspilerState()) shouldBe OkResult("FOO") + } + + "interpolate strings consisting only in a single argument that is neither String nor OkResult" in { + val arg = 42 + code"$arg".runAndDiscardState(TranspilerState()) shouldBe OkResult("42") + } + + "interpolate strings with multiple arguments" in { + val arg1 = "foo" + val arg2 = ok("bar") + val arg3 = 42 + code"arg1: $arg1, arg2: $arg2, arg3: $arg3".runAndDiscardState(TranspilerState()) shouldBe OkResult( + "arg1: foo, arg2: bar, arg3: 42") + } + + "accumulate errors when some arguments are PartialResults" in { + val arg1 = lift(PartialResult("!!! error 1 !!!", UnexpectedNode(Noop.toString))) + val arg2 = "foo" + val arg3 = lift(PartialResult("!!! error 2 !!!", UnsupportedDataType(IntegerType.toString))) + + code"SELECT $arg1 FROM $arg2 WHERE $arg3".runAndDiscardState(TranspilerState()) shouldBe PartialResult( + "SELECT !!! error 1 !!! FROM foo WHERE !!! error 2 !!!", + RemorphErrors(Seq(UnexpectedNode(Noop.toString), UnsupportedDataType(IntegerType.toString)))) + } + + "return a KoResult if any one of the arguments is a KoResult" in { + val arg1 = "foo" + val arg2 = lift(KoResult(WorkflowStage.GENERATE, UnexpectedNode(Noop.toString))) + val arg3 = 42 + code"arg1: $arg1, arg2: $arg2, arg3: $arg3".runAndDiscardState(TranspilerState()) shouldBe KoResult( + WorkflowStage.GENERATE, + UnexpectedNode(Noop.toString)) + } + + "work nicely with mkTba" in { + val arg1 = "foo" + val arg2 = lift(PartialResult("!boom!", UnexpectedNode(Noop.toString))) + val arg3 = 42 + Seq(code"arg1: $arg1", code"arg2: $arg2", code"arg3: $arg3") + .mkCode(", ") + .runAndDiscardState(TranspilerState()) shouldBe PartialResult( + "arg1: foo, arg2: !boom!, arg3: 42", + UnexpectedNode(Noop.toString)) + } + + "unfortunately, if evaluating one of the arguments throws an exception, " + + "it cannot be caught by the interpolator because arguments are evaluated eagerly" in { + def boom(): Unit = throw new RuntimeException("boom") + val three = ok("3") + val two = ok("2") + val one = ok("1") + val aftermath = ok("everything exploded") + a[RuntimeException] should be thrownBy code"$three...$two...$one...${boom()}...$aftermath" + } + + "wrap 'normal' exception, such as invalid escapes, in a failure" in { + code"\D".runAndDiscardState(TranspilerState()) shouldBe an[KoResult] + } + + "foo" in { + val foo = RemorphError + println(foo) + succeed + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/GeneratorTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/GeneratorTestCommon.scala new file mode 100644 index 0000000000..b7f84efca3 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/GeneratorTestCommon.scala @@ -0,0 +1,24 @@ +package com.databricks.labs.remorph.generators + +import com.databricks.labs.remorph.{Generating, OkResult, TranspilerState, intermediate => ir} +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers + +trait GeneratorTestCommon[T <: ir.TreeNode[T]] extends Matchers { + + protected def generator: Generator[T, String] + protected def initialState(t: T): Generating + + implicit class TestOps(t: T) { + def generates(expectedOutput: String): Assertion = { + generator.generate(t).runAndDiscardState(TranspilerState(initialState(t))) shouldBe OkResult(expectedOutput) + } + + def doesNotTranspile: Assertion = { + generator + .generate(t) + .runAndDiscardState(TranspilerState(initialState(t))) + .isInstanceOf[OkResult[_]] shouldBe false + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGeneratorTest.scala new file mode 100644 index 0000000000..dee70f3a08 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/FileSetGeneratorTest.scala @@ -0,0 +1,92 @@ +package com.databricks.labs.remorph.generators.orchestration + +import com.databricks.labs.remorph.discovery.{ExecutedQuery, QueryHistory} +import com.databricks.labs.remorph.generators.orchestration.rules.history.RawMigration +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser +import com.databricks.labs.remorph.transpilers.{PySparkGenerator, SqlGenerator} +import com.databricks.labs.remorph.{TranspilerState, KoResult, OkResult, PartialResult} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import java.sql.Timestamp +import java.time.Duration + +class FileSetGeneratorTest extends AnyWordSpec with Matchers { + private[this] val parser = new SnowflakePlanParser() + private[this] val sqlGen = new SqlGenerator + private[this] val pyGen = new PySparkGenerator + private[this] val queryHistory = QueryHistory( + Seq( + ExecutedQuery( + "query1", + s"""INSERT INTO schema1.table1 SELECT col1, col2 + |FROM schema2.table2 AS t2 INNER JOIN schema1.table3 AS t3 ON t2.id = t3.id + |""".stripMargin, + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(30), + Some("user1")), + ExecutedQuery( + "query3", + "SELECT * FROM schema2.table3 AS t3 JOIN schema2.table4 AS t4 ON t3.id = t4.id", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(60), + Some("user3")))) + + "FileSetGenerator" should { + "work" in { + new FileSetGenerator(parser, sqlGen, pyGen).generate(RawMigration(queryHistory)).run(TranspilerState()) match { + case OkResult((_, fileSet)) => + fileSet.getFile("notebooks/query1.sql").get shouldBe + s"""INSERT INTO + | schema1.table1 + |SELECT + | col1, + | col2 + |FROM + | schema2.table2 AS t2 + | INNER JOIN schema1.table3 AS t3 ON t2.id = t3.id;""".stripMargin + fileSet.getFile("notebooks/query3.sql").get shouldBe + s"""SELECT + | * + |FROM + | schema2.table3 AS t3 + | JOIN schema2.table4 AS t4 ON t3.id = t4.id;""".stripMargin + fileSet.getFile("databricks.yml").get shouldBe + s"""--- + |bundle: + | name: "remorphed" + |targets: + | dev: + | mode: "development" + | default: true + | prod: + | mode: "production" + |resources: + | jobs: + | migrated_via_remorph: + | name: "[$${bundle.target}] Migrated via Remorph" + | tags: + | generator: "remorph" + | tasks: + | - notebook_task: + | notebook_path: "notebooks/query1.sql" + | warehouse_id: "__DEFAULT_WAREHOUSE_ID__" + | task_key: "query1" + | - notebook_task: + | notebook_path: "notebooks/query3.sql" + | warehouse_id: "__DEFAULT_WAREHOUSE_ID__" + | task_key: "query3" + | schemas: + | schema2: + | catalog_name: "main" + | name: "schema2" + | schema1: + | catalog_name: "main" + | name: "schema1" + |""".stripMargin + case PartialResult(_, error) => fail(error.msg) + case KoResult(_, error) => fail(error.msg) + } + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFileTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFileTest.scala new file mode 100644 index 0000000000..45c057c7a5 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/orchestration/rules/GenerateBundleFileTest.scala @@ -0,0 +1,47 @@ +package com.databricks.labs.remorph.generators.orchestration.rules + +import com.databricks.labs.remorph.generators.orchestration.rules.bundles.Schema +import com.databricks.labs.remorph.generators.orchestration.rules.converted.{CreatedFile, PythonNotebookTask} +import com.databricks.labs.remorph.generators.orchestration.rules.history.Migration +import com.databricks.labs.remorph.intermediate.workflows.jobs.JobSettings +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class GenerateBundleFileTest extends AnyWordSpec with Matchers { + private[this] val rule = new GenerateBundleFile + + "GenerateBundleFile" should { + "skip nulls" in { + val task = PythonNotebookTask(CreatedFile("notebooks/foo.py", "import foo")).toTask + val tree = rule.apply( + Migration(Seq(Schema("main", "foo"), Schema("main", "bar"), JobSettings("main workflow", Seq(task))))) + tree.find(_.isInstanceOf[CreatedFile]).get shouldBe CreatedFile( + "databricks.yml", + s"""--- + |bundle: + | name: "remorphed" + |targets: + | dev: + | mode: "development" + | default: true + | prod: + | mode: "production" + |resources: + | jobs: + | main_workflow: + | name: "[$${bundle.target}] main workflow" + | tasks: + | - notebook_task: + | notebook_path: "notebooks/foo.py" + | task_key: "foo" + | schemas: + | foo: + | catalog_name: "main" + | name: "foo" + | bar: + | catalog_name: "main" + | name: "bar" + |""".stripMargin) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/py/ExpressionGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/py/ExpressionGeneratorTest.scala new file mode 100644 index 0000000000..84369f5d47 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/py/ExpressionGeneratorTest.scala @@ -0,0 +1,257 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.generators.{GeneratorContext, GeneratorTestCommon} +import com.databricks.labs.remorph.{Generating, intermediate => ir} +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class ExpressionGeneratorTest + extends AnyWordSpec + with GeneratorTestCommon[ir.Expression] + with MockitoSugar + with ir.IRHelpers { + + override protected val generator = new ExpressionGenerator + + private[this] val logical = new LogicalPlanGenerator + + override def initialState(expr: ir.Expression): Generating = + Generating(optimizedPlan = ir.Batch(Seq.empty), currentNode = expr, ctx = GeneratorContext(logical)) + + "name" in { + ir.Name("a") generates "a" + } + + "literals" should { + "generate string" in { + ir.Literal("a") generates "'a'" + } + "generate int" in { + ir.Literal(1) generates "1" + } + "generate float" in { + ir.DoubleLiteral(1.0) generates "1.0" + } + "generate boolean" in { + ir.Literal(true) generates "True" + ir.Literal(false) generates "False" + } + "generate null" in { + ir.Literal(null) generates "None" + } + } + + "predicates" should { + "a > b" in { + ir.GreaterThan(ir.Name("a"), ir.Name("b")) generates "a > b" + } + "a >= b" in { + ir.GreaterThanOrEqual(ir.Name("a"), ir.Name("b")) generates "a >= b" + } + "a < b" in { + ir.LessThan(ir.Name("a"), ir.Name("b")) generates "a < b" + } + "a <= b" in { + ir.LessThanOrEqual(ir.Name("a"), ir.Name("b")) generates "a <= b" + } + "a != b" in { + ir.NotEquals(ir.Name("a"), ir.Name("b")) generates "a != b" + } + "a == b" in { + ir.Equals(ir.Name("a"), ir.Name("b")) generates "a == b" + } + "~a" in { + ir.Not(ir.Name("a")) generates "~(a)" + } + "a or b" in { + ir.Or(ir.Name("a"), ir.Name("b")) generates "a or b" + } + "a and b" in { + ir.And(ir.Name("a"), ir.Name("b")) generates "a and b" + } + } + + "arithmetic" should { + "a + b" in { + ir.Add(ir.Name("a"), ir.Name("b")) generates "a + b" + } + "a - b" in { + ir.Subtract(ir.Name("a"), ir.Name("b")) generates "a - b" + } + "a * b" in { + ir.Multiply(ir.Name("a"), ir.Name("b")) generates "a * b" + } + "a / b" in { + ir.Divide(ir.Name("a"), ir.Name("b")) generates "a / b" + } + "a % b" in { + ir.Mod(ir.Name("a"), ir.Name("b")) generates "a % b" + } + "-a" in { + ir.UMinus(ir.Name("a")) generates "-a" + } + "+a" in { + ir.UPlus(ir.Name("a")) generates "+a" + } + } + + "python calls" should { + "f()" in { + Call(ir.Name("f"), Seq.empty, Seq.empty) generates "f()" + } + "f(a)" in { + Call(ir.Name("f"), Seq(ir.Name("a")), Seq.empty) generates "f(a)" + } + "f(a, b)" in { + Call(ir.Name("f"), Seq(ir.Name("a"), ir.Name("b")), Seq.empty) generates "f(a, b)" + } + "f(a, c=1)" in { + Call(ir.Name("f"), Seq(ir.Name("a")), Seq(Keyword(ir.Name("c"), ir.Literal(1)))) generates "f(a, c=1)" + } + "f(a, b, c=1, d=2)" in { + Call( + ir.Name("f"), + Seq(ir.Name("a")), + Seq(Keyword(ir.Name("c"), ir.Literal(1)), Keyword(ir.Name("d"), ir.Literal(2)))) generates "f(a, c=1, d=2)" + } + } + + "dicts" should { + "{}" in { + Dict(Seq.empty, Seq.empty) generates "{}" + } + "{a: b}" in { + Dict(Seq(ir.Name("a")), Seq(ir.Name("b"))) generates "{a: b}" + } + "{a: b, c: d}" in { + Dict(Seq(ir.Name("a"), ir.Name("c")), Seq(ir.Name("b"), ir.Name("d"))) generates "{a: b, c: d}" + } + } + + "slices" should { + ":" in { + Slice(None, None, None) generates ":" + } + "a:" in { + Slice(Some(ir.Name("a")), None, None) generates "a:" + } + ":b" in { + Slice(None, Some(ir.Name("b")), None) generates ":b" + } + "a:b" in { + Slice(Some(ir.Name("a")), Some(ir.Name("b")), None) generates "a:b" + } + "::c" in { + Slice(None, None, Some(ir.Name("c"))) generates "::c" + } + "a::c" in { + Slice(Some(ir.Name("a")), None, Some(ir.Name("c"))) generates "a::c" + } + ":b:c" in { + Slice(None, Some(ir.Name("b")), Some(ir.Name("c"))) generates ":b:c" + } + "a:b:c" in { + Slice(Some(ir.Name("a")), Some(ir.Name("b")), Some(ir.Name("c"))) generates "a:b:c" + } + } + + "if expr" should { + "a if b else c" in { + IfExp(test = ir.Name("a"), body = ir.Name("b"), orElse = ir.Name("c")) generates "b if a else c" + } + } + + "sets" should { + "set()" in { + Set(Seq.empty) generates "set()" + } + "set(a)" in { + Set(Seq(ir.Name("a"))) generates "{a}" + } + "set(a, b)" in { + Set(Seq(ir.Name("a"), ir.Name("b"))) generates "{a, b}" + } + } + + "lists" should { + "[]" in { + List(Seq.empty) generates "[]" + } + "[a]" in { + List(Seq(ir.Name("a"))) generates "[a]" + } + "[a, b]" in { + List(Seq(ir.Name("a"), ir.Name("b"))) generates "[a, b]" + } + } + + "subscripts" should { + "a[b]" in { + Subscript(ir.Name("a"), ir.Name("b")) generates "a[b]" + } + } + + "attributes" should { + "a.b" in { + Attribute(ir.Name("a"), ir.Name("b")) generates "a.b" + } + } + + "tuples" should { + "(a,)" in { + Tuple(Seq(ir.Name("a"))) generates "(a,)" + } + "(a, b)" in { + Tuple(Seq(ir.Name("a"), ir.Name("b"))) generates "(a, b,)" + } + } + + "lambdas" should { + "lambda a, b: c" in { + Lambda(Arguments(args = Seq(ir.Name("a"), ir.Name("b"))), ir.Name("c")) generates "lambda a, b: c" + } + "lambda *args: c" in { + Lambda(Arguments(vararg = Some(ir.Name("args"))), ir.Name("c")) generates "lambda *args: c" + } + "lambda **kw: c" in { + Lambda(Arguments(kwargs = Some(ir.Name("kw"))), ir.Name("c")) generates "lambda **kw: c" + } + "lambda pos1, pos2, *other, **keywords: c" in { + Lambda( + Arguments( + args = Seq(ir.Name("pos1"), ir.Name("pos2")), + vararg = Some(ir.Name("other")), + kwargs = Some(ir.Name("keywords"))), + ir.Name("c")) generates "lambda pos1, pos2, *other, **keywords: c" + } + } + + "comprehensions" should { + "[a*2 for a in b]" in { + ListComp( + ir.Multiply(ir.Name("a"), ir.Literal(2)), + Seq(Comprehension(ir.Name("a"), ir.Name("b"), Seq.empty))) generates "[a * 2 for a in b]" + } + "[a*2 for a in b if len(a) > 2]" in { + ListComp( + ir.Multiply(ir.Name("a"), ir.Literal(2)), + Seq( + Comprehension( + ir.Name("a"), + ir.Name("b"), + Seq( + ir.GreaterThan( + Call(ir.Name("len"), Seq(ir.Name("a"))), + ir.Literal(2)))))) generates "[a * 2 for a in b if len(a) > 2]" + } + "{a for a in b}" in { + SetComp(ir.Name("a"), Seq(Comprehension(ir.Name("a"), ir.Name("b"), Seq.empty))) generates "{a for a in b}" + } + "{a:1 for a in b}" in { + DictComp( + ir.Name("a"), + ir.Literal(1), + Seq(Comprehension(ir.Name("a"), ir.Name("b"), Seq.empty))) generates "{a: 1 for a in b}" + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/py/StatementGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/py/StatementGeneratorTest.scala new file mode 100644 index 0000000000..2facd792a6 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/py/StatementGeneratorTest.scala @@ -0,0 +1,238 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.generators.{GeneratorContext, GeneratorTestCommon} +import com.databricks.labs.remorph.{Generating, intermediate => ir} +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class StatementGeneratorTest + extends AnyWordSpec + with GeneratorTestCommon[Statement] + with MockitoSugar + with ir.IRHelpers { + + override protected val generator = new StatementGenerator(new ExpressionGenerator) + + override protected def initialState(statement: Statement) = + Generating( + optimizedPlan = ir.Batch(Seq.empty), + currentNode = statement, + ctx = GeneratorContext(new LogicalPlanGenerator)) + + "imports" should { + "import a, b as c" in { + Import(Seq(Alias(ir.Name("a")), Alias(ir.Name("b"), Some(ir.Name("c"))))) generates "import a, b as c" + } + "from foo.bar import a, b as c" in { + ImportFrom( + Some(ir.Name("foo.bar")), + Seq(Alias(ir.Name("a")), Alias(ir.Name("b"), Some(ir.Name("c"))))) generates "from foo.bar import a, b as c" + } + } + + "functions" should { + "@foo" in { + Decorator(ir.Name("foo")) generates "@foo" + } + "simple" in { + FunctionDef(ir.Name("foo"), Arguments(), Seq(Pass)) generates "def foo():\n pass\n" + } + "decorated" in { + FunctionDef( + ir.Name("foo"), + Arguments(), + Seq(Pass), + Seq(Decorator(ir.Name("bar")))) generates "@bar\ndef foo():\n pass\n" + } + } + + "classes" should { + "simple" in { + ClassDef(ir.Name("Foo"), Seq(), Seq(Pass)) generates "class Foo:\n pass\n" + } + "with bases" in { + ClassDef(ir.Name("Foo"), Seq(ir.Name("Bar")), Seq(Pass)) generates "class Foo(Bar):\n pass\n" + } + "with decorators" in { + ClassDef(ir.Name("Foo"), Seq(), Seq(Pass), Seq(Decorator(ir.Name("bar")))) generates "@bar\nclass Foo:\n pass\n" + } + "with decorated functions" in { + ClassDef( + ir.Name("Foo"), + Seq(), + Seq( + FunctionDef(ir.Name("foo"), Arguments(), Seq(Pass), Seq(Decorator(ir.Name("bar")))))) generates """class Foo: + | @bar + | def foo(): + | pass + | + |""".stripMargin + } + } + + "assign" should { + "a = b" in { + Assign(Seq(ir.Name("a")), ir.Name("b")) generates "a = b" + } + "a, b = c, d" in { + Assign(Seq(ir.Name("a"), ir.Name("b")), Tuple(Seq(ir.Name("c"), ir.Name("d")))) generates "a, b = (c, d,)" + } + } + + "for loop" should { + "for a in b: pass" in { + For(ir.Name("a"), ir.Name("b"), Seq(Pass)) generates "for a in b:\n pass\n" + } + "for a in b: ... else: ..." in { + For(ir.Name("a"), ir.Name("b"), Seq(Pass), Seq(Pass)) generates + """for a in b: + | pass + |else: + | pass + |""".stripMargin + } + "for-else in function propagates whitespace for else branch" in { + FunctionDef( + ir.Name("foo"), + Arguments(), + Seq(For(ir.Name("a"), ir.Name("b"), Seq(Pass), Seq(Pass)))) generates """def foo(): + | for a in b: + | pass + | else: + | pass + | + |""".stripMargin + } + } + + "while loop" should { + "while a: pass" in { + While(ir.Name("a"), Seq(Pass)) generates "while a:\n pass\n" + } + "while a: ... else: ..." in { + While(ir.Name("a"), Seq(Pass), Seq(Pass)) generates + """while a: + | pass + |else: + | pass + |""".stripMargin + } + } + + "if statement" should { + "if a: pass" in { + If(ir.Name("a"), Seq(Pass)) generates "if a:\n pass\n" + } + "if a: ... else: ..." in { + If(ir.Name("a"), Seq(Pass), Seq(Pass)) generates + """if a: + | pass + |else: + | pass + |""".stripMargin + } + } + + "with statement" should { + "with a(), b() as c: pass" in { + With( + Seq(Alias(Call(ir.Name("a"), Seq())), Alias(Call(ir.Name("b"), Seq()), Some(ir.Name("c")))), + Seq(Pass)) generates """with a(), b() as c: + | pass + |""".stripMargin + } + } + + "try-except" should { + "try: pass except: pass" in { + Try(Seq(Pass), Seq(Except(None, Seq(Pass)))) generates + """try: + | pass + |except: + | pass + |""".stripMargin + } + "try: pass except Foo: pass" in { + Try(Seq(Pass), Seq(Except(Some(Alias(ir.Name("Foo"))), Seq(Pass)))) generates + """try: + | pass + |except Foo: + | pass + |""".stripMargin + } + "try: pass except Foo as x: pass" in { + Try( + Seq(Pass), + Seq( + Except(Some(Alias(ir.Name("Foo"), Some(ir.Name("x")))), Seq(Pass)), + Except(Some(Alias(ir.Name("NotFound"))), Seq(Pass))), + Seq(Pass)) generates + """try: + | pass + |except Foo as x: + | pass + |except NotFound: + | pass + |else: + | pass + |""".stripMargin + } + "try: ... finally: ..." in { + Try(Seq(Pass), orFinally = Seq(Pass)) generates + """try: + | pass + |finally: + | pass + |""".stripMargin + } + "if True: try: ... finally: ..." in { + If(ir.Literal.True, Seq(Try(Seq(Pass), orFinally = Seq(Pass)))) generates + """if True: + | try: + | pass + | finally: + | pass + | + |""".stripMargin + } + } + + "raise" should { + "raise" in { + Raise() generates "raise" + } + "raise Foo" in { + Raise(Some(ir.Name("Foo"))) generates "raise Foo" + } + "raise Foo from Bar" in { + Raise(Some(ir.Name("Foo")), Some(ir.Name("Bar"))) generates "raise Foo from Bar" + } + } + + "assert" should { + "assert a" in { + Assert(ir.Name("a")) generates "assert a" + } + "assert a, b" in { + Assert(ir.Name("a"), Some(ir.Name("b"))) generates "assert a, b" + } + } + + "return" should { + "return" in { + Return(None) generates "return" + } + "return a" in { + Return(Some(ir.Name("a"))) generates "return a" + } + } + + "delete" should { + "delete a" in { + Delete(Seq(ir.Name("a"))) generates "del a" + } + "delete a, b" in { + Delete(Seq(ir.Name("a"), ir.Name("b"))) generates "del a, b" + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/sql/DataTypeGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/DataTypeGeneratorTest.scala new file mode 100644 index 0000000000..2b7130b429 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/DataTypeGeneratorTest.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.{TranspilerState, OkResult, intermediate => ir} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.{TableDrivenPropertyChecks, TableFor2} +import org.scalatest.wordspec.AnyWordSpec + +class DataTypeGeneratorTest extends AnyWordSpec with Matchers with TableDrivenPropertyChecks { + + val translations: TableFor2[ir.DataType, String] = Table( + ("datatype", "expected translation"), + (ir.NullType, "VOID"), + (ir.BooleanType, "BOOLEAN"), + (ir.BinaryType, "BINARY"), + (ir.ShortType, "SMALLINT"), + (ir.IntegerType, "INT"), + (ir.LongType, "BIGINT"), + (ir.FloatType, "FLOAT"), + (ir.DoubleType, "DOUBLE"), + (ir.StringType, "STRING"), + (ir.DateType, "DATE"), + (ir.TimestampType, "TIMESTAMP"), + (ir.TimestampNTZType, "TIMESTAMP_NTZ"), + (ir.DecimalType(None, None), "DECIMAL"), + (ir.DecimalType(Some(10), None), "DECIMAL(10)"), + (ir.DecimalType(Some(38), Some(6)), "DECIMAL(38, 6)"), + (ir.ArrayType(ir.StringType), "ARRAY"), + (ir.ArrayType(ir.ArrayType(ir.IntegerType)), "ARRAY>"), + (ir.MapType(ir.StringType, ir.DoubleType), "MAP"), + (ir.MapType(ir.StringType, ir.ArrayType(ir.DateType)), "MAP>"), + (ir.VarcharType(Some(10)), "VARCHAR(10)"), + ( + ir.StructExpr( + Seq( + ir.Alias(ir.Literal(1), ir.Id("a")), + ir.Alias(ir.Literal("two"), ir.Id("b")), + ir.Alias(ir.Literal(Seq(1, 2, 3)), ir.Id("c")))) + .dataType, + "STRUCT>")) + + "DataTypeGenerator" should { + "generate proper SQL data types" in { + forAll(translations) { (dt, expected) => + DataTypeGenerator.generateDataType(dt).runAndDiscardState(TranspilerState()) shouldBe OkResult(expected) + } + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/sql/ExpressionGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/ExpressionGeneratorTest.scala new file mode 100644 index 0000000000..d965cb186b --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/ExpressionGeneratorTest.scala @@ -0,0 +1,1757 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators.{GeneratorContext, GeneratorTestCommon} +import com.databricks.labs.remorph.{Generating, intermediate => ir} +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +import java.sql.{Date, Timestamp} + +class ExpressionGeneratorTest + extends AnyWordSpec + with GeneratorTestCommon[ir.Expression] + with MockitoSugar + with ir.IRHelpers { + + override protected val generator = new ExpressionGenerator + + private[this] val optionGenerator = new OptionGenerator(generator) + + private[this] val logical = new LogicalPlanGenerator(generator, optionGenerator) + + override protected def initialState(expr: ir.Expression) = + Generating(optimizedPlan = ir.Batch(Seq.empty), currentNode = expr, ctx = GeneratorContext(logical)) + + "options" in { + ir.Options( + Map( + "KEEPFIXED" -> ir.Column(None, ir.Id("PLAN")), + "FAST" -> ir.Literal(666), + "MAX_GRANT_PERCENT" -> ir.Literal(30)), + Map(), + Map("FLAME" -> false, "QUICKLY" -> true), + List()) generates + """/* + | The following statement was originally given the following OPTIONS: + | + | Expression options: + | + | KEEPFIXED = PLAN + | FAST = 666 + | MAX_GRANT_PERCENT = 30 + | + | Boolean options: + | + | FLAME OFF + | QUICKLY ON + | + | + | */ + |""".stripMargin + } + + "struct" in { + ir.StructExpr( + Seq( + ir.Alias(ir.Literal(1), ir.Id("a")), + ir.Alias(ir.Literal("two"), ir.Id("b")), + ir.Alias(ir.Literal(Seq(1, 2, 3)), ir.Id("c")))) generates "STRUCT(1 AS a, 'two' AS b, ARRAY(1, 2, 3) AS c)" + } + + "columns" should { + "unresolved" in { + ir.UnresolvedAttribute("a") generates "a" + } + "a" in { + ir.Column(None, ir.Id("a")) generates "a" + } + "t.a" in { + ir.Column(Some(ir.ObjectReference(ir.Id("t"))), ir.Id("a")) generates "t.a" + } + "s.t.a" in { + ir.Column(Some(ir.ObjectReference(ir.Id("s.t"))), ir.Id("a")) generates "s.t.a" + } + + "$1" in { + ir.Column(None, ir.Position(1)).doesNotTranspile + } + } + + "arithmetic" should { + "-a" in { + ir.UMinus(ir.UnresolvedAttribute("a")) generates "-a" + } + "+a" in { + ir.UPlus(ir.UnresolvedAttribute("a")) generates "+a" + } + "a * b" in { + ir.Multiply(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a * b" + } + "a / b" in { + ir.Divide(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a / b" + } + "a % b" in { + ir.Mod(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a % b" + } + "a + b" in { + ir.Add(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a + b" + } + "a - b" in { + ir.Subtract(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a - b" + } + } + + "bitwise" should { + "a | b" in { + ir.BitwiseOr(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a | b" + } + "a & b" in { + ir.BitwiseAnd(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a & b" + } + "a ^ b" in { + ir.BitwiseXor(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a ^ b" + } + "~a" in { + ir.BitwiseNot(ir.UnresolvedAttribute("a")) generates "~a" + } + } + + "like" should { + "a LIKE 'b%'" in { + ir.Like(ir.UnresolvedAttribute("a"), ir.Literal("b%"), None) generates "a LIKE 'b%'" + } + "a LIKE b ESCAPE '/'" in { + ir.Like(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"), Some(ir.Literal('/'))) generates + "a LIKE b ESCAPE '/'" + } + "a ILIKE 'b%'" in { + ir.ILike(ir.UnresolvedAttribute("a"), ir.Literal("b%"), None) generates "a ILIKE 'b%'" + } + "a ILIKE b ESCAPE '/'" in { + ir.ILike(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"), Some(ir.Literal('/'))) generates + "a ILIKE b ESCAPE '/'" + } + "a LIKE ANY ('b%', 'c%')" in { + ir.LikeAny( + ir.UnresolvedAttribute("a"), + Seq(ir.Literal("b%"), ir.Literal("c%"))) generates "a LIKE ANY ('b%', 'c%')" + } + "a LIKE ALL ('b%', '%c')" in { + ir.LikeAll( + ir.UnresolvedAttribute("a"), + Seq(ir.Literal("b%"), ir.Literal("%c"))) generates "a LIKE ALL ('b%', '%c')" + } + "other ilike" in { + ir.ILikeAny( + ir.UnresolvedAttribute("a"), + Seq(ir.Literal("b%"), ir.Literal("c%"))) generates "a ILIKE ANY ('b%', 'c%')" + ir.ILikeAll( + ir.UnresolvedAttribute("a"), + Seq(ir.Literal("b%"), ir.Literal("%c"))) generates "a ILIKE ALL ('b%', '%c')" + } + } + + "predicates" should { + "a AND b" in { + ir.And(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a AND b" + } + "a OR b" in { + ir.Or(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a OR b" + } + "NOT (a)" in { + ir.Not(ir.UnresolvedAttribute("a")) generates "NOT (a)" + } + "a = b" in { + ir.Equals(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a = b" + } + "a != b" in { + ir.NotEquals(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a != b" + } + "a < b" in { + ir.LessThan(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a < b" + } + "a <= b" in { + ir.LessThanOrEqual(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a <= b" + } + "a > b" in { + ir.GreaterThan(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a > b" + } + "a >= b" in { + ir.GreaterThanOrEqual(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a >= b" + } + } + + "functions" should { + "ABS(a)" in { + ir.CallFunction("ABS", Seq(ir.UnresolvedAttribute("a"))) generates "ABS(a)" + } + + "ACOS(a)" in { + ir.CallFunction("ACOS", Seq(ir.UnresolvedAttribute("a"))) generates "ACOS(a)" + } + + "ACOSH(a)" in { + ir.CallFunction("ACOSH", Seq(ir.UnresolvedAttribute("a"))) generates "ACOSH(a)" + } + + "ADD_MONTHS(a, b)" in { + ir.CallFunction( + "ADD_MONTHS", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ADD_MONTHS(a, b)" + } + + "AGGREGATE(a, b, c, d)" in { + ir.CallFunction( + "AGGREGATE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "AGGREGATE(a, b, c, d)" + } + + "ANY(a)" in { + ir.CallFunction("ANY", Seq(ir.UnresolvedAttribute("a"))) generates "ANY(a)" + } + + "APPROX_COUNT_DISTINCT(a, b)" in { + ir.CallFunction( + "APPROX_COUNT_DISTINCT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "APPROX_COUNT_DISTINCT(a, b)" + } + + "ARRAY(a, b, c, d)" in { + ir.CallFunction( + "ARRAY", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "ARRAY(a, b, c, d)" + } + + "ARRAYS_OVERLAP(a, b)" in { + ir.CallFunction( + "ARRAYS_OVERLAP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAYS_OVERLAP(a, b)" + } + + "ARRAYS_ZIP(a, b)" in { + ir.CallFunction( + "ARRAYS_ZIP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAYS_ZIP(a, b)" + } + + "ARRAY_CONTAINS(a, b)" in { + ir.CallFunction( + "ARRAY_CONTAINS", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_CONTAINS(a, b)" + } + + "ARRAY_DISTINCT(a)" in { + ir.CallFunction("ARRAY_DISTINCT", Seq(ir.UnresolvedAttribute("a"))) generates "ARRAY_DISTINCT(a)" + } + + "ARRAY_EXCEPT(a, b)" in { + ir.CallFunction( + "ARRAY_EXCEPT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_EXCEPT(a, b)" + } + + "ARRAY_INTERSECT(a, b)" in { + ir.CallFunction( + "ARRAY_INTERSECT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_INTERSECT(a, b)" + } + + "ARRAY_JOIN(a, b, c)" in { + ir.CallFunction( + "ARRAY_JOIN", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "ARRAY_JOIN(a, b, c)" + } + + "ARRAY_MAX(a)" in { + ir.CallFunction("ARRAY_MAX", Seq(ir.UnresolvedAttribute("a"))) generates "ARRAY_MAX(a)" + } + + "ARRAY_MIN(a)" in { + ir.CallFunction("ARRAY_MIN", Seq(ir.UnresolvedAttribute("a"))) generates "ARRAY_MIN(a)" + } + + "ARRAY_POSITION(a, b)" in { + ir.CallFunction( + "ARRAY_POSITION", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_POSITION(a, b)" + } + + "ARRAY_REMOVE(a, b)" in { + ir.CallFunction( + "ARRAY_REMOVE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_REMOVE(a, b)" + } + + "ARRAY_REMOVE([2, 3, 4::DOUBLE, 4, NULL], 4)" in { + ir.CallFunction( + "ARRAY_REMOVE", + Seq( + ir.ArrayExpr( + Seq(ir.Literal(2), ir.Literal(3), ir.Cast(ir.Literal(4), ir.DoubleType), ir.Literal(4), ir.Literal(null)), + ir.IntegerType), + ir.Literal(4))) generates "ARRAY_REMOVE(ARRAY(2, 3, CAST(4 AS DOUBLE), 4, NULL), 4)" + } + + "ARRAY_REPEAT(a, b)" in { + ir.CallFunction( + "ARRAY_REPEAT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_REPEAT(a, b)" + } + + "ARRAY_SORT(a, b)" in { + ir.CallFunction( + "ARRAY_SORT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_SORT(a, b)" + } + + "ARRAY_UNION(a, b)" in { + ir.CallFunction( + "ARRAY_UNION", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ARRAY_UNION(a, b)" + } + + "ASCII(a)" in { + ir.CallFunction("ASCII", Seq(ir.UnresolvedAttribute("a"))) generates "ASCII(a)" + } + + "ASIN(a)" in { + ir.CallFunction("ASIN", Seq(ir.UnresolvedAttribute("a"))) generates "ASIN(a)" + } + + "ASINH(a)" in { + ir.CallFunction("ASINH", Seq(ir.UnresolvedAttribute("a"))) generates "ASINH(a)" + } + + "ASSERT_TRUE(a, b)" in { + ir.CallFunction( + "ASSERT_TRUE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ASSERT_TRUE(a, b)" + } + + "ATAN(a)" in { + ir.CallFunction("ATAN", Seq(ir.UnresolvedAttribute("a"))) generates "ATAN(a)" + } + + "ATAN2(a, b)" in { + ir.CallFunction("ATAN2", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ATAN2(a, b)" + } + + "ATANH(a)" in { + ir.CallFunction("ATANH", Seq(ir.UnresolvedAttribute("a"))) generates "ATANH(a)" + } + + "AVG(a)" in { + ir.CallFunction("AVG", Seq(ir.UnresolvedAttribute("a"))) generates "AVG(a)" + } + + "BASE64(a)" in { + ir.CallFunction("BASE64", Seq(ir.UnresolvedAttribute("a"))) generates "BASE64(a)" + } + + "BIN(a)" in { + ir.CallFunction("BIN", Seq(ir.UnresolvedAttribute("a"))) generates "BIN(a)" + } + + "BIT_AND(a)" in { + ir.CallFunction("BIT_AND", Seq(ir.UnresolvedAttribute("a"))) generates "BIT_AND(a)" + } + + "BIT_COUNT(a)" in { + ir.CallFunction("BIT_COUNT", Seq(ir.UnresolvedAttribute("a"))) generates "BIT_COUNT(a)" + } + + "BIT_LENGTH(a)" in { + ir.CallFunction("BIT_LENGTH", Seq(ir.UnresolvedAttribute("a"))) generates "BIT_LENGTH(a)" + } + + "BIT_OR(a)" in { + ir.CallFunction("BIT_OR", Seq(ir.UnresolvedAttribute("a"))) generates "BIT_OR(a)" + } + + "BIT_XOR(a)" in { + ir.CallFunction("BIT_XOR", Seq(ir.UnresolvedAttribute("a"))) generates "BIT_XOR(a)" + } + + "BOOL_AND(a)" in { + ir.CallFunction("BOOL_AND", Seq(ir.UnresolvedAttribute("a"))) generates "BOOL_AND(a)" + } + + "BROUND(a, b)" in { + ir.CallFunction("BROUND", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "BROUND(a, b)" + } + + "CBRT(a)" in { + ir.CallFunction("CBRT", Seq(ir.UnresolvedAttribute("a"))) generates "CBRT(a)" + } + + "CEIL(a)" in { + ir.CallFunction("CEIL", Seq(ir.UnresolvedAttribute("a"))) generates "CEIL(a)" + } + + "CHAR(a)" in { + ir.CallFunction("CHAR", Seq(ir.UnresolvedAttribute("a"))) generates "CHAR(a)" + } + + "COALESCE(a, b)" in { + ir.CallFunction( + "COALESCE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "COALESCE(a, b)" + } + + "ARRAY_AGG(a)" in { + ir.CollectList(ir.UnresolvedAttribute("a")) generates "ARRAY_AGG(a)" + } + + "COLLECT_SET(a)" in { + ir.CallFunction("COLLECT_SET", Seq(ir.UnresolvedAttribute("a"))) generates "COLLECT_SET(a)" + } + + "CONCAT(a, b)" in { + ir.CallFunction("CONCAT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "CONCAT(a, b)" + } + + "CONCAT_WS(a, b)" in { + ir.CallFunction( + "CONCAT_WS", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "CONCAT_WS(a, b)" + } + + "CONV(a, b, c)" in { + ir.CallFunction( + "CONV", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "CONV(a, b, c)" + } + + "CORR(a, b)" in { + ir.CallFunction("CORR", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "CORR(a, b)" + } + + "COS(a)" in { + ir.CallFunction("COS", Seq(ir.UnresolvedAttribute("a"))) generates "COS(a)" + } + + "COSH(a)" in { + ir.CallFunction("COSH", Seq(ir.UnresolvedAttribute("a"))) generates "COSH(a)" + } + + "COT(a)" in { + ir.CallFunction("COT", Seq(ir.UnresolvedAttribute("a"))) generates "COT(a)" + } + + "COUNT(a, b)" in { + ir.CallFunction("COUNT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "COUNT(a, b)" + } + + "COUNT_IF(a)" in { + ir.CallFunction("COUNT_IF", Seq(ir.UnresolvedAttribute("a"))) generates "COUNT_IF(a)" + } + + "COUNT_MIN_SKETCH(a, b, c, d)" in { + ir.CallFunction( + "COUNT_MIN_SKETCH", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "COUNT_MIN_SKETCH(a, b, c, d)" + } + + "COVAR_POP(a, b)" in { + ir.CallFunction( + "COVAR_POP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "COVAR_POP(a, b)" + } + + "COVAR_SAMP(a, b)" in { + ir.CallFunction( + "COVAR_SAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "COVAR_SAMP(a, b)" + } + + "CRC32(a)" in { + ir.CallFunction("CRC32", Seq(ir.UnresolvedAttribute("a"))) generates "CRC32(a)" + } + + "CUBE(a, b)" in { + ir.CallFunction("CUBE", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "CUBE(a, b)" + } + + "CUME_DIST()" in { + ir.CallFunction("CUME_DIST", Seq()) generates "CUME_DIST()" + } + + "CURRENT_CATALOG()" in { + ir.CallFunction("CURRENT_CATALOG", Seq()) generates "CURRENT_CATALOG()" + } + + "CURRENT_DATABASE()" in { + ir.CallFunction("CURRENT_DATABASE", Seq()) generates "CURRENT_DATABASE()" + } + + "CURRENT_DATE()" in { + ir.CallFunction("CURRENT_DATE", Seq()) generates "CURRENT_DATE()" + } + + "CURRENT_TIMESTAMP()" in { + ir.CallFunction("CURRENT_TIMESTAMP", Seq()) generates "CURRENT_TIMESTAMP()" + } + + "CURRENT_TIMEZONE()" in { + ir.CallFunction("CURRENT_TIMEZONE", Seq()) generates "CURRENT_TIMEZONE()" + } + + "DATEDIFF(a, b)" in { + ir.CallFunction( + "DATEDIFF", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATEDIFF(a, b)" + } + + "DATE_ADD(a, b)" in { + ir.CallFunction( + "DATE_ADD", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATE_ADD(a, b)" + } + + "DATE_FORMAT(a, b)" in { + ir.CallFunction( + "DATE_FORMAT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATE_FORMAT(a, b)" + } + + "DATE_FROM_UNIX_DATE(a)" in { + ir.CallFunction("DATE_FROM_UNIX_DATE", Seq(ir.UnresolvedAttribute("a"))) generates "DATE_FROM_UNIX_DATE(a)" + } + + "DATE_PART(a, b)" in { + ir.CallFunction( + "DATE_PART", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATE_PART(a, b)" + } + + "DATE_SUB(a, b)" in { + ir.CallFunction( + "DATE_SUB", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATE_SUB(a, b)" + } + + "DATE_TRUNC(a, b)" in { + ir.CallFunction( + "DATE_TRUNC", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DATE_TRUNC(a, b)" + } + + "DAYOFMONTH(a)" in { + ir.CallFunction("DAYOFMONTH", Seq(ir.UnresolvedAttribute("a"))) generates "DAYOFMONTH(a)" + } + + "DAYOFWEEK(a)" in { + ir.CallFunction("DAYOFWEEK", Seq(ir.UnresolvedAttribute("a"))) generates "DAYOFWEEK(a)" + } + + "DAYOFYEAR(a)" in { + ir.CallFunction("DAYOFYEAR", Seq(ir.UnresolvedAttribute("a"))) generates "DAYOFYEAR(a)" + } + + "DECODE(a, b)" in { + ir.CallFunction("DECODE", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DECODE(a, b)" + } + + "DEGREES(a)" in { + ir.CallFunction("DEGREES", Seq(ir.UnresolvedAttribute("a"))) generates "DEGREES(a)" + } + + "DENSE_RANK(a, b)" in { + ir.CallFunction( + "DENSE_RANK", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DENSE_RANK(a, b)" + } + + "DIV(a, b)" in { + ir.CallFunction("DIV", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "DIV(a, b)" + } + + "E()" in { + ir.CallFunction("E", Seq()) generates "E()" + } + + "ELEMENT_AT(a, b)" in { + ir.CallFunction( + "ELEMENT_AT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ELEMENT_AT(a, b)" + } + + "ELT(a, b)" in { + ir.CallFunction("ELT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ELT(a, b)" + } + + "ENCODE(a, b)" in { + ir.CallFunction("ENCODE", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ENCODE(a, b)" + } + + "EXISTS(a, b)" in { + ir.CallFunction("EXISTS", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "EXISTS(a, b)" + } + + "EXP(a)" in { + ir.CallFunction("EXP", Seq(ir.UnresolvedAttribute("a"))) generates "EXP(a)" + } + + "EXPLODE(a)" in { + ir.CallFunction("EXPLODE", Seq(ir.UnresolvedAttribute("a"))) generates "EXPLODE(a)" + } + + "EXPM1(a)" in { + ir.CallFunction("EXPM1", Seq(ir.UnresolvedAttribute("a"))) generates "EXPM1(a)" + } + + "EXTRACT(a FROM b)" in { + ir.Extract(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "EXTRACT(a FROM b)" + } + "FACTORIAL(a)" in { + ir.CallFunction("FACTORIAL", Seq(ir.UnresolvedAttribute("a"))) generates "FACTORIAL(a)" + } + + "FILTER(a, b)" in { + ir.CallFunction("FILTER", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FILTER(a, b)" + } + + "FIND_IN_SET(a, b)" in { + ir.CallFunction( + "FIND_IN_SET", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FIND_IN_SET(a, b)" + } + + "FIRST(a, b)" in { + ir.CallFunction("FIRST", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FIRST(a, b)" + } + + "FLATTEN(a)" in { + ir.CallFunction("FLATTEN", Seq(ir.UnresolvedAttribute("a"))) generates "FLATTEN(a)" + } + + "FLOOR(a)" in { + ir.CallFunction("FLOOR", Seq(ir.UnresolvedAttribute("a"))) generates "FLOOR(a)" + } + + "FORALL(a, b)" in { + ir.CallFunction("FORALL", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FORALL(a, b)" + } + + "FORMAT_NUMBER(a, b)" in { + ir.CallFunction( + "FORMAT_NUMBER", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FORMAT_NUMBER(a, b)" + } + + "FORMAT_STRING(a, b)" in { + ir.CallFunction( + "FORMAT_STRING", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FORMAT_STRING(a, b)" + } + + "FROM_CSV(a, b, c)" in { + ir.CallFunction( + "FROM_CSV", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "FROM_CSV(a, b, c)" + } + + "FROM_JSON(a, b, c)" in { + ir.CallFunction( + "FROM_JSON", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "FROM_JSON(a, b, c)" + } + + "FROM_UNIXTIME(a, b)" in { + ir.CallFunction( + "FROM_UNIXTIME", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FROM_UNIXTIME(a, b)" + } + + "FROM_UTC_TIMESTAMP(a, b)" in { + ir.CallFunction( + "FROM_UTC_TIMESTAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "FROM_UTC_TIMESTAMP(a, b)" + } + + "GET_JSON_OBJECT(a, b)" in { + ir.CallFunction( + "GET_JSON_OBJECT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "GET_JSON_OBJECT(a, b)" + } + + "GREATEST(a, b)" in { + ir.CallFunction( + "GREATEST", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "GREATEST(a, b)" + } + + "GROUPING(a)" in { + ir.CallFunction("GROUPING", Seq(ir.UnresolvedAttribute("a"))) generates "GROUPING(a)" + } + + "GROUPING_ID(a, b)" in { + ir.CallFunction( + "GROUPING_ID", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "GROUPING_ID(a, b)" + } + + "HASH(a, b)" in { + ir.CallFunction("HASH", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "HASH(a, b)" + } + + "HEX(a)" in { + ir.CallFunction("HEX", Seq(ir.UnresolvedAttribute("a"))) generates "HEX(a)" + } + + "HOUR(a)" in { + ir.CallFunction("HOUR", Seq(ir.UnresolvedAttribute("a"))) generates "HOUR(a)" + } + + "HYPOT(a, b)" in { + ir.CallFunction("HYPOT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "HYPOT(a, b)" + } + + "IF(a, b, c)" in { + ir.CallFunction( + "IF", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "IF(a, b, c)" + } + + "IFNULL(a, b)" in { + ir.CallFunction("IFNULL", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "IFNULL(a, b)" + } + + "INITCAP(a)" in { + ir.CallFunction("INITCAP", Seq(ir.UnresolvedAttribute("a"))) generates "INITCAP(a)" + } + + "INLINE(a)" in { + ir.CallFunction("INLINE", Seq(ir.UnresolvedAttribute("a"))) generates "INLINE(a)" + } + + "INPUT_FILE_BLOCK_LENGTH()" in { + ir.CallFunction("INPUT_FILE_BLOCK_LENGTH", Seq()) generates "INPUT_FILE_BLOCK_LENGTH()" + } + + "INPUT_FILE_BLOCK_START()" in { + ir.CallFunction("INPUT_FILE_BLOCK_START", Seq()) generates "INPUT_FILE_BLOCK_START()" + } + + "INPUT_FILE_NAME()" in { + ir.CallFunction("INPUT_FILE_NAME", Seq()) generates "INPUT_FILE_NAME()" + } + + "INSTR(a, b)" in { + ir.CallFunction("INSTR", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "INSTR(a, b)" + } + + "ISNAN(a)" in { + ir.CallFunction("ISNAN", Seq(ir.UnresolvedAttribute("a"))) generates "ISNAN(a)" + } + + "ISNOTNULL(a)" in { + ir.CallFunction("ISNOTNULL", Seq(ir.UnresolvedAttribute("a"))) generates "ISNOTNULL(a)" + } + + "ISNULL(a)" in { + ir.CallFunction("ISNULL", Seq(ir.UnresolvedAttribute("a"))) generates "ISNULL(a)" + } + + "JAVA_METHOD(a, b)" in { + ir.CallFunction( + "JAVA_METHOD", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "JAVA_METHOD(a, b)" + } + + "JSON_ARRAY_LENGTH(a)" in { + ir.CallFunction("JSON_ARRAY_LENGTH", Seq(ir.UnresolvedAttribute("a"))) generates "JSON_ARRAY_LENGTH(a)" + } + + "JSON_OBJECT_KEYS(a)" in { + ir.CallFunction("JSON_OBJECT_KEYS", Seq(ir.UnresolvedAttribute("a"))) generates "JSON_OBJECT_KEYS(a)" + } + + "JSON_TUPLE(a, b)" in { + ir.CallFunction( + "JSON_TUPLE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "JSON_TUPLE(a, b)" + } + + "KURTOSIS(a)" in { + ir.CallFunction("KURTOSIS", Seq(ir.UnresolvedAttribute("a"))) generates "KURTOSIS(a)" + } + + "LAG(a, b, c)" in { + ir.CallFunction( + "LAG", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "LAG(a, b, c)" + } + + "LAST(a, b)" in { + ir.CallFunction("LAST", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LAST(a, b)" + } + + "LAST_DAY(a)" in { + ir.CallFunction("LAST_DAY", Seq(ir.UnresolvedAttribute("a"))) generates "LAST_DAY(a)" + } + + "LEAD(a, b, c)" in { + ir.CallFunction( + "LEAD", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "LEAD(a, b, c)" + } + + "LEAST(a, b)" in { + ir.CallFunction("LEAST", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LEAST(a, b)" + } + + "LEFT(a, b)" in { + ir.CallFunction("LEFT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LEFT(a, b)" + } + + "LENGTH(a)" in { + ir.CallFunction("LENGTH", Seq(ir.UnresolvedAttribute("a"))) generates "LENGTH(a)" + } + + "LEVENSHTEIN(a, b)" in { + ir.CallFunction( + "LEVENSHTEIN", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LEVENSHTEIN(a, b)" + } + + "LN(a)" in { + ir.CallFunction("LN", Seq(ir.UnresolvedAttribute("a"))) generates "LN(a)" + } + + "LOG(a, b)" in { + ir.CallFunction("LOG", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LOG(a, b)" + } + + "LOG10(a)" in { + ir.CallFunction("LOG10", Seq(ir.UnresolvedAttribute("a"))) generates "LOG10(a)" + } + + "LOG1P(a)" in { + ir.CallFunction("LOG1P", Seq(ir.UnresolvedAttribute("a"))) generates "LOG1P(a)" + } + + "LOG2(a)" in { + ir.CallFunction("LOG2", Seq(ir.UnresolvedAttribute("a"))) generates "LOG2(a)" + } + + "LOWER(a)" in { + ir.CallFunction("LOWER", Seq(ir.UnresolvedAttribute("a"))) generates "LOWER(a)" + } + + "LTRIM(a, b)" in { + ir.CallFunction("LTRIM", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "LTRIM(a, b)" + } + + "MAKE_DATE(a, b, c)" in { + ir.CallFunction( + "MAKE_DATE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "MAKE_DATE(a, b, c)" + } + + "MAP_CONCAT(a, b)" in { + ir.CallFunction( + "MAP_CONCAT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MAP_CONCAT(a, b)" + } + + "MAP_ENTRIES(a)" in { + ir.CallFunction("MAP_ENTRIES", Seq(ir.UnresolvedAttribute("a"))) generates "MAP_ENTRIES(a)" + } + + "MAP_FILTER(a, b)" in { + ir.CallFunction( + "MAP_FILTER", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MAP_FILTER(a, b)" + } + + "MAP_FROM_ARRAYS(a, b)" in { + ir.CallFunction( + "MAP_FROM_ARRAYS", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MAP_FROM_ARRAYS(a, b)" + } + + "MAP_FROM_ENTRIES(a)" in { + ir.CallFunction("MAP_FROM_ENTRIES", Seq(ir.UnresolvedAttribute("a"))) generates "MAP_FROM_ENTRIES(a)" + } + + "MAP_KEYS(a)" in { + ir.CallFunction("MAP_KEYS", Seq(ir.UnresolvedAttribute("a"))) generates "MAP_KEYS(a)" + } + + "MAP_VALUES(a)" in { + ir.CallFunction("MAP_VALUES", Seq(ir.UnresolvedAttribute("a"))) generates "MAP_VALUES(a)" + } + + "MAP_ZIP_WITH(a, b, c)" in { + ir.CallFunction( + "MAP_ZIP_WITH", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "MAP_ZIP_WITH(a, b, c)" + } + + "MAX(a)" in { + ir.CallFunction("MAX", Seq(ir.UnresolvedAttribute("a"))) generates "MAX(a)" + } + + "MAX_BY(a, b)" in { + ir.CallFunction("MAX_BY", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MAX_BY(a, b)" + } + + "MD5(a)" in { + ir.CallFunction("MD5", Seq(ir.UnresolvedAttribute("a"))) generates "MD5(a)" + } + + "MIN(a)" in { + ir.CallFunction("MIN", Seq(ir.UnresolvedAttribute("a"))) generates "MIN(a)" + } + + "MINUTE(a)" in { + ir.CallFunction("MINUTE", Seq(ir.UnresolvedAttribute("a"))) generates "MINUTE(a)" + } + + "MIN_BY(a, b)" in { + ir.CallFunction("MIN_BY", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MIN_BY(a, b)" + } + + "MOD(a, b)" in { + ir.CallFunction("MOD", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "MOD(a, b)" + } + + "MONOTONICALLY_INCREASING_ID()" in { + ir.CallFunction("MONOTONICALLY_INCREASING_ID", Seq()) generates "MONOTONICALLY_INCREASING_ID()" + } + + "MONTH(a)" in { + ir.CallFunction("MONTH", Seq(ir.UnresolvedAttribute("a"))) generates "MONTH(a)" + } + + "MONTHS_BETWEEN(a, b, c)" in { + ir.CallFunction( + "MONTHS_BETWEEN", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "MONTHS_BETWEEN(a, b, c)" + } + + "NAMED_STRUCT(a, b)" in { + ir.CallFunction( + "NAMED_STRUCT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NAMED_STRUCT(a, b)" + } + + "NANVL(a, b)" in { + ir.CallFunction("NANVL", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NANVL(a, b)" + } + + "NEGATIVE(a)" in { + ir.CallFunction("NEGATIVE", Seq(ir.UnresolvedAttribute("a"))) generates "NEGATIVE(a)" + } + + "NEXT_DAY(a, b)" in { + ir.CallFunction( + "NEXT_DAY", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NEXT_DAY(a, b)" + } + + "NOW()" in { + ir.CallFunction("NOW", Seq()) generates "NOW()" + } + + "NTH_VALUE(a, b)" in { + ir.CallFunction( + "NTH_VALUE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NTH_VALUE(a, b)" + } + + "NTILE(a)" in { + ir.CallFunction("NTILE", Seq(ir.UnresolvedAttribute("a"))) generates "NTILE(a)" + } + + "NULLIF(a, b)" in { + ir.CallFunction("NULLIF", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NULLIF(a, b)" + } + + "NVL(a, b)" in { + ir.CallFunction("NVL", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "NVL(a, b)" + } + + "NVL2(a, b, c)" in { + ir.CallFunction( + "NVL2", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "NVL2(a, b, c)" + } + "OCTET_LENGTH(a)" in { + ir.CallFunction("OCTET_LENGTH", Seq(ir.UnresolvedAttribute("a"))) generates "OCTET_LENGTH(a)" + } + + "OR(a, b)" in { + ir.CallFunction("OR", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "OR(a, b)" + } + + "OVERLAY(a, b, c, d)" in { + ir.CallFunction( + "OVERLAY", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "OVERLAY(a, b, c, d)" + } + + "PARSE_URL(a, b)" in { + ir.CallFunction( + "PARSE_URL", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "PARSE_URL(a, b)" + } + + "PERCENTILE(a, b, c)" in { + ir.CallFunction( + "PERCENTILE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "PERCENTILE(a, b, c)" + } + + "PERCENTILE_APPROX(a, b, c)" in { + ir.CallFunction( + "PERCENTILE_APPROX", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "PERCENTILE_APPROX(a, b, c)" + } + + "PERCENT_RANK(a, b)" in { + ir.CallFunction( + "PERCENT_RANK", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "PERCENT_RANK(a, b)" + } + + "PI()" in { + ir.CallFunction("PI", Seq()) generates "PI()" + } + + "PMOD(a, b)" in { + ir.CallFunction("PMOD", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "PMOD(a, b)" + } + + "POSEXPLODE(a)" in { + ir.CallFunction("POSEXPLODE", Seq(ir.UnresolvedAttribute("a"))) generates "POSEXPLODE(a)" + } + + "POSITIVE(a)" in { + ir.CallFunction("POSITIVE", Seq(ir.UnresolvedAttribute("a"))) generates "POSITIVE(a)" + } + + "POWER(a, b)" in { + ir.Pow(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "POWER(a, b)" + } + + "QUARTER(a)" in { + ir.CallFunction("QUARTER", Seq(ir.UnresolvedAttribute("a"))) generates "QUARTER(a)" + } + + "RADIANS(a)" in { + ir.CallFunction("RADIANS", Seq(ir.UnresolvedAttribute("a"))) generates "RADIANS(a)" + } + + "RAISE_ERROR(a)" in { + ir.CallFunction("RAISE_ERROR", Seq(ir.UnresolvedAttribute("a"))) generates "RAISE_ERROR(a)" + } + + "RAND(a)" in { + ir.CallFunction("RAND", Seq(ir.UnresolvedAttribute("a"))) generates "RAND(a)" + } + + "RANDN(a)" in { + ir.CallFunction("RANDN", Seq(ir.UnresolvedAttribute("a"))) generates "RANDN(a)" + } + + "RANK(a, b)" in { + ir.CallFunction("RANK", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "RANK(a, b)" + } + + "REGEXP_EXTRACT(a, b, c)" in { + ir.CallFunction( + "REGEXP_EXTRACT", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "REGEXP_EXTRACT(a, b, c)" + } + + "REGEXP_EXTRACT_ALL(a, b, c)" in { + ir.CallFunction( + "REGEXP_EXTRACT_ALL", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "REGEXP_EXTRACT_ALL(a, b, c)" + } + + "REGEXP_REPLACE(a, b, c, d)" in { + ir.CallFunction( + "REGEXP_REPLACE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "REGEXP_REPLACE(a, b, c, d)" + } + + "REPEAT(a, b)" in { + ir.CallFunction("REPEAT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "REPEAT(a, b)" + } + + "REPLACE(a, b, c)" in { + ir.CallFunction( + "REPLACE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "REPLACE(a, b, c)" + } + "REVERSE(a)" in { + ir.CallFunction("REVERSE", Seq(ir.UnresolvedAttribute("a"))) generates "REVERSE(a)" + } + + "RIGHT(a, b)" in { + ir.CallFunction("RIGHT", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "RIGHT(a, b)" + } + "RINT(a)" in { + ir.CallFunction("RINT", Seq(ir.UnresolvedAttribute("a"))) generates "RINT(a)" + } + + "a RLIKE b" in { + ir.RLike(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b")) generates "a RLIKE b" + } + + "ROLLUP(a, b)" in { + ir.CallFunction("ROLLUP", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ROLLUP(a, b)" + } + + "ROUND(a, b)" in { + ir.CallFunction("ROUND", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "ROUND(a, b)" + } + "ROW_NUMBER()" in { + ir.CallFunction("ROW_NUMBER", Seq()) generates "ROW_NUMBER()" + } + + "RPAD(a, b, c)" in { + ir.CallFunction( + "RPAD", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "RPAD(a, b, c)" + } + + "RTRIM(a, b)" in { + ir.CallFunction("RTRIM", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "RTRIM(a, b)" + } + + "SCHEMA_OF_CSV(a, b)" in { + ir.CallFunction( + "SCHEMA_OF_CSV", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SCHEMA_OF_CSV(a, b)" + } + + "SCHEMA_OF_JSON(a, b)" in { + ir.CallFunction( + "SCHEMA_OF_JSON", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SCHEMA_OF_JSON(a, b)" + } + "SECOND(a)" in { + ir.CallFunction("SECOND", Seq(ir.UnresolvedAttribute("a"))) generates "SECOND(a)" + } + + "SENTENCES(a, b, c)" in { + ir.CallFunction( + "SENTENCES", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SENTENCES(a, b, c)" + } + + "SEQUENCE(a, b, c)" in { + ir.CallFunction( + "SEQUENCE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SEQUENCE(a, b, c)" + } + "SHA(a)" in { + ir.CallFunction("SHA", Seq(ir.UnresolvedAttribute("a"))) generates "SHA(a)" + } + + "SHA2(a, b)" in { + ir.CallFunction("SHA2", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SHA2(a, b)" + } + + "SHIFTLEFT(a, b)" in { + ir.CallFunction( + "SHIFTLEFT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SHIFTLEFT(a, b)" + } + + "SHIFTRIGHT(a, b)" in { + ir.CallFunction( + "SHIFTRIGHT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SHIFTRIGHT(a, b)" + } + + "SHIFTRIGHTUNSIGNED(a, b)" in { + ir.CallFunction( + "SHIFTRIGHTUNSIGNED", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SHIFTRIGHTUNSIGNED(a, b)" + } + "SHUFFLE(a)" in { + ir.CallFunction("SHUFFLE", Seq(ir.UnresolvedAttribute("a"))) generates "SHUFFLE(a)" + } + "SIGN(a)" in { + ir.CallFunction("SIGN", Seq(ir.UnresolvedAttribute("a"))) generates "SIGN(a)" + } + "SIN(a)" in { + ir.CallFunction("SIN", Seq(ir.UnresolvedAttribute("a"))) generates "SIN(a)" + } + "SINH(a)" in { + ir.CallFunction("SINH", Seq(ir.UnresolvedAttribute("a"))) generates "SINH(a)" + } + "SIZE(a)" in { + ir.CallFunction("SIZE", Seq(ir.UnresolvedAttribute("a"))) generates "SIZE(a)" + } + "SKEWNESS(a)" in { + ir.CallFunction("SKEWNESS", Seq(ir.UnresolvedAttribute("a"))) generates "SKEWNESS(a)" + } + + "SLICE(a, b, c)" in { + ir.CallFunction( + "SLICE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SLICE(a, b, c)" + } + + "SORT_ARRAY(a, b)" in { + ir.CallFunction( + "SORT_ARRAY", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "SORT_ARRAY(a, b)" + } + "SOUNDEX(a)" in { + ir.CallFunction("SOUNDEX", Seq(ir.UnresolvedAttribute("a"))) generates "SOUNDEX(a)" + } + "SPACE(a)" in { + ir.CallFunction("SPACE", Seq(ir.UnresolvedAttribute("a"))) generates "SPACE(a)" + } + "SPARK_PARTITION_ID()" in { + ir.CallFunction("SPARK_PARTITION_ID", Seq()) generates "SPARK_PARTITION_ID()" + } + + "SPLIT(a, b, c)" in { + ir.CallFunction( + "SPLIT", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SPLIT(a, b, c)" + } + + "SPLIT_PART(a, b, c)" in { + ir.CallFunction( + "SPLIT_PART", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SPLIT_PART(a, b, c)" + } + + "SQRT(a)" in { + ir.CallFunction("SQRT", Seq(ir.UnresolvedAttribute("a"))) generates "SQRT(a)" + } + + "STACK(a, b)" in { + ir.CallFunction("STACK", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "STACK(a, b)" + } + "STD(a)" in { + ir.CallFunction("STD", Seq(ir.UnresolvedAttribute("a"))) generates "STD(a)" + } + "STDDEV(a)" in { + ir.CallFunction("STDDEV", Seq(ir.UnresolvedAttribute("a"))) generates "STDDEV(a)" + } + "STDDEV_POP(a)" in { + ir.CallFunction("STDDEV_POP", Seq(ir.UnresolvedAttribute("a"))) generates "STDDEV_POP(a)" + } + + "STR_TO_MAP(a, b, c)" in { + ir.CallFunction( + "STR_TO_MAP", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "STR_TO_MAP(a, b, c)" + } + + "SUBSTR(a, b, c)" in { + ir.CallFunction( + "SUBSTR", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SUBSTR(a, b, c)" + } + + "SUBSTRING_INDEX(a, b, c)" in { + ir.CallFunction( + "SUBSTRING_INDEX", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "SUBSTRING_INDEX(a, b, c)" + } + "SUM(a)" in { + ir.CallFunction("SUM", Seq(ir.UnresolvedAttribute("a"))) generates "SUM(a)" + } + "TAN(a)" in { + ir.CallFunction("TAN", Seq(ir.UnresolvedAttribute("a"))) generates "TAN(a)" + } + "TANH(a)" in { + ir.CallFunction("TANH", Seq(ir.UnresolvedAttribute("a"))) generates "TANH(a)" + } + "TIMESTAMP_MICROS(a)" in { + ir.CallFunction("TIMESTAMP_MICROS", Seq(ir.UnresolvedAttribute("a"))) generates "TIMESTAMP_MICROS(a)" + } + "TIMESTAMP_MILLIS(a)" in { + ir.CallFunction("TIMESTAMP_MILLIS", Seq(ir.UnresolvedAttribute("a"))) generates "TIMESTAMP_MILLIS(a)" + } + "TIMESTAMP_SECONDS(a)" in { + ir.CallFunction("TIMESTAMP_SECONDS", Seq(ir.UnresolvedAttribute("a"))) generates "TIMESTAMP_SECONDS(a)" + } + + "TO_CSV(a, b)" in { + ir.CallFunction("TO_CSV", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_CSV(a, b)" + } + + "TO_DATE(a, b)" in { + ir.CallFunction( + "TO_DATE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_DATE(a, b)" + } + + "TO_JSON(a, b)" in { + ir.CallFunction( + "TO_JSON", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_JSON(a, b)" + } + + "TO_NUMBER(a, b)" in { + ir.CallFunction( + "TO_NUMBER", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_NUMBER(a, b)" + } + + "TO_TIMESTAMP(a, b)" in { + ir.CallFunction( + "TO_TIMESTAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_TIMESTAMP(a, b)" + } + + "TO_UNIX_TIMESTAMP(a, b)" in { + ir.CallFunction( + "TO_UNIX_TIMESTAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_UNIX_TIMESTAMP(a, b)" + } + + "TO_UTC_TIMESTAMP(a, b)" in { + ir.CallFunction( + "TO_UTC_TIMESTAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TO_UTC_TIMESTAMP(a, b)" + } + + "TRANSFORM(a, b)" in { + ir.CallFunction( + "TRANSFORM", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TRANSFORM(a, b)" + } + + "TRANSFORM_KEYS(a, b)" in { + ir.CallFunction( + "TRANSFORM_KEYS", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TRANSFORM_KEYS(a, b)" + } + + "TRANSFORM_VALUES(a, b)" in { + ir.CallFunction( + "TRANSFORM_VALUES", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TRANSFORM_VALUES(a, b)" + } + + "TRANSLATE(a, b, c)" in { + ir.CallFunction( + "TRANSLATE", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "TRANSLATE(a, b, c)" + } + + "TRIM(a, b)" in { + ir.CallFunction("TRIM", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TRIM(a, b)" + } + + "TRUNC(a, b)" in { + ir.CallFunction("TRUNC", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "TRUNC(a, b)" + } + "TYPEOF(a)" in { + ir.CallFunction("TYPEOF", Seq(ir.UnresolvedAttribute("a"))) generates "TYPEOF(a)" + } + "UCASE(a)" in { + ir.CallFunction("UCASE", Seq(ir.UnresolvedAttribute("a"))) generates "UCASE(a)" + } + "UNBASE64(a)" in { + ir.CallFunction("UNBASE64", Seq(ir.UnresolvedAttribute("a"))) generates "UNBASE64(a)" + } + "UNHEX(a)" in { + ir.CallFunction("UNHEX", Seq(ir.UnresolvedAttribute("a"))) generates "UNHEX(a)" + } + "UNIX_DATE(a)" in { + ir.CallFunction("UNIX_DATE", Seq(ir.UnresolvedAttribute("a"))) generates "UNIX_DATE(a)" + } + "UNIX_MICROS(a)" in { + ir.CallFunction("UNIX_MICROS", Seq(ir.UnresolvedAttribute("a"))) generates "UNIX_MICROS(a)" + } + "UNIX_MILLIS(a)" in { + ir.CallFunction("UNIX_MILLIS", Seq(ir.UnresolvedAttribute("a"))) generates "UNIX_MILLIS(a)" + } + "UNIX_SECONDS(a)" in { + ir.CallFunction("UNIX_SECONDS", Seq(ir.UnresolvedAttribute("a"))) generates "UNIX_SECONDS(a)" + } + + "UNIX_TIMESTAMP(a, b)" in { + ir.CallFunction( + "UNIX_TIMESTAMP", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "UNIX_TIMESTAMP(a, b)" + } + "UUID()" in { + ir.CallFunction("UUID", Seq()) generates "UUID()" + } + "VAR_POP(a)" in { + ir.CallFunction("VAR_POP", Seq(ir.UnresolvedAttribute("a"))) generates "VAR_POP(a)" + } + "VAR_SAMP(a)" in { + ir.CallFunction("VAR_SAMP", Seq(ir.UnresolvedAttribute("a"))) generates "VAR_SAMP(a)" + } + "VERSION()" in { + ir.CallFunction("VERSION", Seq()) generates "VERSION()" + } + "WEEKDAY(a)" in { + ir.CallFunction("WEEKDAY", Seq(ir.UnresolvedAttribute("a"))) generates "WEEKDAY(a)" + } + "WEEKOFYEAR(a)" in { + ir.CallFunction("WEEKOFYEAR", Seq(ir.UnresolvedAttribute("a"))) generates "WEEKOFYEAR(a)" + } + + "WIDTH_BUCKET(a, b, c, d)" in { + ir.CallFunction( + "WIDTH_BUCKET", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"), + ir.UnresolvedAttribute("d"))) generates "WIDTH_BUCKET(a, b, c, d)" + } + + "XPATH(a, b)" in { + ir.CallFunction("XPATH", Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH(a, b)" + } + + "XPATH_BOOLEAN(a, b)" in { + ir.CallFunction( + "XPATH_BOOLEAN", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_BOOLEAN(a, b)" + } + + "XPATH_DOUBLE(a, b)" in { + ir.CallFunction( + "XPATH_DOUBLE", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_DOUBLE(a, b)" + } + + "XPATH_FLOAT(a, b)" in { + ir.CallFunction( + "XPATH_FLOAT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_FLOAT(a, b)" + } + + "XPATH_INT(a, b)" in { + ir.CallFunction( + "XPATH_INT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_INT(a, b)" + } + + "XPATH_LONG(a, b)" in { + ir.CallFunction( + "XPATH_LONG", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_LONG(a, b)" + } + + "XPATH_SHORT(a, b)" in { + ir.CallFunction( + "XPATH_SHORT", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_SHORT(a, b)" + } + + "XPATH_STRING(a, b)" in { + ir.CallFunction( + "XPATH_STRING", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XPATH_STRING(a, b)" + } + + "XXHASH64(a, b)" in { + ir.CallFunction( + "XXHASH64", + Seq(ir.UnresolvedAttribute("a"), ir.UnresolvedAttribute("b"))) generates "XXHASH64(a, b)" + } + "YEAR(a)" in { + ir.CallFunction("YEAR", Seq(ir.UnresolvedAttribute("a"))) generates "YEAR(a)" + } + + "ZIP_WITH(a, b, c)" in { + ir.CallFunction( + "ZIP_WITH", + Seq( + ir.UnresolvedAttribute("a"), + ir.UnresolvedAttribute("b"), + ir.UnresolvedAttribute("c"))) generates "ZIP_WITH(a, b, c)" + } + } + + "literal" should { + "NULL" in { + ir.Literal(null) generates "NULL" + } + + "binary array" in { + ir.Literal(Array[Byte](0x01, 0x02, 0x03)) generates "010203" + } + + "booleans" in { + ir.Literal(true) generates "true" + ir.Literal(false) generates "false" + } + + "short" in { + ir.Literal(123, ir.ShortType) generates "123" + } + + "123" in { + ir.Literal(123) generates "123" + } + + "long" in { + ir.Literal(123, ir.LongType) generates "123" + } + + "float" in { + ir.Literal(123.4f, ir.FloatType) generates "123.4" + } + + "double" in { + ir.Literal(123.4, ir.DoubleType) generates "123.4" + } + + "decimal" in { + ir.Literal(BigDecimal("123.4")) generates "123.4" + } + + "string" in { + ir.Literal("abc") generates "'abc'" + } + + "string containing single quotes" in { + ir.Literal("a'b'c") generates "'a\\'b\\'c'" + } + + "CAST('2024-07-23 18:03:21' AS TIMESTAMP)" in { + ir.Literal(new Timestamp(1721757801L)) generates "CAST('2024-07-23 18:03:21' AS TIMESTAMP)" + } + + "CAST('2024-07-23' AS DATE)" in { + ir.Literal(new Date(1721757801000L)) generates "CAST('2024-07-23' AS DATE)" + } + + "ARRAY('abc', 'def')" in { + ir.Literal(Seq("abc", "def")) generates "ARRAY('abc', 'def')" + } + + "MAP('foo', 'bar', 'baz', 'qux')" in { + ir.Literal(Map("foo" -> "bar", "baz" -> "qux")) generates "MAP('foo', 'bar', 'baz', 'qux')" + } + } + + "distinct" should { + "be generated" in { + ir.Distinct(ir.Id("c1")) generates "DISTINCT c1" + } + } + + "star" should { + "be generated" in { + ir.Star(None) generates "*" + ir.Star(Some(ir.ObjectReference(ir.Id("t1")))) generates "t1.*" + ir.Star( + Some( + ir.ObjectReference(ir.Id("schema1"), ir.Id("table 1", caseSensitive = true)))) generates "schema1.`table 1`.*" + } + } + + "case...when...else" should { + "be generated" in { + ir.Case(None, Seq(ir.WhenBranch(ir.Literal(true), ir.Literal(42))), None) generates "CASE WHEN true THEN 42 END" + + ir.Case( + Some(ir.Id("c1")), + Seq(ir.WhenBranch(ir.Literal(true), ir.Literal(42))), + None) generates "CASE c1 WHEN true THEN 42 END" + + ir.Case( + Some(ir.Id("c1")), + Seq(ir.WhenBranch(ir.Literal(true), ir.Literal(42))), + Some(ir.Literal(0))) generates "CASE c1 WHEN true THEN 42 ELSE 0 END" + + ir.Case( + Some(ir.Id("c1")), + Seq(ir.WhenBranch(ir.Literal("Answer"), ir.Literal(42)), ir.WhenBranch(ir.Literal("Year"), ir.Literal(2024))), + Some(ir.Literal(0))) generates "CASE c1 WHEN 'Answer' THEN 42 WHEN 'Year' THEN 2024 ELSE 0 END" + + } + } + + "IN" should { + "be generated" in { + ir.In( + ir.Id("c1"), + Seq( + ir.ScalarSubquery( + ir.Project(namedTable("table1"), Seq(ir.Id("column1")))))) generates "c1 IN (SELECT column1 FROM table1)" + + ir.In(ir.Id("c1"), Seq(ir.Literal(1), ir.Literal(2), ir.Literal(3))) generates "c1 IN (1, 2, 3)" + } + } + + "window functions" should { + "be generated" in { + ir.Window( + ir.RowNumber(), + Seq(ir.Id("a")), + Seq(ir.SortOrder(ir.Id("b"), ir.Ascending, ir.NullsFirst)), + Some( + ir.WindowFrame( + ir.RowsFrame, + ir.CurrentRow, + ir.NoBoundary))) generates "ROW_NUMBER() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS CURRENT ROW)" + + ir.Window( + ir.RowNumber(), + Seq(ir.Id("a")), + Seq(ir.SortOrder(ir.Id("b"), ir.Ascending, ir.NullsFirst)), + Some( + ir.WindowFrame( + ir.RangeFrame, + ir.CurrentRow, + ir.FollowingN( + ir.Literal(42))))) generates "ROW_NUMBER() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 42 FOLLOWING)" + + } + } + + "JSON_ACCESS" should { + + "handle valid identifier" in { + ir.JsonAccess(ir.Id("c1"), ir.Literal("a")) generates "c1['a']" + } + + "handle invalid identifier" in { + ir.JsonAccess(ir.Id("c1"), ir.Id("1", caseSensitive = true)) generates "c1[\"1\"]" + } + + "handle integer literal" in { + ir.JsonAccess(ir.Id("c1"), ir.Literal(123)) generates "c1[123]" + } + + "handle string literal" in { + ir.JsonAccess(ir.Id("c1"), ir.Literal("abc")) generates "c1['abc']" + } + + "handle dot expression" in { + ir.JsonAccess(ir.Dot(ir.Id("c1"), ir.Id("c2")), ir.Literal("a")) generates "c1.c2['a']" + } + + "be generated" in { + ir.JsonAccess(ir.JsonAccess(ir.Id("c1"), ir.Literal("a")), ir.Literal("b")) generates "c1['a']['b']" + + ir.JsonAccess( + ir.JsonAccess( + ir.JsonAccess(ir.Dot(ir.Id("demo"), ir.Id("level_key")), ir.Literal("level_1_key")), + ir.Literal("level_2_key")), + ir.Id("1")) generates "demo.level_key['level_1_key']['level_2_key'][\"1\"]" + } + } + + "error node in expression tree" should { + "generate inline error message comment for a and bad text" in { + ir.And(ir.Id("a"), ir.UnresolvedExpression(ruleText = "bad text", message = "some error message")) generates + """a AND /* The following issues were detected: + | + | some error message + | bad text + | */""".stripMargin + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGeneratorTest.scala new file mode 100644 index 0000000000..c5dfeb4bd5 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGeneratorTest.scala @@ -0,0 +1,533 @@ +package com.databricks.labs.remorph.generators.sql + +import com.databricks.labs.remorph.generators.{GeneratorContext, GeneratorTestCommon} +import com.databricks.labs.remorph.{Generating, intermediate => ir} +import org.scalatest.wordspec.AnyWordSpec + +class LogicalPlanGeneratorTest extends AnyWordSpec with GeneratorTestCommon[ir.LogicalPlan] with ir.IRHelpers { + + protected val expressionGenerator = new ExpressionGenerator() + protected val optionGenerator = new OptionGenerator(expressionGenerator) + override protected val generator = new LogicalPlanGenerator(expressionGenerator, optionGenerator) + + override protected def initialState(plan: ir.LogicalPlan) = + Generating(optimizedPlan = plan, currentNode = plan, ctx = GeneratorContext(generator)) + "Project" should { + "transpile to SELECT" in { + ir.Project(namedTable("t1"), Seq(ir.Id("c1"))) generates "SELECT c1 FROM t1" + ir.Project(namedTable("t1"), Seq(ir.Star(None))) generates "SELECT * FROM t1" + ir.Project(ir.NoTable, Seq(ir.Literal(1))) generates "SELECT 1" + } + + "transpile to SELECT with COMMENTS" in { + ir.WithOptions( + ir.Project(namedTable("t"), Seq(ir.Star(None))), + ir.Options( + Map("MAXRECURSION" -> ir.Literal(10), "OPTIMIZE" -> ir.Column(None, ir.Id("FOR", caseSensitive = true))), + Map("SOMESTROPT" -> "STRINGOPTION"), + Map("SOMETHING" -> true, "SOMETHINGELSE" -> false), + List("SOMEOTHER"))) generates + """/* + | The following statement was originally given the following OPTIONS: + | + | Expression options: + | + | MAXRECURSION = 10 + | OPTIMIZE = `FOR` + | + | String options: + | + | SOMESTROPT = 'STRINGOPTION' + | + | Boolean options: + | + | SOMETHING ON + | SOMETHINGELSE OFF + | + | Auto options: + | + | SOMEOTHER AUTO + | + | + | */ + |SELECT * FROM t""".stripMargin + } + } + + "Filter" should { + "transpile to WHERE" in { + ir.Filter( + ir.Project(namedTable("t1"), Seq(ir.Id("c1"))), + ir.CallFunction("IS_DATE", Seq(ir.Id("c2")))) generates "(SELECT c1 FROM t1) WHERE IS_DATE(c2)" + } + } + + "MergeIntoTable" should { + "transpile to MERGE" in { + ir.MergeIntoTable( + namedTable("t"), + namedTable("s"), + ir.Equals( + ir.Column(Some(ir.ObjectReference(ir.Id("t"))), ir.Id("a")), + ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("a"))), + Seq( + ir.UpdateAction( + None, + Seq( + ir.Assign( + ir.Column(Some(ir.ObjectReference(ir.Id("t"))), ir.Id("b")), + ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("b"))))), + ir.DeleteAction( + Some(ir.LessThan(ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("b")), ir.Literal(10))))), + Seq( + ir.InsertAction( + None, + Seq( + ir.Assign(ir.Column(None, ir.Id("a")), ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("a"))), + ir.Assign(ir.Column(None, ir.Id("b")), ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("b")))))), + List.empty) generates + s"""MERGE INTO t + |USING s + |ON t.a = s.a + | WHEN MATCHED THEN UPDATE SET t.b = s.b WHEN MATCHED AND s.b < 10 THEN DELETE + | WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.a, s.b) + | + |""".stripMargin + } + "transpile to MERGE with comments" in { + ir.WithOptions( + ir.MergeIntoTable( + ir.NamedTable("t", Map(), is_streaming = false), + ir.NamedTable("s", Map(), is_streaming = false), + ir.Equals( + ir.Column(Some(ir.ObjectReference(ir.Id("t"))), ir.Id("a")), + ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("a"))), + Seq( + ir.UpdateAction( + None, + Seq( + ir.Assign( + ir.Column(Some(ir.ObjectReference(ir.Id("t"))), ir.Id("b")), + ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("b")))))), + Seq( + ir.InsertAction( + None, + Seq( + ir.Assign(ir.Column(None, ir.Id("a")), ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("a"))), + ir.Assign(ir.Column(None, ir.Id("b")), ir.Column(Some(ir.ObjectReference(ir.Id("s"))), ir.Id("b")))))), + List.empty), + ir.Options( + Map( + "KEEPFIXED" -> ir.Column(None, ir.Id("PLAN")), + "FAST" -> ir.Literal(666), + "MAX_GRANT_PERCENT" -> ir.Literal(30)), + Map(), + Map("FLAME" -> false, "QUICKLY" -> true), + List())) generates + """/* + | The following statement was originally given the following OPTIONS: + | + | Expression options: + | + | KEEPFIXED = PLAN + | FAST = 666 + | MAX_GRANT_PERCENT = 30 + | + | Boolean options: + | + | FLAME OFF + | QUICKLY ON + | + | + | */ + |MERGE INTO t + |USING s + |ON t.a = s.a + | WHEN MATCHED THEN UPDATE SET t.b = s.b + | WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.a, s.b) + | + |""".stripMargin + } + } + + "UpdateTable" should { + "transpile to UPDATE" in { + ir.UpdateTable( + namedTable("t"), + None, + Seq( + ir.Assign(ir.Column(None, ir.Id("a")), ir.Literal(1)), + ir.Assign(ir.Column(None, ir.Id("b")), ir.Literal(2))), + Some(ir.Equals(ir.Column(None, ir.Id("c")), ir.Literal(3)))) generates + "UPDATE t SET a = 1, b = 2 WHERE c = 3" + } + } + + "InsertIntoTable" should { + "transpile to INSERT" in { + ir.InsertIntoTable( + namedTable("t"), + Some(Seq(ir.Id("a"), ir.Id("b"))), + ir.Values(Seq(Seq(ir.Literal(1), ir.Literal(2))))) generates "INSERT INTO t (a, b) VALUES (1,2)" + } + } + + "DeleteFromTable" should { + "transpile to DELETE" in { + ir.DeleteFromTable( + target = namedTable("t"), + where = Some(ir.Equals(ir.Column(None, ir.Id("c")), ir.Literal(3)))) generates + "DELETE FROM t WHERE c = 3" + } + } + + "Join" should { + "transpile to JOIN" in { + crossJoin(namedTable("t1"), namedTable("t2")) generates "t1 CROSS JOIN t2" + + ir.Join( + namedTable("t1"), + namedTable("t2"), + None, + ir.InnerJoin, + Seq(), + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) generates "t1, t2" + + ir.Join( + namedTable("t1"), + namedTable("t2"), + Some(ir.CallFunction("IS_DATE", Seq(ir.Id("c1")))), + ir.InnerJoin, + Seq(), + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) generates "t1 INNER JOIN t2 ON IS_DATE(c1)" + + ir.Join( + namedTable("t1"), + namedTable("t2"), + Some(ir.CallFunction("IS_DATE", Seq(ir.Id("c1")))), + ir.RightOuterJoin, + Seq("c1", "c2"), + ir.JoinDataType( + is_left_struct = false, + is_right_struct = false)) generates "t1 RIGHT JOIN t2 ON IS_DATE(c1) USING (c1, c2)" + + ir.Join( + namedTable("t1"), + namedTable("t2"), + None, + ir.NaturalJoin(ir.LeftOuterJoin), + Seq(), + ir.JoinDataType(is_left_struct = false, is_right_struct = false)) generates "t1 NATURAL LEFT JOIN t2" + } + } + + "SetOperation" should { + "transpile to UNION" in { + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false) generates "(a) UNION (b)" + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false) generates "(a) UNION ALL (b)" + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.UnionSetOp, + is_all = true, + by_name = true, + allow_missing_columns = false) + .doesNotTranspile + } + + "transpile to INTERSECT" in { + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false) generates "(a) INTERSECT (b)" + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.IntersectSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false) generates "(a) INTERSECT ALL (b)" + } + + "transpile to EXCEPT" in { + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false) generates "(a) EXCEPT (b)" + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.ExceptSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false) generates "(a) EXCEPT ALL (b)" + } + + "unspecified" in { + ir.SetOperation( + namedTable("a"), + namedTable("b"), + ir.UnspecifiedSetOp, + is_all = true, + by_name = true, + allow_missing_columns = false) + .doesNotTranspile + } + } + + "transpile to LIMIT" in { + ir.Limit(namedTable("a"), ir.Literal(10)) generates "a LIMIT 10" + } + + "transpile to OFFSET" in { + ir.Offset(namedTable("a"), ir.Literal(10)) generates "a OFFSET 10" + } + + "transpile to ORDER BY" in { + ir.Sort( + namedTable("a"), + Seq(ir.SortOrder(ir.Id("c1"), ir.Ascending, ir.NullsFirst))) generates "a ORDER BY c1 NULLS FIRST" + } + + "transpile to VALUES" in { + ir.Values(Seq(Seq(ir.Literal(1), ir.Literal(2)), Seq(ir.Literal(3), ir.Literal(4)))) generates "VALUES (1,2), (3,4)" + } + + "Aggregate" should { + "transpile to GROUP BY" in { + ir.Aggregate(namedTable("t1"), ir.GroupBy, Seq(ir.Id("c1")), None) generates "t1 GROUP BY c1" + } + + "transpile to PIVOT" in { + ir.Aggregate( + namedTable("t1"), + ir.Pivot, + Seq(ir.Id("c1")), + Some(ir.Pivot(ir.Id("c2"), Seq(ir.Literal(1), ir.Literal(2))))) generates + "t1 PIVOT(c1 FOR c2 IN(1, 2))" + } + + "transpile to SELECT DISTINCT" in { + ir.Deduplicate( + namedTable("t1"), + Seq(ir.Id("c1"), ir.Id("c2")), + all_columns_as_keys = false, + within_watermark = false) generates "SELECT DISTINCT c1, c2 FROM t1" + + ir.Deduplicate( + namedTable("t1"), + Seq(), + all_columns_as_keys = true, + within_watermark = false) generates "SELECT DISTINCT * FROM t1" + } + + } + + "transpile to AS" in { + ir.TableAlias( + namedTable("table1"), + "t1", + Seq(ir.Id("c1"), ir.Id("c2"), ir.Id("c3"))) generates "table1 AS t1(c1, c2, c3)" + } + + "CreateTableCommand" should { + "transpile to CREATE TABLE" in { + ir.CreateTableCommand( + "t1", + Seq( + ir.ColumnDeclaration( + "c1", + ir.IntegerType, + constraints = Seq(ir.Nullability(nullable = false), ir.PrimaryKey())), + ir.ColumnDeclaration( + "c2", + ir.StringType))) generates "CREATE TABLE t1 (c1 INT NOT NULL PRIMARY KEY, c2 STRING )" + } + } + + "TableSample" should { + "transpile to TABLESAMPLE" in { + ir.TableSample( + namedTable("t1"), + ir.RowSamplingFixedAmount(BigDecimal(10)), + Some(BigDecimal(10))) generates "(t1) TABLESAMPLE (10 ROWS) REPEATABLE (10)" + } + } + + "CreateTableParameters" should { + "transpile to CREATE TABLE" in { + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq.empty)) generates "CREATE TABLE some_table (a INT, b VARCHAR(10))" + } + + "transpile to CREATE TABLE with commented options" in { + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq(ir.OptionUnresolved("Unsupported Option: SPARSE")), "b" -> Seq.empty), + Seq( + ir.NamedConstraint( + "c1", + ir.CheckConstraint(ir.GreaterThan(ir.Column(None, ir.Id("a")), ir.Literal(0, ir.IntegerType))))), + Seq.empty, + None, + Some(Seq.empty)) generates + "CREATE TABLE some_table (a INT /* Unsupported Option: SPARSE */, b VARCHAR(10), CONSTRAINT c1 CHECK (a > 0))" + } + + "transpile to CREATE TABLE with default values" in { + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map( + "a" -> Seq(ir.DefaultValueConstraint(ir.Literal(0, ir.IntegerType))), + "b" -> Seq(ir.DefaultValueConstraint(ir.Literal("foo", ir.StringType)))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq.empty)) generates + "CREATE TABLE some_table (a INT DEFAULT 0, b VARCHAR(10) DEFAULT 'foo')" + } + + "transpile to CREATE TABLE with foreign key table constraint" in { + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq(ir.NamedConstraint("c1", ir.ForeignKey("a, b", "other_table", "c, d", Seq.empty))), + Seq.empty, + None, + Some(Seq.empty)) generates + "CREATE TABLE some_table (a INT, b VARCHAR(10), CONSTRAINT c1 FOREIGN KEY (a, b) REFERENCES other_table(c, d))" + } + + "transpile to CREATE TABLE with a primary key, foreign key and a Unique column" in { + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq(ir.PrimaryKey()), "b" -> Seq(ir.Unique())), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq(ir.ForeignKey("b", "other_table", "b", Seq.empty)), + Seq.empty, + None, + Some(Seq.empty)) generates + "CREATE TABLE some_table (a INT PRIMARY KEY, b VARCHAR(10) UNIQUE, FOREIGN KEY (b) REFERENCES other_table(b))" + } + + "transpile to CREATE TABLE with various complex column constraints" in { + ir.CreateTableParams( + ir.CreateTable( + "example_table", + None, + None, + None, + ir.StructType(Seq( + ir.StructField("id", ir.IntegerType), + ir.StructField("name", ir.VarcharType(Some(50)), nullable = false), + ir.StructField("age", ir.IntegerType), + ir.StructField("email", ir.VarcharType(Some(100))), + ir.StructField("department_id", ir.IntegerType)))), + Map( + "name" -> Seq.empty, + "email" -> Seq(ir.Unique()), + "department_id" -> Seq.empty, + "age" -> Seq.empty, + "id" -> Seq(ir.PrimaryKey())), + Map( + "name" -> Seq.empty, + "email" -> Seq.empty, + "department_id" -> Seq.empty, + "age" -> Seq.empty, + "id" -> Seq.empty), + Seq( + ir.CheckConstraint(ir.GreaterThanOrEqual(ir.Column(None, ir.Id("age")), ir.Literal(18, ir.IntegerType))), + ir.ForeignKey("department_id", "departments", "id", Seq.empty)), + Seq.empty, + None, + Some(Seq.empty)) generates + "CREATE TABLE example_table (id INT PRIMARY KEY, name VARCHAR(50) NOT NULL, age INT," + + " email VARCHAR(100) UNIQUE, department_id INT, CHECK (age >= 18)," + + " FOREIGN KEY (department_id) REFERENCES departments(id))" + } + "transpile to CREATE TABLE with a named NULL constraint" in { + ir.CreateTableParams( + ir.CreateTable( + "example_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("id", ir.VarcharType(Some(10)))))), + Map("id" -> Seq.empty), + Map("id" -> Seq.empty), + Seq(ir.NamedConstraint("c1", ir.CheckConstraint(ir.IsNotNull(ir.Column(None, ir.Id("id")))))), + Seq.empty, + None, + Some(Seq.empty)) generates "CREATE TABLE example_table (id VARCHAR(10), CONSTRAINT c1 CHECK (id IS NOT NULL))" + } + + "transpile unsupported table level options to comments" in { + ir.CreateTableParams( + ir.CreateTable("example_table", None, None, None, ir.StructType(Seq(ir.StructField("id", ir.IntegerType)))), + Map("id" -> Seq.empty), + Map("id" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq(ir.OptionUnresolved("LEDGER=ON")))) generates + """/* + | The following options are unsupported: + | + | LEDGER=ON + |*/ + |CREATE TABLE example_table (id INT)""".stripMargin + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/graph/TableGraphTest.scala b/core/src/test/scala/com/databricks/labs/remorph/graph/TableGraphTest.scala new file mode 100644 index 0000000000..87538200b1 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/graph/TableGraphTest.scala @@ -0,0 +1,142 @@ +package com.databricks.labs.remorph.graph + +import com.databricks.labs.remorph.discovery.TableDefinition +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.sql.Timestamp +import java.time.Duration +import com.databricks.labs.remorph.discovery.{ExecutedQuery, QueryHistory} +import com.databricks.labs.remorph.intermediate.{IntegerType, StringType, StructField} + +class TableGraphTest extends AnyFlatSpec with Matchers { + private[this] val parser = new SnowflakePlanParser() + private[this] val queryHistory = QueryHistory( + Seq( + ExecutedQuery( + "query1", + "INSERT INTO table1 SELECT col1, col2 FROM table2 INNER JOIN table3 on table2.id = table3.id", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(30), + Some("user1")), + ExecutedQuery( + "query2", + "INSERT INTO table2 (col1, col2) VALUES (1, 'value1')", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(45), + Some("user2")), + ExecutedQuery( + "query3", + "SELECT * FROM table3 JOIN table4 ON table3.id = table4.id", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(60), + Some("user3")), + ExecutedQuery( + "query4", + "SELECT col1, (SELECT MAX(col2) FROM table5) AS max_col2 FROM table5", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(25), + Some("user4")), + ExecutedQuery( + "query5", + "WITH cte AS (SELECT col1 FROM table5) SELECT * FROM cte", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(35), + Some("user5")), + ExecutedQuery( + "query6", + "INSERT INTO table1 (col1, col2) VALUES (2, 'value2')", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(40), + Some("user1")), + ExecutedQuery( + "query7", + """MERGE INTO table2 USING table3 source_table ON table2.id = source_table.id + |WHEN MATCHED THEN UPDATE SET table2.col1 = source_table.col1""".stripMargin, + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(50), + Some("user2")), + ExecutedQuery( + "query8", + "UPDATE table3 SET col1 = 'new_value' WHERE col2 = 'condition'", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(55), + Some("user3")), + ExecutedQuery( + "query9", + "DELETE FROM table4 WHERE col1 = 'value_to_delete'", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(20), + Some("user4")), + ExecutedQuery( + "query10", + "INSERT INTO table2 SELECT * FROM table5 WHERE col1 = 'some_value'", + new Timestamp(System.currentTimeMillis()), + Duration.ofSeconds(65), + Some("user5")))) + + private[this] val tableDefinitions = Set( + TableDefinition( + catalog = "catalog1", + schema = "schema1", + table = "table1", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 10), + TableDefinition( + catalog = "catalog2", + schema = "schema2", + table = "table2", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 20), + TableDefinition( + catalog = "catalog3", + schema = "schema3", + table = "table3", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 30), + TableDefinition( + catalog = "catalog4", + schema = "schema4", + table = "table4", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 40), + TableDefinition( + catalog = "catalog5", + schema = "schema5", + table = "table5", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 50)) + val graph = new TableGraph(parser) + graph.buildDependency(queryHistory, tableDefinitions) + + "TableDependencyGraph" should "add nodes correctly" in { + val roots = graph.getRootTables() + assert(roots.size == 3) + assert(roots.map(x => x.table).toList.sorted.toSet == Set("table3", "table4", "table5")) + + } + + "TableDependencyGraph" should "return correct upstream tables" in { + val upstreamTables = graph.getUpstreamTables( + TableDefinition( + catalog = "catalog1", + schema = "schema1", + table = "table1", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 10)) + assert(upstreamTables.map(_.table).toList.sorted == List("table2", "table3", "table5")) + } + + "TableDependencyGraph" should "return correct downstream tables" in { + val downstreamTables = graph.getDownstreamTables( + TableDefinition( + catalog = "catalog5", + schema = "schema5", + table = "table5", + columns = Seq(StructField("col1", IntegerType, true), StructField("col2", StringType, false)), + sizeGb = 50)) + assert(downstreamTables.map(_.table).toList.sorted == List("table1", "table2")) + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeQueryHistoryTest.scala b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeQueryHistoryTest.scala new file mode 100644 index 0000000000..fece735c4d --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeQueryHistoryTest.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.integration + +import com.databricks.labs.remorph.connections.{EnvGetter, SnowflakeConnectionFactory} +import com.databricks.labs.remorph.discovery.SnowflakeQueryHistory +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeQueryHistoryTest extends AnyWordSpec with Matchers { + "integration" should { + "work in happy path" ignore { + val env = new EnvGetter + val connFactory = new SnowflakeConnectionFactory(env) + val conn = connFactory.newConnection() // TODO: wrap with closing logic + val snow = new SnowflakeQueryHistory(conn) + val history = snow.history() + assert(history.queries.nonEmpty) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableDefinitionTest.scala b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableDefinitionTest.scala new file mode 100644 index 0000000000..d02ec0c90d --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableDefinitionTest.scala @@ -0,0 +1,23 @@ +package com.databricks.labs.remorph.integration + +import com.databricks.labs.remorph.connections.{EnvGetter, SnowflakeConnectionFactory} +import com.databricks.labs.remorph.discovery.SnowflakeTableDefinitions +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeTableDefinitionTest extends AnyWordSpec with Matchers { + "integration" should { + "get table definitions for snowflake" in { + val env = new EnvGetter + val connFactory = new SnowflakeConnectionFactory(env) + val conn = connFactory.newConnection() + try { + val snow = new SnowflakeTableDefinitions(conn) + val rs = snow.getAllTableDefinitions + rs should not be empty + } finally { + conn.close() + } + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableGraphTest.scala b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableGraphTest.scala new file mode 100644 index 0000000000..591ca7f366 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/integration/SnowflakeTableGraphTest.scala @@ -0,0 +1,27 @@ +package com.databricks.labs.remorph.integration + +import com.databricks.labs.remorph.connections.{EnvGetter, SnowflakeConnectionFactory} +import com.databricks.labs.remorph.discovery.{SnowflakeQueryHistory, SnowflakeTableDefinitions} +import com.databricks.labs.remorph.parsers.snowflake.SnowflakePlanParser +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import com.databricks.labs.remorph.graph.TableGraph + +class SnowflakeTableGraphTest extends AnyWordSpec with Matchers { + "TableGraph" should { + "retrieve downstream tables correctly" ignore { + val env = new EnvGetter + val connFactory = new SnowflakeConnectionFactory(env) + val conn = connFactory.newConnection() + val snow = new SnowflakeQueryHistory(conn) + val history = snow.history() + val parser = new SnowflakePlanParser() + val tableDefinition = new SnowflakeTableDefinitions(conn).getAllTableDefinitions + + val tableGraph = new TableGraph(parser) + tableGraph.buildDependency(history, tableDefinition.toSet) + // TODO Currently there is not enough example in query history for the integration test to work. + assert(tableGraph.getRootTables().isEmpty) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/integration/TSqlTableDefinationTest.scala b/core/src/test/scala/com/databricks/labs/remorph/integration/TSqlTableDefinationTest.scala new file mode 100644 index 0000000000..afe17ec87c --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/integration/TSqlTableDefinationTest.scala @@ -0,0 +1,33 @@ +package com.databricks.labs.remorph.integration + +import com.databricks.labs.remorph.connections.{EnvGetter, TSqlConnectionFactory} +import com.databricks.labs.remorph.discovery.TSqlTableDefinitions +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlTableDefinationTest extends AnyWordSpec with Matchers { + "integration" should { + "get table definitions for TSql" in { + val conn = new TSqlConnectionFactory(new EnvGetter).newConnection() + try { + val sqlTD = new TSqlTableDefinitions(conn) + val rs = sqlTD.getAllTableDefinitions + rs should not be empty + } finally { + conn.close() + } + } + + "get all catalog for TSql" in { + val conn = new TSqlConnectionFactory(new EnvGetter).newConnection() + try { + val sqlTD = new TSqlTableDefinitions(conn) + val rs = sqlTD.getAllCatalogs + rs should not be empty + } finally { + conn.close() + } + } + + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/intermediate/JoinTest.scala b/core/src/test/scala/com/databricks/labs/remorph/intermediate/JoinTest.scala new file mode 100644 index 0000000000..070fc23261 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/intermediate/JoinTest.scala @@ -0,0 +1,19 @@ +package com.databricks.labs.remorph.intermediate + +import org.scalatest.matchers.must.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class JoinTest extends AnyWordSpec with Matchers { + "Join" should { + "propagage expressions" in { + val join = Join( + NoopNode, + NoopNode, + Some(Name("foo")), + InnerJoin, + Seq.empty, + JoinDataType(is_left_struct = true, is_right_struct = true)) + join.expressions mustBe Seq(Name("foo")) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/ErrorCollectorSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/ErrorCollectorSpec.scala new file mode 100644 index 0000000000..98c3ca00b5 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/ErrorCollectorSpec.scala @@ -0,0 +1,38 @@ +package com.databricks.labs.remorph.parsers + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class ErrorCollectorSpec extends AnyWordSpec with Matchers with MockitoSugar { + + "ProductionErrorCollector.formatError" should { + val line = "AAA BBB CCC DDD EEE FFF GGG HHH III JJJ" + val prodCollector = new ProductionErrorCollector(line, "example.sql") + + "not clip lines shorter than window width" in { + + prodCollector.formatError(line, 4, 3) shouldBe + """AAA BBB CCC DDD EEE FFF GGG HHH III JJJ + | ^^^""".stripMargin + } + + "clip longer lines on the right when the offending token is close to the start" in { + prodCollector.formatError(line, 4, 3, 20) shouldBe + """AAA BBB CCC DDD E... + | ^^^""".stripMargin + } + + "clip longer lines on the left when the offending token is close to the end" in { + prodCollector.formatError(line, 32, 3, 20) shouldBe + """...F GGG HHH III JJJ + | ^^^""".stripMargin + } + + "clip longer lines on both sides when the offending toke is too far from both ends" in { + prodCollector.formatError(line, 16, 3, 20) shouldBe + """... DDD EEE FFF G... + | ^^^""".stripMargin + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/ParserTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/ParserTestCommon.scala new file mode 100644 index 0000000000..6b0557e659 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/ParserTestCommon.scala @@ -0,0 +1,80 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.tree.ParseTreeVisitor +import org.scalatest.{Assertion, Assertions} + +trait ParserTestCommon[P <: Parser] extends PlanComparison { self: Assertions => + + protected def makeLexer(chars: CharStream): TokenSource + protected def makeParser(tokens: TokenStream): P + protected def makeErrStrategy(): SqlErrorStrategy + protected def makeErrListener(chars: String): ErrorCollector = new DefaultErrorCollector + protected def astBuilder: ParseTreeVisitor[_] + protected var errListener: ErrorCollector = _ + + protected def parseString[R <: RuleContext](input: String, rule: P => R): R = { + val inputString = CharStreams.fromString(input) + val lexer = makeLexer(inputString) + val tokenStream = new CommonTokenStream(lexer) + val parser = makeParser(tokenStream) + errListener = makeErrListener(input) + parser.removeErrorListeners() + parser.addErrorListener(errListener) + parser.setErrorHandler(makeErrStrategy()) + val tree = rule(parser) + + // uncomment the following line if you need a peek in the Snowflake/TSQL AST + // println(tree.toStringTree(parser)) + tree + } + + protected def example[R <: RuleContext]( + query: String, + rule: P => R, + expectedAst: ir.LogicalPlan, + failOnErrors: Boolean = true): Unit = { + val sfTree = parseString(query, rule) + if (errListener != null && errListener.errorCount != 0) { + errListener.logErrors() + if (failOnErrors) { + fail(s"${errListener.errorCount} errors found in the child string") + } + } + + val result = astBuilder.visit(sfTree) + comparePlans(expectedAst, result.asInstanceOf[ir.LogicalPlan]) + } + + protected def exampleExpr[R <: RuleContext](query: String, rule: P => R, expectedAst: ir.Expression): Unit = { + val sfTree = parseString(query, rule) + if (errListener != null && errListener.errorCount != 0) { + errListener.logErrors() + fail(s"${errListener.errorCount} errors found in the child string") + } + val result = astBuilder.visit(sfTree) + val wrapExpr = (expr: ir.Expression) => ir.Filter(ir.NoopNode, expr) + comparePlans(wrapExpr(expectedAst), wrapExpr(result.asInstanceOf[ir.Expression])) + } + + /** + * Used to pass intentionally bad syntax to the parser and check that it fails with an expected error + * @param query + * @param rule + * @tparam R + * @return + */ + protected def checkError[R <: RuleContext](query: String, rule: P => R, errContains: String): Assertion = { + parseString(query, rule) + if (errListener != null && errListener.errorCount == 0) { + fail(s"Expected an error in the child string\n$query\nbut no errors were found") + } + + val errors = errListener.formatErrors + assert( + errors.exists(_.contains(errContains)), + s"Expected error containing '$errContains' but got:\n${errors.mkString("\n")}") + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/PlanComparison.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/PlanComparison.scala new file mode 100644 index 0000000000..362575bec4 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/PlanComparison.scala @@ -0,0 +1,39 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.utils.Strings +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.Assertions + +trait PlanComparison { + self: Assertions => + + protected def comparePlans(a: ir.LogicalPlan, b: ir.LogicalPlan): Unit = { + val expected = reorderComparisons(a) + val actual = reorderComparisons(b) + if (expected != actual) { + fail(s""" + |== FAIL: Plans do not match (expected vs actual) === + |${Strings.sideBySide(pretty(expected), pretty(actual)).mkString("\n")} + """.stripMargin) + } + } + + private def pretty(x: Any): String = pprint.apply(x, width = 40).plainText + + protected def eraseExprIds(plan: ir.LogicalPlan): ir.LogicalPlan = { + val exprId = ir.ExprId(0) + plan transformAllExpressions { case ir.AttributeReference(name, dt, nullable, _, qualifier) => + ir.AttributeReference(name, dt, nullable, exprId, qualifier) + } + } + + protected def reorderComparisons(plan: ir.LogicalPlan): ir.LogicalPlan = { + eraseExprIds(plan) transformAllExpressions { + case ir.Equals(l, r) if l.hashCode() > r.hashCode() => ir.Equals(r, l) + case ir.GreaterThan(l, r) if l.hashCode() > r.hashCode() => ir.LessThan(r, l) + case ir.GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.LessThanOrEqual(r, l) + case ir.LessThan(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThan(r, l) + case ir.LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThanOrEqual(r, l) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/SetOperationBehaviors.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/SetOperationBehaviors.scala new file mode 100644 index 0000000000..07d575fb00 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/SetOperationBehaviors.scala @@ -0,0 +1,190 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.intermediate.LogicalPlan +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.{Parser, RuleContext} +import org.scalatest.wordspec.AnyWordSpec + +trait SetOperationBehaviors[P <: Parser] extends ir.IRHelpers { this: ParserTestCommon[P] with AnyWordSpec => + + def setOperationsAreTranslated[R <: RuleContext](rule: P => R): Unit = { + def testSimpleExample(query: String, expectedAst: LogicalPlan): Unit = { + query in { + example(query, rule, expectedAst) + } + } + + "translate set operations" should { + /* These are of the form: + * SELECT [...] ((UNION [ALL] | EXCEPT | INTERSECT) SELECT [...])* + * + * Note that precedence is: + * 1. Brackets. + * 2. INTERSECT + * 3. UNION and EXCEPT, from left to right. + */ + testSimpleExample( + "SELECT 1 UNION SELECT 2", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "SELECT 1 UNION ALL SELECT 2", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "SELECT 1 EXCEPT SELECT 2", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "SELECT 1 INTERSECT SELECT 2", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3 EXCEPT SELECT 4 INTERSECT SELECT 5", + ir.SetOperation( + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false), + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(4, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(5, ir.IntegerType))), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + // Part of checking that UNION and EXCEPT are processed with the same precedence: left-to-right + testSimpleExample( + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3", + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + // Part of checking that UNION and EXCEPT are processed with the same precedence: left-to-right + testSimpleExample( + "SELECT 1 EXCEPT SELECT 2 UNION SELECT 3", + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + // INTERSECT has higher precedence than both UNION and EXCEPT + testSimpleExample( + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 INTERSECT SELECT 4", + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(4, ir.IntegerType))), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "SELECT 1 UNION (SELECT 2 UNION ALL SELECT 3) INTERSECT (SELECT 4 EXCEPT SELECT 5)", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false), + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(4, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(5, ir.IntegerType))), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "(SELECT 1, 2) UNION SELECT 3, 4", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType), ir.Literal(2, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType), ir.Literal(4, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + testSimpleExample( + "(SELECT 1, 2) UNION ALL SELECT 3, 4", + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType), ir.Literal(2, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType), ir.Literal(4, ir.IntegerType))), + ir.UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false)) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/VisitorsSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/VisitorsSpec.scala new file mode 100644 index 0000000000..ac0c017631 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/VisitorsSpec.scala @@ -0,0 +1,23 @@ +package com.databricks.labs.remorph.parsers + +import com.databricks.labs.remorph.parsers.tsql.{TSqlLexer, TSqlParser, TSqlParserBaseVisitor, TSqlVisitorCoordinator} +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} +import org.scalatest.wordspec.AnyWordSpec + +class FakeVisitor(override val vc: TSqlVisitorCoordinator) + extends TSqlParserBaseVisitor[String] + with ParserCommon[String] { + override protected def unresolved(ruleText: String, message: String): String = ruleText +} + +class VistorsSpec extends AnyWordSpec { + + "Visitors" should { + "correctly collect text from contexts" in { + val stream = CharStreams.fromString("SELECT * FROM table;") + val result = new TSqlParser(new CommonTokenStream(new TSqlLexer(stream))).tSqlFile() + val text = new FakeVisitor(null).contextText(result) + assert(text == "SELECT * FROM table;") + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/common/FunctionBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/common/FunctionBuilderSpec.scala new file mode 100644 index 0000000000..35193585ca --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/common/FunctionBuilderSpec.scala @@ -0,0 +1,631 @@ +package com.databricks.labs.remorph.parsers.common + +import com.databricks.labs.remorph.parsers.snowflake.{NamedArgumentExpression, SnowflakeFunctionBuilder} +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeFunctionConverters.SynonymOf +import com.databricks.labs.remorph.parsers.tsql.TSqlFunctionBuilder +import com.databricks.labs.remorph.parsers._ +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks + +class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenPropertyChecks with ir.IRHelpers { + + "TSqlFunctionBuilder" should "return correct arity for each function" in { + val functions = Table( + ("functionName", "expectedArity"), // Header + + // TSQL specific + ("@@CURSOR_ROWS", Some(FunctionDefinition.notConvertible(0))), + ("@@FETCH_STATUS", Some(FunctionDefinition.notConvertible(0))), + ("CUBE", Some(FunctionDefinition.standard(1, Int.MaxValue))), + ("MODIFY", Some(FunctionDefinition.xml(1))), + ("ROLLUP", Some(FunctionDefinition.standard(1, Int.MaxValue)))) + + val functionBuilder = new TSqlFunctionBuilder + forAll(functions) { (functionName: String, expectedArity: Option[FunctionDefinition]) => + functionBuilder.functionDefinition(functionName) shouldEqual expectedArity + } + } + + "SnowFlakeFunctionBuilder" should "return correct arity for each function" in { + val functions = Table( + ("functionName", "expectedArity"), // Header + + // Snowflake specific + ("ADD_MONTHS", Some(FunctionDefinition.standard(2))), + ("ANY_VALUE", Some(FunctionDefinition.standard(1))), + ("APPROX_TOP_K", Some(FunctionDefinition.standard(1, 3))), + ("ARRAYS_OVERLAP", Some(FunctionDefinition.standard(2))), + ("ARRAY_AGG", Some(FunctionDefinition.standard(1))), + ("ARRAY_APPEND", Some(FunctionDefinition.standard(2))), + ("ARRAY_CAT", Some(FunctionDefinition.standard(2))), + ("ARRAY_COMPACT", Some(FunctionDefinition.standard(1))), + ("ARRAY_CONSTRUCT", Some(FunctionDefinition.standard(0, Int.MaxValue))), + ("ARRAY_CONSTRUCT_COMPACT", Some(FunctionDefinition.standard(0, Int.MaxValue))), + ("ARRAY_CONTAINS", Some(FunctionDefinition.standard(2))), + ("ARRAY_DISTINCT", Some(FunctionDefinition.standard(1))), + ("ARRAY_EXCEPT", Some(FunctionDefinition.standard(2))), + ("ARRAY_INSERT", Some(FunctionDefinition.standard(3))), + ("ARRAY_INTERSECTION", Some(FunctionDefinition.standard(2))), + ("ARRAY_POSITION", Some(FunctionDefinition.standard(2))), + ("ARRAY_PREPEND", Some(FunctionDefinition.standard(2))), + ("ARRAY_REMOVE", Some(FunctionDefinition.standard(2))), + ("ARRAY_SIZE", Some(FunctionDefinition.standard(1))), + ("ARRAY_SLICE", Some(FunctionDefinition.standard(3))), + ("ARRAY_TO_STRING", Some(FunctionDefinition.standard(2))), + ("ATAN2", Some(FunctionDefinition.standard(2))), + ("BASE64_DECODE_STRING", Some(FunctionDefinition.standard(1, 2))), + ("BASE64_ENCODE", Some(FunctionDefinition.standard(1, 3))), + ("BITOR_AGG", Some(FunctionDefinition.standard(1))), + ("BOOLAND_AGG", Some(FunctionDefinition.standard(1))), + ("CEIL", Some(FunctionDefinition.standard(1, 2))), + ("COLLATE", Some(FunctionDefinition.standard(2))), + ("COLLATION", Some(FunctionDefinition.standard(1))), + ("CONTAINS", Some(FunctionDefinition.standard(2))), + ("CONVERT_TIMEZONE", Some(FunctionDefinition.standard(2, 3))), + ("CORR", Some(FunctionDefinition.standard(2))), + ("COUNT_IF", Some(FunctionDefinition.standard(1))), + ("CURRENT_DATABASE", Some(FunctionDefinition.standard(0))), + ("CURRENT_TIMESTAMP", Some(FunctionDefinition.standard(0, 1))), + ("DATEDIFF", Some(FunctionDefinition.standard(3))), + ("DATE", Some(FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_DATE")))), + ("DATEFROMPARTS", Some(FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATE_FROM_PARTS")))), + ("DATE_FROM_PARTS", Some(FunctionDefinition.standard(3))), + ("DATE_PART", Some(FunctionDefinition.standard(2))), + ("DATE_TRUNC", Some(FunctionDefinition.standard(2))), + ("DAYNAME", Some(FunctionDefinition.standard(1))), + ("DECODE", Some(FunctionDefinition.standard(3, Int.MaxValue))), + ("DENSE_RANK", Some(FunctionDefinition.standard(0))), + ("DIV0", Some(FunctionDefinition.standard(2))), + ("DIV0NULL", Some(FunctionDefinition.standard(2))), + ("EDITDISTANCE", Some(FunctionDefinition.standard(2, 3))), + ("ENDSWITH", Some(FunctionDefinition.standard(2))), + ("EQUAL_NULL", Some(FunctionDefinition.standard(2))), + ("EXTRACT", Some(FunctionDefinition.standard(2))), + ("FLATTEN", Some(FunctionDefinition.symbolic(Set("INPUT"), Set("PATH", "OUTER", "RECURSIVE", "MODE")))), + ("GET", Some(FunctionDefinition.standard(2))), + ("HASH", Some(FunctionDefinition.standard(1, Int.MaxValue))), + ("HOUR", Some(FunctionDefinition.standard(1))), + ("IFNULL", Some(FunctionDefinition.standard(1, 2))), + ("INITCAP", Some(FunctionDefinition.standard(1, 2))), + ("ISNULL", Some(FunctionDefinition.standard(1))), + ("IS_INTEGER", Some(FunctionDefinition.standard(1))), + ("JSON_EXTRACT_PATH_TEXT", Some(FunctionDefinition.standard(2))), + ("LAST_DAY", Some(FunctionDefinition.standard(1, 2))), + ("LEFT", Some(FunctionDefinition.standard(2))), + ("LPAD", Some(FunctionDefinition.standard(2, 3))), + ("LTRIM", Some(FunctionDefinition.standard(1, 2))), + ("MEDIAN", Some(FunctionDefinition.standard(1))), + ("MINUTE", Some(FunctionDefinition.standard(1))), + ("MOD", Some(FunctionDefinition.standard(2))), + ("MODE", Some(FunctionDefinition.standard(1))), + ("MONTHNAME", Some(FunctionDefinition.standard(1))), + ("NEXT_DAY", Some(FunctionDefinition.standard(2))), + ("NULLIFZERO", Some(FunctionDefinition.standard(1))), + ("NTH_VAlUE", Some(FunctionDefinition.standard(2))), + ("NVL", Some(FunctionDefinition.standard(2).withConversionStrategy(SynonymOf("IFNULL")))), + ("NVL2", Some(FunctionDefinition.standard(3))), + ("OBJECT_CONSTRUCT", Some(FunctionDefinition.standard(0, Int.MaxValue))), + ("OBJECT_KEYS", Some(FunctionDefinition.standard(1))), + ("PARSE_JSON", Some(FunctionDefinition.standard(1))), + ("PARSE_URL", Some(FunctionDefinition.standard(1, 2))), + ("POSITION", Some(FunctionDefinition.standard(2, 3))), + ("RANDOM", Some(FunctionDefinition.standard(0, 1))), + ("RANK", Some(FunctionDefinition.standard(0))), + ("REGEXP_COUNT", Some(FunctionDefinition.standard(2, 4))), + ("REGEXP_INSTR", Some(FunctionDefinition.standard(2, 7))), + ("REGEXP_LIKE", Some(FunctionDefinition.standard(2, 3))), + ("REGEXP_REPLACE", Some(FunctionDefinition.standard(2, 6))), + ("REGEXP_SUBSTR", Some(FunctionDefinition.standard(2, 6))), + ("REGR_INTERCEPT", Some(FunctionDefinition.standard(2))), + ("REGR_R2", Some(FunctionDefinition.standard(2))), + ("REGR_SLOPE", Some(FunctionDefinition.standard(2))), + ("REPEAT", Some(FunctionDefinition.standard(2))), + ("RIGHT", Some(FunctionDefinition.standard(2))), + ("RLIKE", Some(FunctionDefinition.standard(2, 3))), + ("ROUND", Some(FunctionDefinition.standard(1, 3))), + ("RPAD", Some(FunctionDefinition.standard(2, 3))), + ("RTRIM", Some(FunctionDefinition.standard(1, 2))), + ("SECOND", Some(FunctionDefinition.standard(1))), + ("SPLIT_PART", Some(FunctionDefinition.standard(3))), + ("SQUARE", Some(FunctionDefinition.standard(1))), + ("STARTSWITH", Some(FunctionDefinition.standard(2))), + ("STDDEV", Some(FunctionDefinition.standard(1))), + ("STDDEV_POP", Some(FunctionDefinition.standard(1))), + ("STDDEV_SAMP", Some(FunctionDefinition.standard(1))), + ("STRIP_NULL_VALUE", Some(FunctionDefinition.standard(1))), + ("STRTOK", Some(FunctionDefinition.standard(1, 3))), + ("STRTOK_TO_ARRAY", Some(FunctionDefinition.standard(1, 2))), + ("SYSDATE", Some(FunctionDefinition.standard(0))), + ("TIME", Some(FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_TIME")))), + ("TIMEADD", Some(FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATEADD")))), + ("TIMESTAMPADD", Some(FunctionDefinition.standard(3))), + ("TIMESTAMPDIFF", Some(FunctionDefinition.standard(3).withConversionStrategy(SynonymOf("DATEDIFF")))), + ("TIMESTAMP_FROM_PARTS", Some(FunctionDefinition.standard(2, 8))), + ("TO_ARRAY", Some(FunctionDefinition.standard(1, 2))), + ("TO_BOOLEAN", Some(FunctionDefinition.standard(1))), + ("TO_CHAR", Some(FunctionDefinition.standard(1, 2).withConversionStrategy(SynonymOf("TO_VARCHAR")))), + ("TO_DATE", Some(FunctionDefinition.standard(1, 2))), + ("TO_DECIMAL", Some(FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TO_NUMBER")))), + ("TO_DOUBLE", Some(FunctionDefinition.standard(1, 2))), + ("TO_JSON", Some(FunctionDefinition.standard(1))), + ("TO_NUMBER", Some(FunctionDefinition.standard(1, 4))), + ("TO_NUMERIC", Some(FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TO_NUMBER")))), + ("TO_OBJECT", Some(FunctionDefinition.standard(1))), + ("TO_TIME", Some(FunctionDefinition.standard(1, 2))), + ("TO_TIMESTAMP", Some(FunctionDefinition.standard(1, 2))), + ("TO_TIMESTAMP_LTZ", Some(FunctionDefinition.standard(1, 2))), + ("TO_TIMESTAMP_NTZ", Some(FunctionDefinition.standard(1, 2))), + ("TO_TIMESTAMP_TZ", Some(FunctionDefinition.standard(1, 2))), + ("TO_VARCHAR", Some(FunctionDefinition.standard(1, 2))), + ("TO_VARIANT", Some(FunctionDefinition.standard(1))), + ("TRIM", Some(FunctionDefinition.standard(1, 2))), + ("TRUNC", Some(FunctionDefinition.standard(2))), + ("TRY_BASE64_DECODE_STRING", Some(FunctionDefinition.standard(1, 2))), + ("TRY_PARSE_JSON", Some(FunctionDefinition.standard(1))), + ("TRY_TO_BINARY", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_BOOLEAN", Some(FunctionDefinition.standard(1))), + ("TRY_TO_DATE", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_DECIMAL", Some(FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TRY_TO_NUMBER")))), + ("TRY_TO_DOUBLE", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_NUMBER", Some(FunctionDefinition.standard(1, 4))), + ("TRY_TO_NUMERIC", Some(FunctionDefinition.standard(1, 4).withConversionStrategy(SynonymOf("TRY_TO_NUMBER")))), + ("TRY_TO_TIME", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_TIMESTAMP", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_TIMESTAMP_LTZ", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_TIMESTAMP_NTZ", Some(FunctionDefinition.standard(1, 2))), + ("TRY_TO_TIMESTAMP_TZ", Some(FunctionDefinition.standard(1, 2))), + ("TYPEOF", Some(FunctionDefinition.standard(1))), + ("UUID_STRING", Some(FunctionDefinition.standard(0, 2))), + ("ZEROIFNULL", Some(FunctionDefinition.standard(1)))) + + val functionBuilder = new SnowflakeFunctionBuilder + forAll(functions) { (functionName: String, expectedArity: Option[FunctionDefinition]) => + functionBuilder.functionDefinition(functionName) shouldEqual expectedArity + } + } + + "FunctionBuilder" should "return correct arity for each function" in { + + val functions = Table( + ("functionName", "expectedArity"), // Header + + ("ABS", FunctionDefinition.standard(1)), + ("ACOS", FunctionDefinition.standard(1)), + ("APPROX_COUNT_DISTINCT", FunctionDefinition.standard(1)), + ("APPROX_PERCENTILE", FunctionDefinition.standard(2)), + ("APPROX_PERCENTILE_CONT", FunctionDefinition.standard(1)), + ("APPROX_PERCENTILE_DISC", FunctionDefinition.standard(1)), + ("APP_NAME", FunctionDefinition.standard(0)), + ("APPLOCK_MODE", FunctionDefinition.standard(3)), + ("APPLOCK_TEST", FunctionDefinition.standard(4)), + ("ASCII", FunctionDefinition.standard(1)), + ("ASIN", FunctionDefinition.standard(1)), + ("ASSEMBLYPROPERTY", FunctionDefinition.standard(2)), + ("ATAN", FunctionDefinition.standard(1)), + ("ATN2", FunctionDefinition.standard(2)), + ("AVG", FunctionDefinition.standard(1)), + ("BINARY_CHECKSUM", FunctionDefinition.standard(1, Int.MaxValue)), + ("BIT_COUNT", FunctionDefinition.standard(1)), + ("CEILING", FunctionDefinition.standard(1)), + ("CERT_ID", FunctionDefinition.standard(1)), + ("CERTENCODED", FunctionDefinition.standard(1)), + ("CERTPRIVATEKEY", FunctionDefinition.standard(2, 3)), + ("CHAR", FunctionDefinition.standard(1)), + ("CHARINDEX", FunctionDefinition.standard(2, 3)), + ("CHECKSUM", FunctionDefinition.standard(2, Int.MaxValue)), + ("CHECKSUM_AGG", FunctionDefinition.standard(1)), + ("COALESCE", FunctionDefinition.standard(1, Int.MaxValue)), + ("COL_LENGTH", FunctionDefinition.standard(2)), + ("COL_NAME", FunctionDefinition.standard(2)), + ("COLUMNPROPERTY", FunctionDefinition.standard(3)), + ("COMPRESS", FunctionDefinition.standard(1)), + ("CONCAT", FunctionDefinition.standard(2, Int.MaxValue)), + ("CONCAT_WS", FunctionDefinition.standard(3, Int.MaxValue)), + ("CONNECTIONPROPERTY", FunctionDefinition.notConvertible(1)), + ("CONTEXT_INFO", FunctionDefinition.standard(0)), + ("CONVERT", FunctionDefinition.standard(2, 3)), + ("COS", FunctionDefinition.standard(1)), + ("COT", FunctionDefinition.standard(1)), + ("COUNT", FunctionDefinition.standard(1)), + ("COUNT_BIG", FunctionDefinition.standard(1)), + ("CUME_DIST", FunctionDefinition.standard(0)), + ("CURRENT_DATE", FunctionDefinition.standard(0)), + ("CURRENT_REQUEST_ID", FunctionDefinition.standard(0)), + ("CURRENT_TIMESTAMP", FunctionDefinition.standard(0)), + ("CURRENT_TIMEZONE", FunctionDefinition.standard(0)), + ("CURRENT_TIMEZONE_ID", FunctionDefinition.standard(0)), + ("CURRENT_TRANSACTION_ID", FunctionDefinition.standard(0)), + ("CURRENT_USER", FunctionDefinition.standard(0)), + ("CURSOR_ROWS", FunctionDefinition.standard(0)), + ("CURSOR_STATUS", FunctionDefinition.standard(2)), + ("DATABASE_PRINCIPAL_ID", FunctionDefinition.standard(0, 1)), + ("DATABASEPROPERTY", FunctionDefinition.standard(2)), + ("DATABASEPROPERTYEX", FunctionDefinition.standard(2)), + ("DATALENGTH", FunctionDefinition.standard(1)), + ("DATE_BUCKET", FunctionDefinition.standard(3, 4)), + ("DATE_DIFF_BIG", FunctionDefinition.standard(3)), + ("DATEADD", FunctionDefinition.standard(3)), + ("DATEDIFF", FunctionDefinition.standard(3)), + ("DATEFROMPARTS", FunctionDefinition.standard(3)), + ("DATENAME", FunctionDefinition.standard(2)), + ("DATEPART", FunctionDefinition.standard(2)), + ("DATETIME2FROMPARTS", FunctionDefinition.standard(8)), + ("DATETIMEFROMPARTS", FunctionDefinition.standard(7)), + ("DATETIMEOFFSETFROMPARTS", FunctionDefinition.standard(10)), + ("DATETRUNC", FunctionDefinition.standard(2)), + ("DAY", FunctionDefinition.standard(1)), + ("DB_ID", FunctionDefinition.standard(0, 1)), + ("DB_NAME", FunctionDefinition.standard(0, 1)), + ("DECOMPRESS", FunctionDefinition.standard(1)), + ("DEGREES", FunctionDefinition.standard(1)), + ("DENSE_RANK", FunctionDefinition.standard(0)), + ("DIFFERENCE", FunctionDefinition.standard(2)), + ("EOMONTH", FunctionDefinition.standard(1, 2)), + ("ERROR_LINE", FunctionDefinition.standard(0)), + ("ERROR_MESSAGE", FunctionDefinition.standard(0)), + ("ERROR_NUMBER", FunctionDefinition.standard(0)), + ("ERROR_PROCEDURE", FunctionDefinition.standard(0)), + ("ERROR_SEVERITY", FunctionDefinition.standard(0)), + ("ERROR_STATE", FunctionDefinition.standard(0)), + ("EXIST", FunctionDefinition.xml(1)), + ("EXP", FunctionDefinition.standard(1)), + ("FILE_ID", FunctionDefinition.standard(1)), + ("FILE_IDEX", FunctionDefinition.standard(1)), + ("FILE_NAME", FunctionDefinition.standard(1)), + ("FILEGROUP_ID", FunctionDefinition.standard(1)), + ("FILEGROUP_NAME", FunctionDefinition.standard(1)), + ("FILEGROUPPROPERTY", FunctionDefinition.standard(2)), + ("FILEPROPERTY", FunctionDefinition.standard(2)), + ("FILEPROPERTYEX", FunctionDefinition.standard(2)), + ("FIRST_VALUE", FunctionDefinition.standard(1)), + ("FLOOR", FunctionDefinition.standard(1)), + ("FORMAT", FunctionDefinition.standard(2, 3)), + ("FORMATMESSAGE", FunctionDefinition.standard(2, Int.MaxValue)), + ("FULLTEXTCATALOGPROPERTY", FunctionDefinition.standard(2)), + ("FULLTEXTSERVICEPROPERTY", FunctionDefinition.standard(1)), + ("GET_FILESTREAM_TRANSACTION_CONTEXT", FunctionDefinition.standard(0)), + ("GETANCESTGOR", FunctionDefinition.standard(1)), + ("GETANSINULL", FunctionDefinition.standard(0, 1)), + ("GETDATE", FunctionDefinition.standard(0)), + ("GETDESCENDANT", FunctionDefinition.standard(2)), + ("GETLEVEL", FunctionDefinition.standard(0)), + ("GETREPARENTEDVALUE", FunctionDefinition.standard(2)), + ("GETUTCDATE", FunctionDefinition.standard(0)), + ("GREATEST", FunctionDefinition.standard(1, Int.MaxValue)), + ("GROUPING", FunctionDefinition.standard(1)), + ("GROUPING_ID", FunctionDefinition.standard(0, Int.MaxValue)), + ("HAS_DBACCESS", FunctionDefinition.standard(1)), + ("HAS_PERMS_BY_NAME", FunctionDefinition.standard(4, 5)), + ("HOST_ID", FunctionDefinition.standard(0)), + ("HOST_NAME", FunctionDefinition.standard(0)), + ("IDENT_CURRENT", FunctionDefinition.standard(1)), + ("IDENT_INCR", FunctionDefinition.standard(1)), + ("IDENT_SEED", FunctionDefinition.standard(1)), + ("IDENTITY", FunctionDefinition.standard(1, 3)), + ("IFF", FunctionDefinition.standard(3)), + ("INDEX_COL", FunctionDefinition.standard(3)), + ("INDEXKEY_PROPERTY", FunctionDefinition.standard(3)), + ("INDEXPROPERTY", FunctionDefinition.standard(3)), + ("IS_MEMBER", FunctionDefinition.standard(1)), + ("IS_ROLEMEMBER", FunctionDefinition.standard(1, 2)), + ("IS_SRVROLEMEMBER", FunctionDefinition.standard(1, 2)), + ("ISDATE", FunctionDefinition.standard(1)), + ("ISDESCENDANTOF", FunctionDefinition.standard(1)), + ("ISJSON", FunctionDefinition.standard(1, 2)), + ("ISNUMERIC", FunctionDefinition.standard(1)), + ("JSON_MODIFY", FunctionDefinition.standard(3)), + ("JSON_PATH_EXISTS", FunctionDefinition.standard(2)), + ("JSON_QUERY", FunctionDefinition.standard(2)), + ("JSON_VALUE", FunctionDefinition.standard(2)), + ("LAG", FunctionDefinition.standard(1, 3)), + ("LAST_VALUE", FunctionDefinition.standard(1)), + ("LEAD", FunctionDefinition.standard(1, 3)), + ("LEAST", FunctionDefinition.standard(1, Int.MaxValue)), + ("LEFT", FunctionDefinition.standard(2)), + ("LEN", FunctionDefinition.standard(1)), + ("LOG", FunctionDefinition.standard(1, 2)), + ("LOG10", FunctionDefinition.standard(1)), + ("LOGINPROPERTY", FunctionDefinition.standard(2)), + ("LOWER", FunctionDefinition.standard(1)), + ("LTRIM", FunctionDefinition.standard(1)), + ("MAX", FunctionDefinition.standard(1)), + ("MIN", FunctionDefinition.standard(1)), + ("MIN_ACTIVE_ROWVERSION", FunctionDefinition.standard(0)), + ("MONTH", FunctionDefinition.standard(1)), + ("NCHAR", FunctionDefinition.standard(1)), + ("NEWID", FunctionDefinition.standard(0)), + ("NEWSEQUENTIALID", FunctionDefinition.standard(0)), + ("NODES", FunctionDefinition.xml(1)), + ("NTILE", FunctionDefinition.standard(1)), + ("NULLIF", FunctionDefinition.standard(2)), + ("OBJECT_DEFINITION", FunctionDefinition.standard(1)), + ("OBJECT_ID", FunctionDefinition.standard(1, 2)), + ("OBJECT_NAME", FunctionDefinition.standard(1, 2)), + ("OBJECT_SCHEMA_NAME", FunctionDefinition.standard(1, 2)), + ("OBJECTPROPERTY", FunctionDefinition.standard(2)), + ("OBJECTPROPERTYEX", FunctionDefinition.standard(2)), + ("ORIGINAL_DB_NAME", FunctionDefinition.standard(0)), + ("ORIGINAL_LOGIN", FunctionDefinition.standard(0)), + ("PARSE", FunctionDefinition.notConvertible(2, 3)), + ("PARSENAME", FunctionDefinition.standard(2)), + ("PATINDEX", FunctionDefinition.standard(2)), + ("PERCENTILE_CONT", FunctionDefinition.standard(1)), + ("PERCENTILE_DISC", FunctionDefinition.standard(1)), + ("PERMISSIONS", FunctionDefinition.notConvertible(0, 2)), + ("PI", FunctionDefinition.standard(0)), + ("POWER", FunctionDefinition.standard(2)), + ("PWDCOMPARE", FunctionDefinition.standard(2, 3)), + ("PWDENCRYPT", FunctionDefinition.standard(1)), + ("QUERY", FunctionDefinition.xml(1)), + ("QUOTENAME", FunctionDefinition.standard(1, 2)), + ("RADIANS", FunctionDefinition.standard(1)), + ("RAND", FunctionDefinition.standard(0, 1)), + ("RANK", FunctionDefinition.standard(0)), + ("REPLACE", FunctionDefinition.standard(3)), + ("REPLICATE", FunctionDefinition.standard(2)), + ("REVERSE", FunctionDefinition.standard(1)), + ("RIGHT", FunctionDefinition.standard(2)), + ("ROUND", FunctionDefinition.standard(1, 3)), + ("ROW_NUMBER", FunctionDefinition.standard(0)), + ("ROWCOUNT_BIG", FunctionDefinition.standard(0)), + ("RTRIM", FunctionDefinition.standard(1)), + ("SCHEMA_ID", FunctionDefinition.standard(0, 1)), + ("SCHEMA_NAME", FunctionDefinition.standard(0, 1)), + ("SCOPE_IDENTITY", FunctionDefinition.standard(0)), + ("SERVERPROPERTY", FunctionDefinition.standard(1)), + ("SESSION_CONTEXT", FunctionDefinition.standard(1, 2)), + ("SESSION_USER", FunctionDefinition.standard(0)), + ("SESSIONPROPERTY", FunctionDefinition.standard(1)), + ("SIGN", FunctionDefinition.standard(1)), + ("SIN", FunctionDefinition.standard(1)), + ("SMALLDATETIMEFROMPARTS", FunctionDefinition.standard(5)), + ("SOUNDEX", FunctionDefinition.standard(1)), + ("SPACE", FunctionDefinition.standard(1)), + ("SQL_VARIANT_PROPERTY", FunctionDefinition.standard(2)), + ("SQRT", FunctionDefinition.standard(1)), + ("SQUARE", FunctionDefinition.standard(1)), + ("STATS_DATE", FunctionDefinition.standard(2)), + ("STDEV", FunctionDefinition.standard(1)), + ("STDEVP", FunctionDefinition.standard(1)), + ("STR", FunctionDefinition.standard(1, 3)), + ("STRING_AGG", FunctionDefinition.standard(2, 3)), + ("STRING_ESCAPE", FunctionDefinition.standard(2)), + ("STUFF", FunctionDefinition.standard(4)), + ("SUBSTR", FunctionDefinition.standard(2, 3)), + ("SUBSTRING", FunctionDefinition.standard(2, 3)), + ("SUM", FunctionDefinition.standard(1)), + ("SUSER_ID", FunctionDefinition.standard(0, 1)), + ("SUSER_NAME", FunctionDefinition.standard(0, 1)), + ("SUSER_SID", FunctionDefinition.standard(0, 2)), + ("SUSER_SNAME", FunctionDefinition.standard(0, 1)), + ("SWITCHOFFSET", FunctionDefinition.standard(2)), + ("SYSDATETIME", FunctionDefinition.standard(0)), + ("SYSDATETIMEOFFSET", FunctionDefinition.standard(0)), + ("SYSTEM_USER", FunctionDefinition.standard(0)), + ("SYSUTCDATETIME", FunctionDefinition.standard(0)), + ("TAN", FunctionDefinition.standard(1)), + ("TIMEFROMPARTS", FunctionDefinition.standard(5)), + ("TODATETIMEOFFSET", FunctionDefinition.standard(2)), + ("TOSTRING", FunctionDefinition.standard(0)), + ("TRANSLATE", FunctionDefinition.standard(3)), + ("TRIM", FunctionDefinition.standard(1, 2)), + ("TYPE_ID", FunctionDefinition.standard(1)), + ("TYPE_NAME", FunctionDefinition.standard(1)), + ("TYPEPROPERTY", FunctionDefinition.standard(2)), + ("UNICODE", FunctionDefinition.standard(1)), + ("UPPER", FunctionDefinition.standard(1)), + ("USER", FunctionDefinition.standard(0)), + ("USER_ID", FunctionDefinition.standard(0, 1)), + ("USER_NAME", FunctionDefinition.standard(0, 1)), + ("VALUE", FunctionDefinition.xml(2)), + ("VAR", FunctionDefinition.standard(1)), + ("VARP", FunctionDefinition.standard(1)), + ("XACT_STATE", FunctionDefinition.standard(0)), + ("YEAR", FunctionDefinition.standard(1))) + + val functionBuilder = new TSqlFunctionBuilder + + forAll(functions) { (functionName: String, expectedArity: FunctionDefinition) => + { + val actual = functionBuilder.functionDefinition(functionName) + actual.nonEmpty shouldEqual true + actual.map(_.arity) shouldEqual Some(expectedArity.arity) + } + } + } + + "functionType" should "return correct function type for each TSQL function" in { + val functions = Table( + ("functionName", "expectedFunctionType"), // Header + + // This table needs to maintain an example of all types of functions in Arity and FunctionType + // However, it is not necessary to test all functions, just one of each type. + // However, note that there are no XML functions with VariableArity at the moment. + ("MODIFY", XmlFunction), + ("ABS", StandardFunction), + ("VALUE", XmlFunction), + ("CONCAT", StandardFunction), + ("TRIM", StandardFunction)) + + // Test all FunctionType combinations that currently exist + val functionBuilder = new TSqlFunctionBuilder + + forAll(functions) { (functionName: String, expectedFunctionType: FunctionType) => + { + functionBuilder.functionType(functionName) shouldEqual expectedFunctionType + } + } + } + + "functionType" should "return UnknownFunction for tsql functions that do not exist" in { + val functionBuilder = new TSqlFunctionBuilder + val result = functionBuilder.functionType("DOES_NOT_EXIST") + result shouldEqual UnknownFunction + } + + "isConvertible method in FunctionArity" should "return true by default" in { + val fixedArity = FunctionDefinition.standard(1) + fixedArity.functionType should not be NotConvertibleFunction + + val variableArity = FunctionDefinition.standard(1, 2) + variableArity should not be NotConvertibleFunction + + } + + "buildFunction" should "remove quotes and brackets from function names" in { + val functionBuilder = new TSqlFunctionBuilder + + val quoteTable = Table( + ("functionName", "expectedFunctionName"), // Header + + ("a", "a"), // Test function name with less than 2 characters + ("'quoted'", "quoted"), // Test function name with matching quotes + ("[bracketed]", "bracketed"), // Test function name with matching brackets + ("\\backslashed\\", "backslashed"), // Test function name with matching backslashes + ("\"doublequoted\"", "doublequoted") // Test function name with non-matching quotes + ) + forAll(quoteTable) { (functionName: String, expectedFunctionName: String) => + { + val r = functionBuilder.buildFunction(functionName, List.empty) + r match { + case f: ir.UnresolvedFunction => f.function_name shouldBe expectedFunctionName + case _ => fail("Unexpected function type") + } + } + } + } + + "buildFunction" should "Apply known TSql conversion strategies" in { + val functionBuilder = new TSqlFunctionBuilder + + val renameTable = Table( + ("functionName", "params", "expectedFunctionName"), // Header + + ("ISNULL", Seq(simplyNamedColumn("x"), ir.Literal(1)), "IFNULL"), + ("GET_BIT", Seq(simplyNamedColumn("x"), ir.Literal(1)), "GETBIT"), + ("left_SHIFT", Seq(simplyNamedColumn("x"), ir.Literal(1)), "LEFTSHIFT"), + ("RIGHT_SHIFT", Seq(simplyNamedColumn("x"), ir.Literal(1)), "RIGHTSHIFT")) + + forAll(renameTable) { (functionName: String, params: Seq[ir.Expression], expectedFunctionName: String) => + { + val r = functionBuilder.buildFunction(functionName, params) + r match { + case f: ir.CallFunction => f.function_name shouldBe expectedFunctionName + case _ => fail("Unexpected function type") + } + } + } + } + + "buildFunction" should "not resolve IFNULL when child dialect isn't TSql" in { + val functionBuilder = new SnowflakeFunctionBuilder + + val result1 = functionBuilder.buildFunction("ISNULL", Seq(simplyNamedColumn("x"), ir.Literal(0))) + result1 shouldBe a[ir.UnresolvedFunction] + } + + "buildFunction" should "not resolve IFNULL when child dialect isn't Snowflake" in { + val functionBuilder = new TSqlFunctionBuilder + + val result1 = functionBuilder.buildFunction("IFNULL", Seq(simplyNamedColumn("x"), ir.Literal(0))) + result1 shouldBe a[ir.UnresolvedFunction] + } + + "buildFunction" should "Should preserve case if it can" in { + val functionBuilder = new TSqlFunctionBuilder + val result1 = functionBuilder.buildFunction("isnull", Seq(simplyNamedColumn("x"), ir.Literal(0))) + result1 match { + case f: ir.CallFunction => f.function_name shouldBe "ifnull" + case _ => fail("ifnull conversion failed") + } + } + + "FunctionRename strategy" should "preserve original function if no match is found" in { + val functionBuilder = new TSqlFunctionBuilder + val result1 = functionBuilder.applyConversionStrategy(FunctionDefinition.standard(1), Seq(ir.Literal(66)), "Abs") + result1 match { + case f: ir.CallFunction => f.function_name shouldBe "Abs" + case _ => fail("UNKNOWN_FUNCTION conversion failed") + } + } + + "FunctionArity.verifyArguments" should "return true when arity is fixed and provided number of arguments matches" in { + FunctionArity.verifyArguments(FixedArity(0), Seq()) shouldBe true + FunctionArity.verifyArguments(FixedArity(1), Seq(ir.Noop)) shouldBe true + FunctionArity.verifyArguments(FixedArity(2), Seq(ir.Noop, ir.Noop)) shouldBe true + } + + "FunctionArity.verifyArguments" should + "return true when arity is varying and provided number of arguments matches" in { + FunctionArity.verifyArguments(VariableArity(0, 2), Seq()) shouldBe true + FunctionArity.verifyArguments(VariableArity(0, 2), Seq(ir.Noop)) shouldBe true + FunctionArity.verifyArguments(VariableArity(0, 2), Seq(ir.Noop, ir.Noop)) shouldBe true + } + + "FunctionArity.verifyArguments" should "return true when arity is symbolic and arguments are provided named" in { + val arity = SymbolicArity(Set("req1", "REQ2"), Set("opt1", "opt2", "opt3")) + FunctionArity.verifyArguments( + arity, + Seq(NamedArgumentExpression("Req2", ir.Noop), snowflake.NamedArgumentExpression("REQ1", ir.Noop))) shouldBe true + + FunctionArity.verifyArguments( + arity, + Seq( + snowflake.NamedArgumentExpression("Req2", ir.Noop), + snowflake.NamedArgumentExpression("OPT1", ir.Noop), + snowflake.NamedArgumentExpression("REQ1", ir.Noop))) shouldBe true + + FunctionArity.verifyArguments( + arity, + Seq( + snowflake.NamedArgumentExpression("Req2", ir.Noop), + snowflake.NamedArgumentExpression("OPT1", ir.Noop), + snowflake.NamedArgumentExpression("OPT3", ir.Noop), + snowflake.NamedArgumentExpression("OPT2", ir.Noop), + snowflake.NamedArgumentExpression("REQ1", ir.Noop))) shouldBe true + } + + "FunctionArity.verifyArguments" should "return true when arity is symbolic and arguments are provided unnamed" in { + val arity = SymbolicArity(Set("req1", "REQ2"), Set("opt1", "opt2", "opt3")) + + FunctionArity.verifyArguments(arity, Seq( /*REQ1*/ ir.Noop, /*REQ2*/ ir.Noop)) shouldBe true + FunctionArity.verifyArguments(arity, Seq( /*REQ1*/ ir.Noop, /*REQ2*/ ir.Noop, /*OPT1*/ ir.Noop)) shouldBe true + FunctionArity.verifyArguments( + arity, + Seq( /*REQ1*/ ir.Noop, /*REQ2*/ ir.Noop, /*OPT1*/ ir.Noop, /*OPT2*/ ir.Noop)) shouldBe true + FunctionArity.verifyArguments( + arity, + Seq( /*REQ1*/ ir.Noop, /*REQ2*/ ir.Noop, /*OPT1*/ ir.Noop, /*OPT2*/ ir.Noop, /*OPT3*/ ir.Noop)) shouldBe true + } + + "FunctionArity.verifyArguments" should "return false otherwise" in { + // not enough arguments + FunctionArity.verifyArguments(FixedArity(1), Seq.empty) shouldBe false + FunctionArity.verifyArguments(VariableArity(2, 3), Seq(ir.Noop)) shouldBe false + FunctionArity.verifyArguments( + SymbolicArity(Set("req1", "req2"), Set.empty), + Seq(snowflake.NamedArgumentExpression("REQ2", ir.Noop))) shouldBe false + FunctionArity.verifyArguments(SymbolicArity(Set("req1", "req2"), Set.empty), Seq(ir.Noop)) shouldBe false + + // too many arguments + FunctionArity.verifyArguments(FixedArity(0), Seq(ir.Noop)) shouldBe false + FunctionArity.verifyArguments(VariableArity(0, 1), Seq(ir.Noop, ir.Noop)) shouldBe false + FunctionArity.verifyArguments( + SymbolicArity(Set("req1", "req2"), Set.empty), + Seq(ir.Noop, ir.Noop, ir.Noop)) shouldBe false + + // wrongly named arguments + FunctionArity.verifyArguments( + SymbolicArity(Set("req1"), Set("opt1")), + Seq( + snowflake.NamedArgumentExpression("REQ2", ir.Noop), + snowflake.NamedArgumentExpression("REQ1", ir.Noop))) shouldBe false + + // mix of named and unnamed arguments + FunctionArity.verifyArguments( + SymbolicArity(Set("REQ"), Set("OPT")), + Seq(ir.Noop, snowflake.NamedArgumentExpression("OPT", ir.Noop))) shouldBe false + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowLexerSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowLexerSpec.scala new file mode 100644 index 0000000000..7e9f66b400 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowLexerSpec.scala @@ -0,0 +1,344 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import org.antlr.v4.runtime.{CharStreams, Token} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.wordspec.AnyWordSpec + +class SnowLexerSpec extends AnyWordSpec with Matchers with TableDrivenPropertyChecks { + + private[this] val lexer = new SnowflakeLexer(null) + + private def fillTokens(input: String): List[Token] = { + val inputString = CharStreams.fromString(input) + lexer.setInputStream(inputString) + Iterator.continually(lexer.nextToken()).takeWhile(_.getType != Token.EOF).toList + } + + private def dumpTokens(tokens: List[Token]): Unit = { + tokens.foreach { t => + // scalastyle:off println + val name = lexer.getVocabulary.getDisplayName(t.getType).padTo(32, ' ') + println(s"${name}(${t.getType}) -->${t.getText}<--") + // scalastyle:on println + } + } + + // TODO: Expand this test to cover all token types, and maybe all tokens + "Snowflake Lexer" should { + "scan string literals and ids" in { + + val testInput = Table( + ("child", "expected"), // Headers + + (""""quoted""id"""", SnowflakeLexer.DOUBLE_QUOTE_ID), + ("\"quote\"\"andunquote\"\"\"", SnowflakeLexer.DOUBLE_QUOTE_ID), + ("identifier_1", SnowflakeLexer.ID), + ("SELECT", SnowflakeLexer.SELECT), + ("FROM", SnowflakeLexer.FROM), + ("WHERE", SnowflakeLexer.WHERE), + ("GROUP", SnowflakeLexer.GROUP), + ("BY", SnowflakeLexer.BY), + ("HAVING", SnowflakeLexer.HAVING), + ("ORDER", SnowflakeLexer.ORDER), + ("LIMIT", SnowflakeLexer.LIMIT), + ("UNION", SnowflakeLexer.UNION), + ("ALL", SnowflakeLexer.ALL), + ("INTERSECT", SnowflakeLexer.INTERSECT), + ("INSERT", SnowflakeLexer.INSERT), + ("EXCEPT", SnowflakeLexer.EXCEPT), + ("-", SnowflakeLexer.MINUS), + ("+", SnowflakeLexer.PLUS), + ("42", SnowflakeLexer.INT), + ("42.42", SnowflakeLexer.FLOAT), + ("42E4", SnowflakeLexer.REAL), + ("!=", SnowflakeLexer.NE), + ("*", SnowflakeLexer.STAR), + ("/", SnowflakeLexer.DIVIDE), + ("!ABORT", SnowflakeLexer.SQLCOMMAND), + ("$parameter", SnowflakeLexer.LOCAL_ID), + ("$ids", SnowflakeLexer.LOCAL_ID), + ("=", SnowflakeLexer.EQ), + ("!=", SnowflakeLexer.NE), + (">", SnowflakeLexer.GT), + ("<", SnowflakeLexer.LT), + (">=", SnowflakeLexer.GE), + ("<=", SnowflakeLexer.LE), + ("(", SnowflakeLexer.LPAREN), + (")", SnowflakeLexer.RPAREN), + (",", SnowflakeLexer.COMMA), + (";", SnowflakeLexer.SEMI), + ("[", SnowflakeLexer.LSB), + ("]", SnowflakeLexer.RSB), + (":", SnowflakeLexer.COLON), + ("::", SnowflakeLexer.COLON_COLON), + ("_!Jinja0001", SnowflakeLexer.JINJA_REF)) + + forAll(testInput) { (input: String, expectedType: Int) => + val inputString = CharStreams.fromString(input) + + lexer.setInputStream(inputString) + val tok: Token = lexer.nextToken() + tok.getType shouldBe expectedType + tok.getText shouldBe input + } + } + + "scan string literals with escaped solidus" in { + val tok = fillTokens("""'\\'""") + dumpTokens(tok) + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_ESCAPE + tok(1).getText shouldBe """\\""" + + tok(2).getType shouldBe SnowflakeLexer.STRING_END + tok(2).getText shouldBe "'" + + } + + "scan string literals with double single quotes" in { + val tok = fillTokens("'And it''s raining'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(1).getText shouldBe "And it" + + tok(2).getType shouldBe SnowflakeLexer.STRING_SQUOTE + tok(2).getText shouldBe "''" + + tok(3).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(3).getText shouldBe "s raining" + + tok(4).getType shouldBe SnowflakeLexer.STRING_END + tok(4).getText shouldBe "'" + } + + "scan string literals with an escaped character" in { + val tok = fillTokens("""'Tab\oir'""") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(1).getText shouldBe "Tab" + + tok(2).getType shouldBe SnowflakeLexer.STRING_ESCAPE + tok(2).getText shouldBe "\\o" + + tok(3).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(3).getText shouldBe "ir" + + tok(4).getType shouldBe SnowflakeLexer.STRING_END + tok(4).getText shouldBe "'" + } + + "scan string literals with an escaped single quote" in { + val tok = fillTokens("""'Tab\'oir'""") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(1).getText shouldBe "Tab" + + tok(2).getType shouldBe SnowflakeLexer.STRING_ESCAPE + tok(2).getText shouldBe "\\'" + + tok(3).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(3).getText shouldBe "oir" + + tok(4).getType shouldBe SnowflakeLexer.STRING_END + tok(4).getText shouldBe "'" + } + + "scan string literals with an embedded Unicode escape" in { + val tok = fillTokens("'Tab\\" + "uAcDcbaT'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(1).getText shouldBe "Tab" + + tok(2).getType shouldBe SnowflakeLexer.STRING_UNICODE + tok(2).getText shouldBe "\\uAcDc" + + tok(3).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(3).getText shouldBe "baT" + + tok(4).getType shouldBe SnowflakeLexer.STRING_END + tok(4).getText shouldBe "'" + } + + "scan simple &variables" in { + val tok = fillTokens("&leeds") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.AMP + tok.head.getText shouldBe "&" + + tok(1).getType shouldBe SnowflakeLexer.ID + tok(1).getText shouldBe "leeds" + } + + "scan simple consecutive &variables" in { + val tok = fillTokens("&leeds&manchester") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.AMP + tok.head.getText shouldBe "&" + + tok(1).getType shouldBe SnowflakeLexer.ID + tok(1).getText shouldBe "leeds" + + tok(2).getType shouldBe SnowflakeLexer.AMP + tok(2).getText shouldBe "&" + + tok(3).getType shouldBe SnowflakeLexer.ID + tok(3).getText shouldBe "manchester" + } + + "scan simple &variables within composite variables" in { + lexer.setInputStream(CharStreams.fromString("&leeds.&manchester")) + val tok = fillTokens("&leeds.&manchester") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.AMP + tok.head.getText shouldBe "&" + + tok(1).getType shouldBe SnowflakeLexer.ID + tok(1).getText shouldBe "leeds" + + tok(2).getType shouldBe SnowflakeLexer.DOT + tok(2).getText shouldBe "." + + tok(3).getType shouldBe SnowflakeLexer.AMP + tok(3).getText shouldBe "&" + + tok(4).getType shouldBe SnowflakeLexer.ID + tok(4).getText shouldBe "manchester" + } + + "scan && in a string" in { + val tok = fillTokens("'&¬AVar'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_AMPAMP + tok(1).getText shouldBe "&&" + + tok(2).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(2).getText shouldBe "notAVar" + + tok(3).getType shouldBe SnowflakeLexer.STRING_END + tok(3).getText shouldBe "'" + } + + "scan && in a string with {}" in { + val tok = fillTokens("'&&{notAVar}'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_AMPAMP + tok(1).getText shouldBe "&&" + + tok(2).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(2).getText shouldBe "{notAVar}" + + tok(3).getType shouldBe SnowflakeLexer.STRING_END + tok(3).getText shouldBe "'" + } + "scan &variables in a string" in { + val tok = fillTokens("'&leeds'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.VAR_SIMPLE + tok(1).getText shouldBe "&leeds" + + tok(2).getType shouldBe SnowflakeLexer.STRING_END + tok(2).getText shouldBe "'" + } + + "scan consecutive &variables in a string" in { + val tok = fillTokens("'&leeds&{united}'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.VAR_SIMPLE + tok(1).getText shouldBe "&leeds" + + tok(2).getType shouldBe SnowflakeLexer.VAR_COMPLEX + tok(2).getText shouldBe "&{united}" + + tok(3).getType shouldBe SnowflakeLexer.STRING_END + tok(3).getText shouldBe "'" + } + + "scan &variables separated by && in a string" in { + val tok = fillTokens("'&leeds&&&united'") + dumpTokens(tok) + + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.VAR_SIMPLE + tok(1).getText shouldBe "&leeds" + + tok(2).getType shouldBe SnowflakeLexer.STRING_AMPAMP + tok(2).getText shouldBe "&&" + + tok(3).getType shouldBe SnowflakeLexer.VAR_SIMPLE + tok(3).getText shouldBe "&united" + + tok(4).getType shouldBe SnowflakeLexer.STRING_END + tok(4).getText shouldBe "'" + } + + "scan a single ampersand in a string" in { + val tok = fillTokens("'&'") + dumpTokens(tok) + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(1).getText shouldBe "&" + + tok(2).getType shouldBe SnowflakeLexer.STRING_END + tok(2).getText shouldBe "'" + + } + + "scan a trailing ampersand in a string" in { + val tok = fillTokens("'&score&'") + dumpTokens(tok) + tok.head.getType shouldBe SnowflakeLexer.STRING_START + tok.head.getText shouldBe "'" + + tok(1).getType shouldBe SnowflakeLexer.VAR_SIMPLE + tok(1).getText shouldBe "&score" + + tok(2).getType shouldBe SnowflakeLexer.STRING_CONTENT + tok(2).getText shouldBe "&" + + tok(3).getType shouldBe SnowflakeLexer.STRING_END + tok(3).getText shouldBe "'" + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala new file mode 100644 index 0000000000..f4f6e31b3c --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala @@ -0,0 +1,668 @@ +package com.databricks.labs.remorph.parsers +package snowflake + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeAstBuilderSpec + extends AnyWordSpec + with SnowflakeParserTestCommon + with SetOperationBehaviors[SnowflakeParser] + with Matchers + with IRHelpers { + + override protected def astBuilder: SnowflakeAstBuilder = vc.astBuilder + + private def singleQueryExample(query: String, expectedAst: LogicalPlan): Unit = + example(query, _.snowflakeFile(), Batch(Seq(expectedAst))) + + "SnowflakeAstBuilder" should { + "translate a simple SELECT query" in { + singleQueryExample( + query = "SELECT a FROM TABLE", + expectedAst = Project(NamedTable("TABLE", Map.empty), Seq(Id("a")))) + } + + "translate a simple SELECT query with an aliased column" in { + singleQueryExample( + query = "SELECT a AS aa FROM b", + expectedAst = Project(NamedTable("b", Map.empty), Seq(Alias(Id("a"), Id("aa"))))) + } + + "translate a simple SELECT query involving multiple columns" in { + singleQueryExample( + query = "SELECT a, b, c FROM table_x", + expectedAst = Project(NamedTable("table_x", Map.empty), Seq(Id("a"), Id("b"), Id("c")))) + } + + "translate a SELECT query involving multiple columns and aliases" in { + singleQueryExample( + query = "SELECT a, b AS bb, c FROM table_x", + expectedAst = Project(NamedTable("table_x", Map.empty), Seq(Id("a"), Alias(Id("b"), Id("bb")), Id("c")))) + } + + "translate a SELECT query involving a table alias" in { + singleQueryExample( + query = "SELECT t.a FROM table_x t", + expectedAst = Project(TableAlias(NamedTable("table_x"), "t"), Seq(Dot(Id("t"), Id("a"))))) + } + + "translate a SELECT query involving a column alias and a table alias" in { + singleQueryExample( + query = "SELECT t.a, t.b as b FROM table_x t", + expectedAst = Project( + TableAlias(NamedTable("table_x"), "t"), + Seq(Dot(Id("t"), Id("a")), Alias(Dot(Id("t"), Id("b")), Id("b"))))) + } + + val simpleJoinAst = + Join( + NamedTable("table_x", Map.empty), + NamedTable("table_y", Map.empty), + join_condition = None, + UnspecifiedJoin, + using_columns = Seq(), + JoinDataType(is_left_struct = false, is_right_struct = false)) + + "translate a query with a JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x JOIN table_y", + expectedAst = Project(simpleJoinAst, Seq(Id("a")))) + } + + "translate a query with a INNER JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x INNER JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = InnerJoin), Seq(Id("a")))) + } + + "translate a query with a CROSS JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x CROSS JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = CrossJoin), Seq(Id("a")))) + } + + "translate a query with a LEFT JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x LEFT JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = LeftOuterJoin), Seq(Id("a")))) + } + + "translate a query with a LEFT OUTER JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x LEFT OUTER JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = LeftOuterJoin), Seq(Id("a")))) + } + + "translate a query with a RIGHT JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x RIGHT JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = RightOuterJoin), Seq(Id("a")))) + } + + "translate a query with a RIGHT OUTER JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x RIGHT OUTER JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = RightOuterJoin), Seq(Id("a")))) + } + + "translate a query with a FULL JOIN" in { + singleQueryExample( + query = "SELECT a FROM table_x FULL JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = FullOuterJoin), Seq(Id("a")))) + } + + "translate a query with a NATURAL JOIN" should { + "SELECT a FROM table_x NATURAL JOIN table_y" in { + singleQueryExample( + query = "SELECT a FROM table_x NATURAL JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = NaturalJoin(UnspecifiedJoin)), Seq(Id("a")))) + } + "SELECT a FROM table_x NATURAL LEFT JOIN table_y" in { + singleQueryExample( + query = "SELECT a FROM table_x NATURAL LEFT JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = NaturalJoin(LeftOuterJoin)), Seq(Id("a")))) + } + "SELECT a FROM table_x NATURAL RIGHT JOIN table_y" in { + singleQueryExample( + query = "SELECT a FROM table_x NATURAL RIGHT JOIN table_y", + expectedAst = Project(simpleJoinAst.copy(join_type = NaturalJoin(RightOuterJoin)), Seq(Id("a")))) + } + } + + "translate a query with a simple WHERE clause" in { + val expectedOperatorTranslations = List( + "=" -> Equals(Id("a"), Id("b")), + "!=" -> NotEquals(Id("a"), Id("b")), + "<>" -> NotEquals(Id("a"), Id("b")), + ">" -> GreaterThan(Id("a"), Id("b")), + "<" -> LessThan(Id("a"), Id("b")), + ">=" -> GreaterThanOrEqual(Id("a"), Id("b")), + "<=" -> LessThanOrEqual(Id("a"), Id("b"))) + + expectedOperatorTranslations.foreach { case (op, expectedPredicate) => + singleQueryExample( + query = s"SELECT a, b FROM c WHERE a $op b", + expectedAst = Project(Filter(NamedTable("c", Map.empty), expectedPredicate), Seq(Id("a"), Id("b")))) + } + } + + "translate a query with a WHERE clause involving composite predicates" should { + "SELECT a, b FROM c WHERE a = b AND b = a" in { + singleQueryExample( + query = "SELECT a, b FROM c WHERE a = b AND b = a", + expectedAst = Project( + Filter(NamedTable("c", Map.empty), And(Equals(Id("a"), Id("b")), Equals(Id("b"), Id("a")))), + Seq(Id("a"), Id("b")))) + } + "SELECT a, b FROM c WHERE a = b OR b = a" in { + singleQueryExample( + query = "SELECT a, b FROM c WHERE a = b OR b = a", + expectedAst = Project( + Filter(NamedTable("c", Map.empty), Or(Equals(Id("a"), Id("b")), Equals(Id("b"), Id("a")))), + Seq(Id("a"), Id("b")))) + } + "SELECT a, b FROM c WHERE NOT a = b" in { + singleQueryExample( + query = "SELECT a, b FROM c WHERE NOT a = b", + expectedAst = + Project(Filter(NamedTable("c", Map.empty), Not(Equals(Id("a"), Id("b")))), Seq(Id("a"), Id("b")))) + } + } + + "translate a query with a GROUP BY clause" in { + singleQueryExample( + query = "SELECT a, COUNT(b) FROM c GROUP BY a", + expectedAst = Project( + Aggregate( + child = NamedTable("c", Map.empty), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("a")), + pivot = None), + Seq(Id("a"), CallFunction("COUNT", Seq(Id("b")))))) + } + + "translate a query with a GROUP BY and ORDER BY clauses" in { + singleQueryExample( + query = "SELECT a, COUNT(b) FROM c GROUP BY a ORDER BY a", + expectedAst = Project( + Sort( + Aggregate( + child = NamedTable("c", Map.empty), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("a")), + pivot = None), + Seq(SortOrder(Id("a"), Ascending, NullsLast))), + Seq(Id("a"), CallFunction("COUNT", Seq(Id("b")))))) + } + + "translate a query with GROUP BY HAVING clause" in { + singleQueryExample( + query = "SELECT a, COUNT(b) FROM c GROUP BY a HAVING COUNT(b) > 1", + expectedAst = Project( + Filter( + Aggregate( + child = NamedTable("c", Map.empty), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("a")), + pivot = None), + GreaterThan(CallFunction("COUNT", Seq(Id("b"))), Literal(1))), + Seq(Id("a"), CallFunction("COUNT", Seq(Id("b")))))) + } + + "translate a query with ORDER BY" should { + "SELECT a FROM b ORDER BY a" in { + singleQueryExample( + query = "SELECT a FROM b ORDER BY a", + expectedAst = + Project(Sort(NamedTable("b", Map.empty), Seq(SortOrder(Id("a"), Ascending, NullsLast))), Seq(Id("a")))) + } + "SELECT a FROM b ORDER BY a DESC" in { + singleQueryExample( + "SELECT a FROM b ORDER BY a DESC", + Project(Sort(NamedTable("b", Map.empty), Seq(SortOrder(Id("a"), Descending, NullsFirst))), Seq(Id("a")))) + } + "SELECT a FROM b ORDER BY a NULLS FIRST" in { + singleQueryExample( + query = "SELECT a FROM b ORDER BY a NULLS FIRST", + expectedAst = + Project(Sort(NamedTable("b", Map.empty), Seq(SortOrder(Id("a"), Ascending, NullsFirst))), Seq(Id("a")))) + } + "SELECT a FROM b ORDER BY a DESC NULLS LAST" in { + singleQueryExample( + query = "SELECT a FROM b ORDER BY a DESC NULLS LAST", + expectedAst = + Project(Sort(NamedTable("b", Map.empty), Seq(SortOrder(Id("a"), Descending, NullsLast))), Seq(Id("a")))) + } + } + + "translate queries with LIMIT and OFFSET" should { + "SELECT a FROM b LIMIT 5" in { + singleQueryExample( + query = "SELECT a FROM b LIMIT 5", + expectedAst = Project(Limit(NamedTable("b", Map.empty), Literal(5)), Seq(Id("a")))) + } + "SELECT a FROM b LIMIT 5 OFFSET 10" in { + singleQueryExample( + query = "SELECT a FROM b LIMIT 5 OFFSET 10", + expectedAst = Project(Offset(Limit(NamedTable("b", Map.empty), Literal(5)), Literal(10)), Seq(Id("a")))) + } + "SELECT a FROM b OFFSET 10 FETCH FIRST 42" in { + singleQueryExample( + query = "SELECT a FROM b OFFSET 10 FETCH FIRST 42", + expectedAst = Project(Offset(NamedTable("b", Map.empty), Literal(10)), Seq(Id("a")))) + } + } + + "translate a query with PIVOT" in { + singleQueryExample( + query = "SELECT a FROM b PIVOT (SUM(a) FOR c IN ('foo', 'bar'))", + expectedAst = Project( + Aggregate( + child = NamedTable("b", Map.empty), + group_type = Pivot, + grouping_expressions = Seq(CallFunction("SUM", Seq(simplyNamedColumn("a")))), + pivot = Some(Pivot(simplyNamedColumn("c"), Seq(Literal("foo"), Literal("bar"))))), + Seq(Id("a")))) + } + + "translate a query with UNPIVOT" in { + singleQueryExample( + query = "SELECT a FROM b UNPIVOT (c FOR d IN (e, f))", + expectedAst = Project( + Unpivot( + child = NamedTable("b", Map.empty), + ids = Seq(simplyNamedColumn("e"), simplyNamedColumn("f")), + values = None, + variable_column_name = Id("c"), + value_column_name = Id("d")), + Seq(Id("a")))) + } + + "translate queries with WITH clauses" should { + "WITH a (b, c, d) AS (SELECT x, y, z FROM e) SELECT b, c, d FROM a" in { + singleQueryExample( + query = "WITH a (b, c, d) AS (SELECT x, y, z FROM e) SELECT b, c, d FROM a", + expectedAst = WithCTE( + Seq( + SubqueryAlias( + Project(namedTable("e"), Seq(Id("x"), Id("y"), Id("z"))), + Id("a"), + Seq(Id("b"), Id("c"), Id("d")))), + Project(namedTable("a"), Seq(Id("b"), Id("c"), Id("d"))))) + } + "WITH a (b, c, d) AS (SELECT x, y, z FROM e), aa (bb, cc) AS (SELECT xx, yy FROM f) SELECT b, c, d FROM a" in { + singleQueryExample( + query = + "WITH a (b, c, d) AS (SELECT x, y, z FROM e), aa (bb, cc) AS (SELECT xx, yy FROM f) SELECT b, c, d FROM a", + expectedAst = WithCTE( + Seq( + SubqueryAlias( + Project(namedTable("e"), Seq(Id("x"), Id("y"), Id("z"))), + Id("a"), + Seq(Id("b"), Id("c"), Id("d"))), + SubqueryAlias(Project(namedTable("f"), Seq(Id("xx"), Id("yy"))), Id("aa"), Seq(Id("bb"), Id("cc")))), + Project(namedTable("a"), Seq(Id("b"), Id("c"), Id("d"))))) + } + "WITH a (b, c, d) AS (SELECT x, y, z FROM e) SELECT x, y, z FROM e UNION SELECT b, c, d FROM a" in { + singleQueryExample( + query = "WITH a (b, c, d) AS (SELECT x, y, z FROM e) SELECT x, y, z FROM e UNION SELECT b, c, d FROM a", + expectedAst = WithCTE( + Seq( + SubqueryAlias( + Project(namedTable("e"), Seq(Id("x"), Id("y"), Id("z"))), + Id("a"), + Seq(Id("b"), Id("c"), Id("d")))), + SetOperation( + Project(namedTable("e"), Seq(Id("x"), Id("y"), Id("z"))), + Project(namedTable("a"), Seq(Id("b"), Id("c"), Id("d"))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false))) + } + } + + "translate a query with WHERE, GROUP BY, HAVING, QUALIFY" in { + singleQueryExample( + query = """SELECT c2, SUM(c3) OVER (PARTITION BY c2) as r + | FROM t1 + | WHERE c3 < 4 + | GROUP BY c2, c3 + | HAVING AVG(c1) >= 5 + | QUALIFY MIN(r) > 6""".stripMargin, + expectedAst = Project( + Filter( + Filter( + Aggregate( + child = Filter(namedTable("t1"), LessThan(Id("c3"), Literal(4))), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("c2"), simplyNamedColumn("c3")), + pivot = None), + GreaterThanOrEqual(CallFunction("AVG", Seq(Id("c1"))), Literal(5))), + GreaterThan(CallFunction("MIN", Seq(Id("r"))), Literal(6))), + Seq(Id("c2"), Alias(Window(CallFunction("SUM", Seq(Id("c3"))), Seq(Id("c2")), Seq(), None), Id("r"))))) + } + + behave like setOperationsAreTranslated(_.queryExpression()) + + "translate Snowflake-specific set operators" should { + "SELECT a FROM t1 MINUS SELECT b FROM t2" in { + singleQueryExample( + "SELECT a FROM t1 MINUS SELECT b FROM t2", + SetOperation( + Project(namedTable("t1"), Seq(Id("a"))), + Project(namedTable("t2"), Seq(Id("b"))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + // Part of checking that UNION, EXCEPT and MINUS are processed with the same precedence: left-to-right + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 MINUS SELECT 4" should { + singleQueryExample( + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 MINUS SELECT 4", + SetOperation( + SetOperation( + SetOperation( + Project(NoTable, Seq(Literal(1, IntegerType))), + Project(NoTable, Seq(Literal(2, IntegerType))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(3, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(4, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 MINUS SELECT 4" should { + singleQueryExample( + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 MINUS SELECT 4", + SetOperation( + SetOperation( + SetOperation( + Project(NoTable, Seq(Literal(1, IntegerType))), + Project(NoTable, Seq(Literal(2, IntegerType))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(3, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(4, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + "SELECT 1 EXCEPT SELECT 2 MINUS SELECT 3 UNION SELECT 4" should { + singleQueryExample( + "SELECT 1 EXCEPT SELECT 2 MINUS SELECT 3 UNION SELECT 4", + SetOperation( + SetOperation( + SetOperation( + Project(NoTable, Seq(Literal(1, IntegerType))), + Project(NoTable, Seq(Literal(2, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(3, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(4, IntegerType))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + "SELECT 1 MINUS SELECT 2 UNION SELECT 3 EXCEPT SELECT 4" should { + singleQueryExample( + "SELECT 1 MINUS SELECT 2 UNION SELECT 3 EXCEPT SELECT 4", + SetOperation( + SetOperation( + SetOperation( + Project(NoTable, Seq(Literal(1, IntegerType))), + Project(NoTable, Seq(Literal(2, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(3, IntegerType))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + Project(NoTable, Seq(Literal(4, IntegerType))), + ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + // INTERSECT has higher precedence than UNION, EXCEPT and MINUS + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 INTERSECT SELECT 4" should { + singleQueryExample( + "SELECT 1 UNION SELECT 2 EXCEPT SELECT 3 MINUS SELECT 4 INTERSECT SELECT 5", + ir.SetOperation( + ir.SetOperation( + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(1, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(2, ir.IntegerType))), + ir.UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.Project(ir.NoTable, Seq(ir.Literal(3, ir.IntegerType))), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.SetOperation( + ir.Project(ir.NoTable, Seq(ir.Literal(4, ir.IntegerType))), + ir.Project(ir.NoTable, Seq(ir.Literal(5, ir.IntegerType))), + ir.IntersectSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + ir.ExceptSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false)) + } + } + + "translate batches of queries" in { + example( + """ + |CREATE TABLE t1 (x VARCHAR); + |SELECT x FROM t1; + |SELECT 3 FROM t3; + |""".stripMargin, + _.snowflakeFile(), + Batch( + Seq( + CreateTableCommand("t1", Seq(ColumnDeclaration("x", StringType))), + Project(namedTable("t1"), Seq(Id("x"))), + Project(namedTable("t3"), Seq(Literal(3)))))) + } + + // Tests below are just meant to verify that SnowflakeAstBuilder properly delegates DML commands + // (other than SELECT) to SnowflakeDMLBuilder + + "translate INSERT commands" in { + singleQueryExample( + "INSERT INTO t (c1, c2, c3) VALUES (1,2, 3), (4, 5, 6)", + InsertIntoTable( + namedTable("t"), + Some(Seq(Id("c1"), Id("c2"), Id("c3"))), + Values(Seq(Seq(Literal(1), Literal(2), Literal(3)), Seq(Literal(4), Literal(5), Literal(6)))), + None, + None)) + } + + "translate DELETE commands" in { + singleQueryExample( + "DELETE FROM t WHERE t.c1 > 42", + DeleteFromTable(namedTable("t"), None, Some(GreaterThan(Dot(Id("t"), Id("c1")), Literal(42))), None, None)) + } + + "translate UPDATE commands" in { + singleQueryExample( + "UPDATE t1 SET c1 = 42;", + UpdateTable(namedTable("t1"), None, Seq(Assign(Column(None, Id("c1")), Literal(42))), None, None, None)) + } + + "survive an invalid command" in { + example( + """ + |CREATE TABLE t1 (x VARCHAR); + |SELECT x y z; + |SELECT 3 FROM t3; + |""".stripMargin, + _.snowflakeFile(), + Batch( + Seq( + CreateTableCommand("t1", Seq(ColumnDeclaration("x", StringType))), + UnresolvedRelation("Unparsable text: SELECTxyz", message = "Unparsed input - ErrorNode encountered"), + UnresolvedRelation( + """Unparsable text: SELECT + |Unparsable text: x + |Unparsable text: y + |Unparsable text: z + |Unparsable text: parser recovered by ignoring: SELECTxyz;""".stripMargin, + message = "Unparsed input - ErrorNode encountered"), + Project(namedTable("t3"), Seq(Literal(3))))), + failOnErrors = false) + + } + + "translate BANG to Unresolved Expression" in { + + example( + "!set error_flag = true;", + _.snowSqlCommand(), + UnresolvedCommand( + ruleText = "!set error_flag = true;", + ruleName = "snowSqlCommand", + tokenName = Some("SQLCOMMAND"), + message = "Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand")) + + example( + "!set dfsdfds", + _.snowSqlCommand(), + UnresolvedCommand( + ruleText = "!set dfsdfds", + ruleName = "snowSqlCommand", + tokenName = Some("SQLCOMMAND"), + message = "Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand")) + assertThrows[Exception] { + example( + "!", + _.snowSqlCommand(), + UnresolvedCommand( + ruleText = "!", + ruleName = "snowSqlCommand", + tokenName = Some("SQLCOMMAND"), + message = "Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand")) + } + assertThrows[Exception] { + example( + "!badcommand", + _.snowSqlCommand(), + UnresolvedCommand( + ruleText = "!badcommand", + ruleName = "snowSqlCommand", + tokenName = Some("SQLCOMMAND"), + message = "Unknown command in SnowflakeAstBuilder.visitSqlCommand")) + } + } + + "translate amps" should { + "select * from a where b = &ids" in { + singleQueryExample( + "select * from a where b = &ids", + // Note when we truly process &vars we should get Variable, not Id + Project(Filter(namedTable("a"), Equals(Id("b"), Id("$ids"))), Seq(Star()))) + } + } + + "translate with recursive" should { + "WITH RECURSIVE employee_hierarchy" in { + singleQueryExample( + """WITH RECURSIVE employee_hierarchy AS ( + | SELECT + | employee_id, + | manager_id, + | employee_name, + | 1 AS level + | FROM + | employees + | WHERE + | manager_id IS NULL + | UNION ALL + | SELECT + | e.employee_id, + | e.manager_id, + | e.employee_name, + | eh.level + 1 AS level + | FROM + | employees e + | INNER JOIN + | employee_hierarchy eh ON e.manager_id = eh.employee_id + |) + |SELECT * + |FROM employee_hierarchy + |ORDER BY level, employee_id;""".stripMargin, + WithRecursiveCTE( + Seq( + SubqueryAlias( + SetOperation( + Project( + Filter(NamedTable("employees"), IsNull(Id("manager_id"))), + Seq( + Id("employee_id"), + Id("manager_id"), + Id("employee_name"), + Alias(Literal(1, IntegerType), Id("level")))), + Project( + Join( + TableAlias(NamedTable("employees"), "e"), + TableAlias(NamedTable("employee_hierarchy"), "eh"), + join_condition = Some(Equals(Dot(Id("e"), Id("manager_id")), Dot(Id("eh"), Id("employee_id")))), + InnerJoin, + using_columns = Seq(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + Seq( + Dot(Id("e"), Id("employee_id")), + Dot(Id("e"), Id("manager_id")), + Dot(Id("e"), Id("employee_name")), + Alias(Add(Dot(Id("eh"), Id("level")), Literal(1, IntegerType)), Id("level")))), + UnionSetOp, + is_all = true, + by_name = false, + allow_missing_columns = false), + Id("employee_hierarchy"), + Seq.empty)), + Project( + Sort( + NamedTable("employee_hierarchy", Map.empty), + Seq(SortOrder(Id("level"), Ascending, NullsLast), SortOrder(Id("employee_id"), Ascending, NullsLast))), + Seq(Star(None))))) + } + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilderSpec.scala new file mode 100644 index 0000000000..79a532d620 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeCommandBuilderSpec.scala @@ -0,0 +1,92 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.intermediate.procedures.SetVariable +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class SnowflakeCommandBuilderSpec + extends AnyWordSpec + with SnowflakeParserTestCommon + with Matchers + with MockitoSugar + with IRHelpers { + + override protected def astBuilder: SnowflakeCommandBuilder = vc.commandBuilder + + "translate Declare to CreateVariable Expression" should { + "X NUMBER DEFAULT 0;" in { + example( + "X NUMBER DEFAULT 0;", + _.declareElement(), + CreateVariable(name = Id("X"), dataType = DecimalType(38, 0), defaultExpr = Some(Literal(0)), replace = false)) + } + "select_statement VARCHAR;" in { + example( + "select_statement VARCHAR;", + _.declareElement(), + CreateVariable(name = Id("select_statement"), dataType = StringType, defaultExpr = None, replace = false)) + } + "price NUMBER(13,2) DEFAULT 111.50;" in { + example( + "price NUMBER(13,2) DEFAULT 111.50;", + _.declareElement(), + CreateVariable( + name = Id("price"), + dataType = DecimalType(Some(13), Some(2)), + defaultExpr = Some(Literal(111.5f)), + replace = false)) + } + "query_statement RESULTSET := (SELECT col1 FROM some_table);" in { + example( + "query_statement RESULTSET := (SELECT col1 FROM some_table);", + _.declareElement(), + CreateVariable( + name = Id("query_statement"), + dataType = StructType(Seq()), + defaultExpr = Some( + ScalarSubquery( + Project(NamedTable("some_table", Map(), is_streaming = false), Seq(Id("col1", caseSensitive = false))))), + replace = false)) + } + } + + "translate Let to SetVariable expressions" should { + "LET X := 1;" in { + example("LET X := 1;", _.let(), SetVariable(name = Id("X"), dataType = None, value = Literal(1))) + } + "LET select_statement := 'SELECT * FROM table WHERE id = ' || id;" in { + example( + "LET select_statement := 'SELECT * FROM table WHERE id = ' || id;", + _.let(), + SetVariable( + name = Id("select_statement"), + dataType = None, + value = Concat(Seq(Literal("SELECT * FROM table WHERE id = "), Id("id"))))) + } + "LET price NUMBER(13,2) DEFAULT 111.50;" in { + example( + "LET price NUMBER(13,2) DEFAULT 111.50;", + _.let(), + SetVariable(name = Id("price"), dataType = Some(DecimalType(Some(13), Some(2))), value = Literal(111.5f))) + } + "LET price NUMBER(13,2) := 121.55;" in { + example( + "LET price NUMBER(13,2) := 121.55;", + _.let(), + SetVariable(name = Id("price"), dataType = Some(DecimalType(Some(13), Some(2))), value = Literal(121.55f))) + } + "LET query_statement RESULTSET := (SELECT col1 FROM some_table);" in { + example( + "LET query_statement RESULTSET := (SELECT col1 FROM some_table);", + _.let(), + SetVariable( + name = Id("query_statement"), + dataType = Some(StructType(Seq(StructField("col1", UnresolvedType)))), + value = ScalarSubquery( + Project(NamedTable("some_table", Map(), is_streaming = false), Seq(Id("col1", caseSensitive = false)))))) + } + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilderSpec.scala new file mode 100644 index 0000000000..2be1d3322f --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDDLBuilderSpec.scala @@ -0,0 +1,496 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => _, _} +import org.antlr.v4.runtime.CommonToken +import org.mockito.Mockito._ +import org.scalatest.matchers.should +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class SnowflakeDDLBuilderSpec + extends AnyWordSpec + with SnowflakeParserTestCommon + with should.Matchers + with MockitoSugar + with IRHelpers { + + override protected def astBuilder: SnowflakeDDLBuilder = vc.ddlBuilder + + private def example(query: String, expectedAst: Catalog): Unit = example(query, _.ddlCommand(), expectedAst) + + "SnowflakeCommandBuilder" should { + "translate Java UDF create command" in { + + val javaCode = """class TestFunc { + | public static String echoVarchar(String x) { + | return x; + | } + |}""".stripMargin + + example( + query = s""" + |CREATE OR REPLACE FUNCTION echo_varchar(x varchar) + |RETURNS VARCHAR + |LANGUAGE JAVA + |CALLED ON NULL INPUT + |IMPORTS = ('@~/some-dir/some-lib.jar') + |HANDLER = 'TestFunc.echoVarchar' + |AS + |'$javaCode'; + |""".stripMargin, + expectedAst = CreateInlineUDF( + name = "echo_varchar", + returnType = StringType, + parameters = Seq(FunctionParameter("x", StringType, None)), + JavaRuntimeInfo( + runtimeVersion = None, + imports = Seq("@~/some-dir/some-lib.jar"), + handler = "TestFunc.echoVarchar"), + acceptsNullParameters = true, + comment = None, + body = javaCode)) + + } + + "translate Python UDF create command" in { + val pythonCode = """import numpy as np + |import pandas as pd + |import xgboost as xgb + |def udf(): + | return [np.__version__, pd.__version__, xgb.__version__] + |""".stripMargin + + example( + query = s"""CREATE OR REPLACE FUNCTION py_udf() + | RETURNS VARIANT + | LANGUAGE PYTHON + | RUNTIME_VERSION = '3.8' + | PACKAGES = ('numpy','pandas','xgboost==1.5.0') + | HANDLER = 'udf' + |AS $$$$ + |$pythonCode + |$$$$;""".stripMargin, + expectedAst = CreateInlineUDF( + name = "py_udf", + returnType = VariantType, + parameters = Seq(), + runtimeInfo = PythonRuntimeInfo( + runtimeVersion = Some("3.8"), + packages = Seq("numpy", "pandas", "xgboost==1.5.0"), + handler = "udf"), + acceptsNullParameters = false, + comment = None, + body = pythonCode.trim)) + } + + "translate JavaScript UDF create command" in { + val javascriptCode = """if (D <= 0) { + | return 1; + | } else { + | var result = 1; + | for (var i = 2; i <= D; i++) { + | result = result * i; + | } + | return result; + | }""".stripMargin + example( + query = s"""CREATE OR REPLACE FUNCTION js_factorial(d double) + | RETURNS double + | LANGUAGE JAVASCRIPT + | STRICT + | COMMENT = 'Compute factorial using JavaScript' + | AS ' + | $javascriptCode + | ';""".stripMargin, + expectedAst = CreateInlineUDF( + name = "js_factorial", + returnType = DoubleType, + parameters = Seq(FunctionParameter("d", DoubleType, None)), + runtimeInfo = JavaScriptRuntimeInfo, + acceptsNullParameters = false, + comment = Some("Compute factorial using JavaScript"), + body = javascriptCode)) + } + + "translate Scala UDF create command" in { + val scalaCode = """class Echo { + | def echoVarchar(x : String): String = { + | return x + | } + |}""".stripMargin + + example( + query = s"""CREATE OR REPLACE FUNCTION echo_varchar(x VARCHAR DEFAULT 'foo') + | RETURNS VARCHAR + | LANGUAGE SCALA + | CALLED ON NULL INPUT + | RUNTIME_VERSION = '2.12' + | HANDLER='Echo.echoVarchar' + | AS + | $$$$ + | $scalaCode + | $$$$;""".stripMargin, + expectedAst = CreateInlineUDF( + name = "echo_varchar", + returnType = StringType, + parameters = Seq(FunctionParameter("x", StringType, Some(Literal("foo")))), + runtimeInfo = ScalaRuntimeInfo(runtimeVersion = Some("2.12"), imports = Seq(), handler = "Echo.echoVarchar"), + acceptsNullParameters = true, + comment = None, + body = scalaCode)) + } + + "translate SQL UDF create command" in { + example( + query = """CREATE FUNCTION multiply1 (a number, b number) + | RETURNS number + | COMMENT='multiply two numbers' + | AS 'a * b';""".stripMargin, + expectedAst = CreateInlineUDF( + name = "multiply1", + returnType = DecimalType(38, 0), + parameters = + Seq(FunctionParameter("a", DecimalType(38, 0), None), FunctionParameter("b", DecimalType(38, 0), None)), + runtimeInfo = SQLRuntimeInfo(memoizable = false), + acceptsNullParameters = false, + comment = Some("multiply two numbers"), + body = "a * b")) + } + + "translate CREATE TABLE commands" should { + "CREATE TABLE s.t1 (x VARCHAR)" in { + example( + "CREATE TABLE s.t1 (x VARCHAR)", + CreateTableCommand(name = "s.t1", columns = Seq(ColumnDeclaration("x", StringType, None, Seq())))) + } + "CREATE TABLE s.t1 (x VARCHAR UNIQUE)" in { + example( + "CREATE TABLE s.t1 (x VARCHAR UNIQUE)", + CreateTableCommand(name = "s.t1", columns = Seq(ColumnDeclaration("x", StringType, None, Seq(Unique()))))) + } + "CREATE TABLE s.t1 (x VARCHAR NOT NULL)" in { + example( + "CREATE TABLE s.t1 (x VARCHAR NOT NULL)", + CreateTableCommand("s.t1", Seq(ColumnDeclaration("x", StringType, None, Seq(Nullability(false)))))) + } + "CREATE TABLE s.t1 (x VARCHAR PRIMARY KEY)" in { + example( + "CREATE TABLE s.t1 (x VARCHAR PRIMARY KEY)", + CreateTableCommand("s.t1", Seq(ColumnDeclaration("x", StringType, None, Seq(PrimaryKey()))))) + } + "CREATE TABLE s.t1 (x VARCHAR UNIQUE FOREIGN KEY REFERENCES s.t2 (y))" in { + example( + "CREATE TABLE s.t1 (x VARCHAR UNIQUE FOREIGN KEY REFERENCES s.t2 (y))", + CreateTableCommand( + "s.t1", + Seq(ColumnDeclaration("x", StringType, None, Seq(Unique(), ForeignKey("", "s.t2.y", "", Seq.empty)))))) + } + "more complex" in { + example( + query = """CREATE TABLE s.t1 ( + | id VARCHAR PRIMARY KEY NOT NULL, + | a VARCHAR(32) UNIQUE, + | b INTEGER, + | CONSTRAINT fkey FOREIGN KEY (a, b) REFERENCES s.t2 (x, y) + |) + |""".stripMargin, + expectedAst = CreateTableCommand( + name = "s.t1", + columns = Seq( + ColumnDeclaration("id", StringType, None, Seq(Nullability(false), PrimaryKey())), + ColumnDeclaration( + "a", + StringType, + None, + Seq(Unique(), NamedConstraint("fkey", ForeignKey("", "s.t2.x", "", Seq.empty)))), + ColumnDeclaration( + "b", + DecimalType(Some(38), Some(0)), + None, + Seq(NamedConstraint("fkey", ForeignKey("", "s.t2.y", "", Seq.empty))))))) + } + + "CREATE TABLE t1 AS SELECT c1, c2 FROM t2;" in { + example( + "CREATE TABLE t1 AS (SELECT * FROM t2);", + CreateTableParams( + CreateTableAsSelect("t1", Project(namedTable("t2"), Seq(Star(None))), None, None, None), + Map.empty[String, Seq[Constraint]], + Map.empty[String, Seq[GenericOption]], + Seq.empty[Constraint], + Seq.empty[Constraint], + None, + None)) + } + } + "translate ALTER TABLE commands" should { + "ALTER TABLE s.t1 ADD COLUMN c VARCHAR" in { + example( + "ALTER TABLE s.t1 ADD COLUMN c VARCHAR", + AlterTableCommand("s.t1", Seq(AddColumn(Seq(ColumnDeclaration("c", StringType)))))) + } + "ALTER TABLE s.t1 ADD CONSTRAINT pk PRIMARY KEY (a, b, c)" in { + example( + "ALTER TABLE s.t1 ADD CONSTRAINT pk PRIMARY KEY (a, b, c)", + AlterTableCommand( + "s.t1", + Seq( + AddConstraint("a", NamedConstraint("pk", PrimaryKey())), + AddConstraint("b", NamedConstraint("pk", PrimaryKey())), + AddConstraint("c", NamedConstraint("pk", PrimaryKey()))))) + } + "ALTER TABLE s.t1 ALTER (COLUMN a TYPE INT)" in { + example( + "ALTER TABLE s.t1 ALTER (COLUMN a TYPE INT)", + AlterTableCommand("s.t1", Seq(ChangeColumnDataType("a", DecimalType(Some(38), Some(0)))))) + } + "ALTER TABLE s.t1 ALTER (COLUMN a NOT NULL)" in { + example( + "ALTER TABLE s.t1 ALTER (COLUMN a NOT NULL)", + AlterTableCommand("s.t1", Seq(AddConstraint("a", Nullability(false))))) + } + "ALTER TABLE s.t1 ALTER (COLUMN a DROP NOT NULL)" in { + example( + "ALTER TABLE s.t1 ALTER (COLUMN a DROP NOT NULL)", + AlterTableCommand("s.t1", Seq(DropConstraint(Some("a"), Nullability(false))))) + } + "ALTER TABLE s.t1 DROP COLUMN a" in { + example("ALTER TABLE s.t1 DROP COLUMN a", AlterTableCommand("s.t1", Seq(DropColumns(Seq("a"))))) + } + "ALTER TABLE s.t1 DROP PRIMARY KEY" in { + example("ALTER TABLE s.t1 DROP PRIMARY KEY", AlterTableCommand("s.t1", Seq(DropConstraint(None, PrimaryKey())))) + } + "ALTER TABLE s.t1 DROP CONSTRAINT pk" in { + example("ALTER TABLE s.t1 DROP CONSTRAINT pk", AlterTableCommand("s.t1", Seq(DropConstraintByName("pk")))) + } + "ALTER TABLE s.t1 DROP UNIQUE (b, c)" in { + example( + "ALTER TABLE s.t1 DROP UNIQUE (b, c)", + AlterTableCommand("s.t1", Seq(DropConstraint(Some("b"), Unique()), DropConstraint(Some("c"), Unique())))) + } + "ALTER TABLE s.t1 RENAME COLUMN a TO aa" in { + example("ALTER TABLE s.t1 RENAME COLUMN a TO aa", AlterTableCommand("s.t1", Seq(RenameColumn("a", "aa")))) + } + "ALTER TABLE s.t1 RENAME CONSTRAINT pk TO pk_t1" in { + example( + "ALTER TABLE s.t1 RENAME CONSTRAINT pk TO pk_t1", + AlterTableCommand("s.t1", Seq(RenameConstraint("pk", "pk_t1")))) + } + } + + "translate Unresolved COMMAND" should { + "ALTER SESSION SET QUERY_TAG = 'TAG'" in { + example( + "ALTER SESSION SET QUERY_TAG = 'TAG';", + UnresolvedCommand( + ruleText = "ALTER SESSION SET QUERY_TAG = 'TAG'", + message = "Unknown ALTER command variant", + ruleName = "alterCommand", + tokenName = Some("ALTER"))) + } + + "ALTER STREAM mystream SET COMMENT = 'New comment for stream'" in { + example( + "ALTER STREAM mystream SET COMMENT = 'New comment for stream';", + UnresolvedCommand( + ruleText = "ALTER STREAM mystream SET COMMENT = 'New comment for stream'", + message = "Unknown ALTER command variant", + ruleName = "alterCommand", + tokenName = Some("ALTER"))) + } + + "CREATE STREAM mystream ON TABLE mytable" in { + example( + "CREATE STREAM mystream ON TABLE mytable;", + UnresolvedCommand( + ruleText = "CREATE STREAM mystream ON TABLE mytable", + message = "CREATE STREAM UNSUPPORTED", + ruleName = "createStream", + tokenName = Some("STREAM"))) + } + + "CREATE TASK t1 SCHEDULE = '30 MINUTE' AS INSERT INTO tbl(ts) VALUES(CURRENT_TIMESTAMP)" in { + example( + "CREATE TASK t1 SCHEDULE = '30 MINUTE' AS INSERT INTO tbl(ts) VALUES(CURRENT_TIMESTAMP);", + UnresolvedCommand( + ruleText = "CREATE TASK t1 SCHEDULE = '30 MINUTE' AS INSERT INTO tbl(ts) VALUES(CURRENT_TIMESTAMP)", + message = "CREATE TASK UNSUPPORTED", + ruleName = "createTask", + tokenName = Some("TASK"))) + } + } + + "wrap unknown AST in UnresolvedCommand" in { + vc.ddlBuilder.visit(parseString("CREATE USER homer", _.createCommand())) shouldBe a[UnresolvedCommand] + } + } + + "SnowflakeDDLBuilder.buildOutOfLineConstraint" should { + + "handle unexpected child" in { + val columnList = parseString("(a, b, c)", _.columnListInParentheses()) + val outOfLineConstraint = mock[OutOfLineConstraintContext] + when(outOfLineConstraint.columnListInParentheses(0)).thenReturn(columnList) + val dummyInputTextForOutOfLineConstraint = "dummy" + when(outOfLineConstraint.getText).thenReturn(dummyInputTextForOutOfLineConstraint) + val result = vc.ddlBuilder.buildOutOfLineConstraints(outOfLineConstraint) + result shouldBe Seq( + "a" -> UnresolvedConstraint(dummyInputTextForOutOfLineConstraint), + "b" -> UnresolvedConstraint(dummyInputTextForOutOfLineConstraint), + "c" -> UnresolvedConstraint(dummyInputTextForOutOfLineConstraint)) + verify(outOfLineConstraint).columnListInParentheses(0) + verify(outOfLineConstraint).UNIQUE() + verify(outOfLineConstraint).primaryKey() + verify(outOfLineConstraint).foreignKey() + verify(outOfLineConstraint).id() + verify(outOfLineConstraint, times(3)).getText + verifyNoMoreInteractions(outOfLineConstraint) + + } + } + + "SnowflakeDDLBuilder.buildInlineConstraint" should { + + "handle unexpected child" in { + val inlineConstraint = mock[InlineConstraintContext] + val dummyInputTextForInlineConstraint = "dummy" + when(inlineConstraint.getText).thenReturn(dummyInputTextForInlineConstraint) + val result = vc.ddlBuilder.buildInlineConstraint(inlineConstraint) + result shouldBe UnresolvedConstraint(dummyInputTextForInlineConstraint) + verify(inlineConstraint).UNIQUE() + verify(inlineConstraint).primaryKey() + verify(inlineConstraint).foreignKey() + verify(inlineConstraint).getText + verifyNoMoreInteractions(inlineConstraint) + } + } + "SnowflakeDDLBuilder.visitAlter_table" should { + "handle unexpected child" in { + val tableName = parseString("s.t1", _.dotIdentifier()) + val alterTable = mock[AlterTableContext] + val startTok = new CommonToken(ID, "s") + when(alterTable.dotIdentifier(0)).thenReturn(tableName) + when(alterTable.getStart).thenReturn(startTok) + when(alterTable.getStop).thenReturn(startTok) + when(alterTable.getRuleIndex).thenReturn(SnowflakeParser.RULE_alterTable) + val result = vc.ddlBuilder.visitAlterTable(alterTable) + result shouldBe UnresolvedCommand( + ruleText = "Mocked string", + message = "Unknown ALTER TABLE variant", + ruleName = "alterTable", + tokenName = Some("ID")) + verify(alterTable).dotIdentifier(0) + verify(alterTable).tableColumnAction() + verify(alterTable).constraintAction() + verify(alterTable).getRuleIndex + verify(alterTable, times(3)).getStart + verify(alterTable).getStop + verifyNoMoreInteractions(alterTable) + } + } + + "SnowflakeDDLBuilder.buildColumnActions" should { + "handle unexpected child" in { + val tableColumnAction = mock[TableColumnActionContext] + when(tableColumnAction.alterColumnClause()) + .thenReturn(java.util.Collections.emptyList[AlterColumnClauseContext]()) + val startTok = new CommonToken(ID, "s") + when(tableColumnAction.getStart).thenReturn(startTok) + when(tableColumnAction.getStop).thenReturn(startTok) + when(tableColumnAction.getRuleIndex).thenReturn(SnowflakeParser.RULE_tableColumnAction) + val result = vc.ddlBuilder.buildColumnActions(tableColumnAction) + result shouldBe Seq( + UnresolvedTableAlteration( + ruleText = "Mocked string", + message = "Unknown COLUMN action variant", + ruleName = "tableColumnAction", + tokenName = Some("ID"))) + verify(tableColumnAction).alterColumnClause() + verify(tableColumnAction).ADD() + verify(tableColumnAction).alterColumnClause() + verify(tableColumnAction).DROP() + verify(tableColumnAction).RENAME() + verify(tableColumnAction).getRuleIndex + verify(tableColumnAction, times(3)).getStart + verify(tableColumnAction).getStop + verifyNoMoreInteractions(tableColumnAction) + } + } + + "SnowflakeDDLBuilder.buildColumnAlterations" should { + "handle unexpected child" in { + val columnName = parseString("a", _.columnName()) + val alterColumnClause = mock[AlterColumnClauseContext] + when(alterColumnClause.columnName()).thenReturn(columnName) + val startTok = new CommonToken(ID, "s") + when(alterColumnClause.getStart).thenReturn(startTok) + when(alterColumnClause.getStop).thenReturn(startTok) + when(alterColumnClause.getRuleIndex).thenReturn(SnowflakeParser.RULE_alterColumnClause) + val result = vc.ddlBuilder.buildColumnAlterations(alterColumnClause) + result shouldBe UnresolvedTableAlteration( + ruleText = "Mocked string", + message = "Unknown ALTER COLUMN variant", + ruleName = "alterColumnClause", + tokenName = Some("ID")) + verify(alterColumnClause).columnName() + verify(alterColumnClause).dataType() + verify(alterColumnClause).DROP() + verify(alterColumnClause).NULL() + verify(alterColumnClause).getRuleIndex + verify(alterColumnClause, times(3)).getStart + verify(alterColumnClause).getStop + verifyNoMoreInteractions(alterColumnClause) + } + } + + "SnowflakeDDLBuilder.buildConstraintActions" should { + "handle unexpected child" in { + val constraintAction = mock[ConstraintActionContext] + val startTok = new CommonToken(ID, "s") + when(constraintAction.getStart).thenReturn(startTok) + when(constraintAction.getStop).thenReturn(startTok) + when(constraintAction.getRuleIndex).thenReturn(SnowflakeParser.RULE_constraintAction) + val result = vc.ddlBuilder.buildConstraintActions(constraintAction) + result shouldBe Seq( + UnresolvedTableAlteration( + ruleText = "Mocked string", + message = "Unknown CONSTRAINT variant", + ruleName = "constraintAction", + tokenName = Some("ID"))) + verify(constraintAction).ADD() + verify(constraintAction).DROP() + verify(constraintAction).RENAME() + verify(constraintAction).getRuleIndex + verify(constraintAction, times(3)).getStart + verify(constraintAction).getStop + verifyNoMoreInteractions(constraintAction) + } + } + + "SnowflakeDDLBuilder.buildDropConstraints" should { + "handle unexpected child" in { + val constraintAction = mock[ConstraintActionContext] + when(constraintAction.id()).thenReturn(java.util.Collections.emptyList[IdContext]) + val startTok = new CommonToken(ID, "s") + when(constraintAction.getStart).thenReturn(startTok) + when(constraintAction.getStop).thenReturn(startTok) + when(constraintAction.getRuleIndex).thenReturn(SnowflakeParser.RULE_constraintAction) + val result = vc.ddlBuilder.buildDropConstraints(constraintAction) + result shouldBe Seq( + UnresolvedTableAlteration( + ruleText = "Mocked string", + message = "Unknown DROP constraint variant", + ruleName = "constraintAction", + tokenName = Some("ID"))) + verify(constraintAction).columnListInParentheses() + verify(constraintAction).primaryKey() + verify(constraintAction).UNIQUE() + verify(constraintAction).id() + verify(constraintAction).getRuleIndex + verify(constraintAction, times(3)).getStart + verify(constraintAction).getStop + verifyNoMoreInteractions(constraintAction) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilderSpec.scala new file mode 100644 index 0000000000..09933c77be --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeDMLBuilderSpec.scala @@ -0,0 +1,189 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeDMLBuilderSpec extends AnyWordSpec with SnowflakeParserTestCommon with IRHelpers { + + override protected def astBuilder: SnowflakeDMLBuilder = vc.dmlBuilder + + "SnowflakeDMLBuilder" should { + "translate INSERT statements" should { + "INSERT INTO foo SELECT * FROM bar LIMIT 100" in { + example( + "INSERT INTO foo SELECT * FROM bar LIMIT 100", + _.insertStatement(), + InsertIntoTable( + namedTable("foo"), + None, + Project(Limit(namedTable("bar"), Literal(100)), Seq(Star(None))), + None, + None, + overwrite = false)) + } + "INSERT OVERWRITE INTO foo SELECT * FROM bar LIMIT 100" in { + example( + "INSERT OVERWRITE INTO foo SELECT * FROM bar LIMIT 100", + _.insertStatement(), + InsertIntoTable( + namedTable("foo"), + None, + Project(Limit(namedTable("bar"), Literal(100)), Seq(Star(None))), + None, + None, + overwrite = true)) + } + "INSERT INTO foo VALUES (1, 2, 3), (4, 5, 6)" in { + example( + "INSERT INTO foo VALUES (1, 2, 3), (4, 5, 6)", + _.insertStatement(), + InsertIntoTable( + namedTable("foo"), + None, + Values(Seq(Seq(Literal(1), Literal(2), Literal(3)), Seq(Literal(4), Literal(5), Literal(6)))), + None, + None, + overwrite = false)) + } + "INSERT OVERWRITE INTO foo VALUES (1, 2, 3), (4, 5, 6)" in { + example( + "INSERT OVERWRITE INTO foo VALUES (1, 2, 3), (4, 5, 6)", + _.insertStatement(), + InsertIntoTable( + namedTable("foo"), + None, + Values(Seq(Seq(Literal(1), Literal(2), Literal(3)), Seq(Literal(4), Literal(5), Literal(6)))), + None, + None, + overwrite = true)) + } + } + + "translate DELETE statements" should { + "direct" in { + example( + "DELETE FROM t WHERE t.c1 > 42", + _.deleteStatement(), + DeleteFromTable(namedTable("t"), None, Some(GreaterThan(Dot(Id("t"), Id("c1")), Literal(42))), None, None)) + } + + "as merge" should { + "DELETE FROM t1 USING t2 WHERE t1.c1 = t2.c2" in { + example( + "DELETE FROM t1 USING t2 WHERE t1.c1 = t2.c2", + _.deleteStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(DeleteAction(None)))) + } + "DELETE FROM table1 AS t1 USING (SELECT * FROM table2) AS t2 WHERE t1.c1 = t2.c2" in { + example( + "DELETE FROM table1 AS t1 USING (SELECT * FROM table2) AS t2 WHERE t1.c1 = t2.c2", + _.deleteStatement(), + MergeIntoTable( + TableAlias(namedTable("table1"), "t1", Seq()), + SubqueryAlias(Project(namedTable("table2"), Seq(Star(None))), Id("t2"), Seq()), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(DeleteAction(None)))) + } + } + } + + "translate UPDATE statements" should { + "UPDATE t1 SET c1 = 42" in { + example( + "UPDATE t1 SET c1 = 42", + _.updateStatement(), + UpdateTable(namedTable("t1"), None, Seq(Assign(Column(None, Id("c1")), Literal(42))), None, None, None)) + } + "UPDATE t1 SET c1 = 42 WHERE c1 < 0" in { + example( + "UPDATE t1 SET c1 = 42 WHERE c1 < 0", + _.updateStatement(), + UpdateTable( + namedTable("t1"), + None, + Seq(Assign(Column(None, Id("c1")), Literal(42))), + Some(LessThan(Id("c1"), Literal(0))), + None, + None)) + } + "UPDATE table1 as t1 SET c1 = c2 + t2.c2 FROM table2 as t2 WHERE t1.c3 = t2.c3" in { + example( + "UPDATE table1 as t1 SET c1 = c2 + t2.c2 FROM table2 as t2 WHERE t1.c3 = t2.c3", + _.updateStatement(), + UpdateTable( + TableAlias(namedTable("table1"), "t1", Seq()), + Some( + crossJoin(TableAlias(namedTable("table1"), "t1", Seq()), TableAlias(namedTable("table2"), "t2", Seq()))), + Seq(Assign(Column(None, Id("c1")), Add(Id("c2"), Dot(Id("t2"), Id("c2"))))), + Some(Equals(Dot(Id("t1"), Id("c3")), Dot(Id("t2"), Id("c3")))), + None, + None)) + } + } + + "translate MERGE statements" should { + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN UPDATE SET c1 = 42" in { + example( + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN UPDATE SET c1 = 42", + _.mergeStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(UpdateAction(None, Seq(Assign(Column(None, Id("c1")), Literal(42))))))) + } + + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN DELETE" in { + example( + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN DELETE", + _.mergeStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(DeleteAction(None)))) + } + + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED AND t2.date = '01/01/2024' THEN DELETE" in { + example( + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED AND t2.date = '01/01/2024' THEN DELETE", + _.mergeStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(DeleteAction(Some(Equals(Dot(Id("t2"), Id("date")), Literal("01/01/2024"))))))) + } + + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN DELETE WHEN NOT MATCHED THEN INSERT" in { + example( + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN DELETE WHEN NOT MATCHED THEN INSERT", + _.mergeStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = Seq(DeleteAction(None)), + notMatchedActions = Seq(InsertAction(None, Seq.empty[Assign])))) + } + + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN UPDATE SET t1.c1 = 42" in { + example( + "MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN UPDATE SET t1.c1 = 42", + _.mergeStatement(), + MergeIntoTable( + namedTable("t1"), + namedTable("t2"), + Equals(Dot(Id("t1"), Id("c1")), Dot(Id("t2"), Id("c2"))), + matchedActions = + Seq(UpdateAction(None, Seq(Assign(Column(Some(ObjectReference(Id("t1"))), Id("c1")), Literal(42))))))) + } + + } + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExprSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExprSpec.scala new file mode 100644 index 0000000000..e875a0bb86 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExprSpec.scala @@ -0,0 +1,247 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeExprSpec extends AnyWordSpec with SnowflakeParserTestCommon with Matchers with IRHelpers { + + override protected def astBuilder: SnowflakeExpressionBuilder = vc.expressionBuilder + + private def example(input: String, expectedAst: Expression): Unit = exampleExpr(input, _.expr(), expectedAst) + + private def searchConditionExample(input: String, expectedAst: Expression): Unit = + exampleExpr(input, _.searchCondition(), expectedAst) + + "SnowflakeExpressionBuilder" should { + "do something" should { + val col = Dot(Dot(Dot(Id("d"), Id("s")), Id("t")), Id("column_1")) + + "d.s.sequence_1.NEXTVAL" in { + example("d.s.sequence_1.NEXTVAL", NextValue("d.s.sequence_1")) + } + "d.s.t.column_1[42]" in { + example( + "d.s.t.column_1[42]", + Dot(Dot(Dot(Id("d"), Id("s")), Id("t")), ArrayAccess(Id("column_1"), Literal(42)))) + } + "d.s.t.column_1:field_1.\"inner field\"" in { + example( + "d.s.t.column_1:field_1.\"inner field\"", + JsonAccess(col, Dot(Id("field_1"), Id("inner field", caseSensitive = true)))) + } + "d.s.t.column_1 COLLATE 'en_US-trim'" in { + example("d.s.t.column_1 COLLATE 'en_US-trim'", Collate(col, "en_US-trim")) + } + } + + "translate unary arithmetic operators" should { + "+column_1" in { + example("+column_1", UPlus(Id("column_1"))) + } + "+42" in { + example("+42", UPlus(Literal(42))) + } + "-column_1" in { + example("-column_1", UMinus(Id("column_1"))) + } + "-42" in { + example("-42", UMinus(Literal(42))) + } + "NOT true" in { + example("NOT true", Not(Literal.True)) + } + "NOT column_2" in { + example("NOT column_2", Not(Id("column_2"))) + } + } + + "translate binary arithmetic operators" should { + "1+1" in { + example("1+1", Add(Literal(1), Literal(1))) + } + "2 * column_1" in { + example("2 * column_1", Multiply(Literal(2), Id("column_1"))) + } + "column_1 - 1" in { + example("column_1 - 1", Subtract(Id("column_1"), Literal(1))) + } + "column_1/column_2" in { + example("column_1/column_2", Divide(Id("column_1"), Id("column_2"))) + } + "42 % 2" in { + example("42 % 2", Mod(Literal(42), Literal(2))) + } + "'foo' || column_1" in { + example("'foo' || column_1", Concat(Seq(Literal("foo"), Id("column_1")))) + } + } + + "translate IFF expression" in { + example("IFF (true, column_1, column_2)", If(Literal.True, Id("column_1"), Id("column_2"))) + } + + "translate array literals" should { + "[1, 2, 3]" in { + example("[1, 2, 3]", ArrayExpr(Seq(Literal(1), Literal(2), Literal(3)), IntegerType)) + } + "[1, 2, 'three']" in { + example("[1, 2, 'three']", ArrayExpr(Seq(Literal(1), Literal(2), Literal("three")), IntegerType)) + } + } + + "translate cast expressions" should { + "CAST (column_1 AS BOOLEAN)" in { + example("CAST (column_1 AS BOOLEAN)", Cast(Id("column_1"), BooleanType)) + } + "TRY_CAST (column_1 AS BOOLEAN)" in { + example("TRY_CAST (column_1 AS BOOLEAN)", TryCast(Id("column_1"), BooleanType)) + } + "TO_TIMESTAMP(1234567890)" in { + example("TO_TIMESTAMP(1234567890)", CallFunction("TO_TIMESTAMP", Seq(Literal(1234567890)))) + } + "TIME('00:00:00')" in { + example("TIME('00:00:00')", CallFunction("TO_TIME", Seq(Literal("00:00:00")))) + } + "TO_TIME(column_1)" in { + example("TO_TIME(column_1)", CallFunction("TO_TIME", Seq(Id("column_1")))) + } + "DATE(column_1)" in { + example("DATE(column_1)", CallFunction("TO_DATE", Seq(Id("column_1")))) + } + "TO_DATE('2024-05-15')" in { + example("TO_DATE('2024-05-15')", CallFunction("TO_DATE", Seq(Literal("2024-05-15")))) + } + "INTERVAL '1 hour'" in { + example("INTERVAL '1 hour'", Cast(Literal("1 hour"), IntervalType)) + } + "42::FLOAT" in { + example("42::FLOAT", Cast(Literal(42), DoubleType)) + } + "TO_CHAR(42)" in { + example("TO_CHAR(42)", CallFunction("TO_VARCHAR", Seq(Literal(42)))) + } + } + + "translate IN expressions" should { + "col1 IN (SELECT * FROM t)" in { + searchConditionExample( + "col1 IN (SELECT * FROM t)", + In(Id("col1"), Seq(ScalarSubquery(Project(namedTable("t"), Seq(Star(None))))))) + } + "col1 NOT IN (SELECT * FROM t)" in { + searchConditionExample( + "col1 NOT IN (SELECT * FROM t)", + Not(In(Id("col1"), Seq(ScalarSubquery(Project(namedTable("t"), Seq(Star(None)))))))) + } + "col1 IN (1, 2, 3)" in { + searchConditionExample("col1 IN (1, 2, 3)", In(Id("col1"), Seq(Literal(1), Literal(2), Literal(3)))) + } + "col1 NOT IN ('foo', 'bar')" in { + searchConditionExample("col1 NOT IN ('foo', 'bar')", Not(In(Id("col1"), Seq(Literal("foo"), Literal("bar"))))) + } + } + + "translate BETWEEN expressions" should { + "col1 BETWEEN 3.14 AND 42" in { + searchConditionExample("col1 BETWEEN 3.14 AND 42", Between(Id("col1"), Literal(3.14), Literal(42))) + } + "col1 NOT BETWEEN 3.14 AND 42" in { + searchConditionExample("col1 NOT BETWEEN 3.14 AND 42", Not(Between(Id("col1"), Literal(3.14), Literal(42)))) + } + } + + "translate LIKE expressions" should { + "col1 LIKE '%foo'" in { + searchConditionExample("col1 LIKE '%foo'", Like(Id("col1"), Literal("%foo"), None)) + } + "col1 ILIKE '%foo'" in { + searchConditionExample("col1 ILIKE '%foo'", ILike(Id("col1"), Literal("%foo"), None)) + } + "col1 NOT LIKE '%foo'" in { + searchConditionExample("col1 NOT LIKE '%foo'", Not(Like(Id("col1"), Literal("%foo"), None))) + } + "col1 NOT ILIKE '%foo'" in { + searchConditionExample("col1 NOT ILIKE '%foo'", Not(ILike(Id("col1"), Literal("%foo"), None))) + } + "col1 LIKE '%foo' ESCAPE '^'" in { + searchConditionExample("col1 LIKE '%foo' ESCAPE '^'", Like(Id("col1"), Literal("%foo"), Some(Literal('^')))) + } + "col1 ILIKE '%foo' ESCAPE '^'" in { + searchConditionExample("col1 ILIKE '%foo' ESCAPE '^'", ILike(Id("col1"), Literal("%foo"), Some(Literal('^')))) + } + "col1 NOT LIKE '%foo' ESCAPE '^'" in { + searchConditionExample( + "col1 NOT LIKE '%foo' ESCAPE '^'", + Not(Like(Id("col1"), Literal("%foo"), Some(Literal('^'))))) + } + "col1 NOT ILIKE '%foo' ESCAPE '^'" in { + searchConditionExample( + "col1 NOT ILIKE '%foo' ESCAPE '^'", + Not(ILike(Id("col1"), Literal("%foo"), Some(Literal('^'))))) + } + "col1 LIKE ANY ('%foo', 'bar%', '%qux%')" in { + searchConditionExample( + "col1 LIKE ANY ('%foo', 'bar%', '%qux%')", + LikeAny(Id("col1"), Seq(Literal("%foo"), Literal("bar%"), Literal("%qux%")))) + } + "col1 LIKE ALL ('%foo', 'bar^%', '%qux%') ESCAPE '^'" in { + searchConditionExample( + "col1 LIKE ALL ('%foo', 'bar^%', '%qux%') ESCAPE '^'", + LikeAll(Id("col1"), Seq(Literal("%foo"), Literal("bar\\^%"), Literal("%qux%")))) + } + "col1 ILIKE ANY ('%foo', 'bar^%', '%qux%') ESCAPE '^'" in { + searchConditionExample( + "col1 ILIKE ANY ('%foo', 'bar^%', '%qux%') ESCAPE '^'", + ILikeAny(Id("col1"), Seq(Literal("%foo"), Literal("bar\\^%"), Literal("%qux%")))) + } + "col1 ILIKE ALL ('%foo', 'bar%', '%qux%')" in { + searchConditionExample( + "col1 ILIKE ALL ('%foo', 'bar%', '%qux%')", + ILikeAll(Id("col1"), Seq(Literal("%foo"), Literal("bar%"), Literal("%qux%")))) + } + "col1 RLIKE '[a-z][A-Z]*'" in { + searchConditionExample("col1 RLIKE '[a-z][A-Z]*'", RLike(Id("col1"), Literal("[a-z][A-Z]*"))) + } + "col1 NOT RLIKE '[a-z][A-Z]*'" in { + searchConditionExample("col1 NOT RLIKE '[a-z][A-Z]*'", Not(RLike(Id("col1"), Literal("[a-z][A-Z]*")))) + } + } + + "translate IS [NOT] NULL expressions" should { + "col1 IS NULL" in { + searchConditionExample("col1 IS NULL", IsNull(Id("col1"))) + } + "col1 IS NOT NULL" in { + searchConditionExample("col1 IS NOT NULL", IsNotNull(Id("col1"))) + } + } + + "translate DISTINCT expressions" in { + searchConditionExample("DISTINCT col1", Distinct(Id("col1"))) + } + + "translate WITHIN GROUP expressions" in { + searchConditionExample( + "ARRAY_AGG(col1) WITHIN GROUP (ORDER BY col2)", + WithinGroup(CallFunction("ARRAY_AGG", Seq(Id("col1"))), Seq(SortOrder(Id("col2"), Ascending, NullsLast)))) + } + + "translate JSON path expressions" should { + "s.f.value:x.y.names[2]" in { + example( + "s.f.value:x.y.names[2]", + JsonAccess( + Dot(Dot(Id("s"), Id("f")), Id("value")), + Dot(Dot(Id("x"), Id("y")), ArrayAccess(Id("names"), Literal(2))))) + } + "x:inner_obj['nested_field']" in { + example("x:inner_obj['nested_field']", JsonAccess(Id("x"), JsonAccess(Id("inner_obj"), Id("nested_field")))) + } + } + + "translate JSON literals" in { + example("{'a': 1, 'b': 2}", StructExpr(Seq(Alias(Literal(1), Id("a")), Alias(Literal(2), Id("b"))))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilderSpec.scala new file mode 100644 index 0000000000..420794d0b2 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeExpressionBuilderSpec.scala @@ -0,0 +1,536 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{ComparisonOperatorContext, ID, LiteralContext} +import org.antlr.v4.runtime.CommonToken +import org.mockito.Mockito._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class SnowflakeExpressionBuilderSpec + extends AnyWordSpec + with SnowflakeParserTestCommon + with Matchers + with MockitoSugar + with IRHelpers { + + override protected def astBuilder: SnowflakeExpressionBuilder = vc.expressionBuilder + + "SnowflakeExpressionBuilder" should { + "translate literals" should { + "null" in { + exampleExpr("null", _.literal(), Literal.Null) + } + "true" in { + exampleExpr("true", _.literal(), Literal.True) + } + "false" in { + exampleExpr("false", _.literal(), Literal.False) + } + "1" in { + exampleExpr("1", _.literal(), Literal(1)) + } + Int.MaxValue.toString in { + exampleExpr(Int.MaxValue.toString, _.literal(), Literal(Int.MaxValue)) + } + "-1" in { + exampleExpr("-1", _.literal(), Literal(-1)) + } + "1.1" in { + exampleExpr("1.1", _.literal(), Literal(1.1f)) + } + "1.1e2" in { + exampleExpr("1.1e2", _.literal(), Literal(110)) + } + Long.MaxValue.toString in { + exampleExpr(Long.MaxValue.toString, _.literal(), Literal(Long.MaxValue)) + } + "1.1e-2" in { + exampleExpr("1.1e-2", _.literal(), Literal(0.011f)) + } + "0.123456789" in { + exampleExpr("0.123456789", _.literal(), Literal(0.123456789)) + } + "0.123456789e-1234" in { + exampleExpr("0.123456789e-1234", _.literal(), DecimalLiteral("0.123456789e-1234")) + } + "'foo'" in { + exampleExpr("'foo'", _.literal(), Literal("foo")) + } + "DATE'1970-01-01'" in { + exampleExpr("DATE'1970-01-01'", _.literal(), Literal(0, DateType)) + } + "TIMESTAMP'1970-01-01 00:00:00'" in { + exampleExpr("TIMESTAMP'1970-01-01 00:00:00'", _.literal(), Literal(0, TimestampType)) + } + } + + "translate ids (quoted or not)" should { + "foo" in { + exampleExpr("foo", _.id(), Id("foo")) + } + "\"foo\"" in { + exampleExpr("\"foo\"", _.id(), Id("foo", caseSensitive = true)) + } + "\"foo \"\"quoted bar\"\"\"" in { + exampleExpr("\"foo \"\"quoted bar\"\"\"", _.id(), Id("foo \"quoted bar\"", caseSensitive = true)) + } + } + + "translate column names" should { + "x" in { + exampleExpr("x", _.columnName(), simplyNamedColumn("x")) + } + "\"My Table\".x" in { + exampleExpr( + "\"My Table\".x", + _.columnName(), + Column(Some(ObjectReference(Id("My Table", caseSensitive = true))), Id("x"))) + } + } + + "translate column positions" should { + "$1" in { + exampleExpr("$1", _.columnElem(), Column(None, Position(1))) + } + } + + "translate aliases" should { + "x AS y" in { + exampleExpr("1 AS y", _.selectListElem(), Alias(Literal(1), Id("y"))) + } + "1 y" in { + exampleExpr("1 y", _.selectListElem(), Alias(Literal(1), Id("y"))) + } + } + + "translate simple numeric binary expressions" should { + "1 + 2" in { + exampleExpr("1 + 2", _.expr(), Add(Literal(1), Literal(2))) + } + "1 +2" in { + exampleExpr("1 +2", _.expr(), Add(Literal(1), Literal(2))) + } + "1 - 2" in { + exampleExpr("1 - 2", _.expr(), Subtract(Literal(1), Literal(2))) + } + "1 -2" in { + exampleExpr("1 -2", _.expr(), Subtract(Literal(1), Literal(2))) + } + "1 * 2" in { + exampleExpr("1 * 2", _.expr(), Multiply(Literal(1), Literal(2))) + } + "1 / 2" in { + exampleExpr("1 / 2", _.expr(), Divide(Literal(1), Literal(2))) + } + "1 % 2" in { + exampleExpr("1 % 2", _.expr(), Mod(Literal(1), Literal(2))) + } + "'A' || 'B'" in { + exampleExpr("'A' || 'B'", _.expr(), Concat(Seq(Literal("A"), Literal("B")))) + } + } + + "translate complex binary expressions" should { + "a + b * 2" in { + exampleExpr("a + b * 2", _.expr(), Add(Id("a"), Multiply(Id("b"), Literal(2)))) + } + "(a + b) * 2" in { + exampleExpr("(a + b) * 2", _.expr(), Multiply(Add(Id("a"), Id("b")), Literal(2))) + } + "a % 3 + b * 2 - c / 5" in { + exampleExpr( + "a % 3 + b * 2 - c / 5", + _.expr(), + Subtract(Add(Mod(Id("a"), Literal(3)), Multiply(Id("b"), Literal(2))), Divide(Id("c"), Literal(5)))) + } + "a || b || c" in { + exampleExpr("a || b || c", _.expr(), Concat(Seq(Concat(Seq(Id("a"), Id("b"))), Id("c")))) + } + } + + "correctly apply operator precedence and associativity" should { + "1 + -++-2" in { + exampleExpr("1 + -++-2", _.expr(), Add(Literal(1), UMinus(UPlus(UPlus(UMinus(Literal(2))))))) + } + "1 + -2 * 3" in { + exampleExpr("1 + -2 * 3", _.expr(), Add(Literal(1), Multiply(UMinus(Literal(2)), Literal(3)))) + } + "1 + -2 * 3 + 7 || 'leeds1' || 'leeds2' || 'leeds3'" in { + exampleExpr( + "1 + -2 * 3 + 7 || 'leeds1' || 'leeds2' || 'leeds3'", + _.expr(), + Concat( + Seq( + Concat( + Seq( + Concat( + Seq(Add(Add(Literal(1), Multiply(UMinus(Literal(2)), Literal(3))), Literal(7)), Literal("leeds1"))), + Literal("leeds2"))), + Literal("leeds3")))) + } + } + + "correctly respect explicit precedence with parentheses" should { + "(1 + 2) * 3" in { + exampleExpr("(1 + 2) * 3", _.expr(), Multiply(Add(Literal(1), Literal(2)), Literal(3))) + } + "1 + (2 * 3)" in { + exampleExpr("1 + (2 * 3)", _.expr(), Add(Literal(1), Multiply(Literal(2), Literal(3)))) + } + "(1 + 2) * (3 + 4)" in { + exampleExpr("(1 + 2) * (3 + 4)", _.expr(), Multiply(Add(Literal(1), Literal(2)), Add(Literal(3), Literal(4)))) + } + "1 + (2 * 3) + 4" in { + exampleExpr("1 + (2 * 3) + 4", _.expr(), Add(Add(Literal(1), Multiply(Literal(2), Literal(3))), Literal(4))) + } + "1 + (2 * 3 + 4)" in { + exampleExpr("1 + (2 * 3 + 4)", _.expr(), Add(Literal(1), Add(Multiply(Literal(2), Literal(3)), Literal(4)))) + } + "1 + (2 * (3 + 4))" in { + exampleExpr("1 + (2 * (3 + 4))", _.expr(), Add(Literal(1), Multiply(Literal(2), Add(Literal(3), Literal(4))))) + } + "(1 + (2 * (3 + 4)))" in { + exampleExpr("(1 + (2 * (3 + 4)))", _.expr(), Add(Literal(1), Multiply(Literal(2), Add(Literal(3), Literal(4))))) + } + } + + "translate functions with special syntax" should { + "EXTRACT(day FROM date1)" in { + exampleExpr( + "EXTRACT(day FROM date1)", + _.builtinFunction(), + CallFunction("EXTRACT", Seq(Id("day"), Id("date1")))) + } + + "EXTRACT('day' FROM date1)" in { + exampleExpr( + "EXTRACT('day' FROM date1)", + _.builtinFunction(), + CallFunction("EXTRACT", Seq(Id("day"), Id("date1")))) + } + + } + + "translate functions named with a keyword" should { + "LEFT(foo, bar)" in { + exampleExpr("LEFT(foo, bar)", _.standardFunction(), CallFunction("LEFT", Seq(Id("foo"), Id("bar")))) + } + "RIGHT(foo, bar)" in { + exampleExpr("RIGHT(foo, bar)", _.standardFunction(), CallFunction("RIGHT", Seq(Id("foo"), Id("bar")))) + } + } + + "translate aggregation functions" should { + "COUNT(x)" in { + exampleExpr("COUNT(x)", _.aggregateFunction(), CallFunction("COUNT", Seq(Id("x")))) + } + "AVG(x)" in { + exampleExpr("AVG(x)", _.aggregateFunction(), CallFunction("AVG", Seq(Id("x")))) + } + "SUM(x)" in { + exampleExpr("SUM(x)", _.aggregateFunction(), CallFunction("SUM", Seq(Id("x")))) + } + "MIN(x)" in { + exampleExpr("MIN(x)", _.aggregateFunction(), CallFunction("MIN", Seq(Id("x")))) + } + "COUNT(*)" in { + exampleExpr("COUNT(*)", _.aggregateFunction(), CallFunction("COUNT", Seq(Star(None)))) + } + "LISTAGG(x, ',')" in { + exampleExpr("LISTAGG(x, ',')", _.aggregateFunction(), CallFunction("LISTAGG", Seq(Id("x"), Literal(",")))) + } + "ARRAY_AGG(x)" in { + exampleExpr("ARRAY_AGG(x)", _.aggregateFunction(), CallFunction("ARRAY_AGG", Seq(Id("x")))) + } + } + + "translate a query with a window function" should { + "ROW_NUMBER() OVER (ORDER BY a DESC)" in { + exampleExpr( + "ROW_NUMBER() OVER (ORDER BY a DESC)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(), + sort_order = Seq(SortOrder(Id("a"), Descending, NullsFirst)), + frame_spec = None)) + } + "ROW_NUMBER() OVER (PARTITION BY a)" in { + exampleExpr( + "ROW_NUMBER() OVER (PARTITION BY a)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("a")), + sort_order = Seq(), + frame_spec = None)) + } + "NTILE(42) OVER (PARTITION BY a ORDER BY b, c DESC, d)" in { + exampleExpr( + "NTILE(42) OVER (PARTITION BY a ORDER BY b, c DESC, d)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("NTILE", Seq(Literal(42))), + partition_spec = Seq(Id("a")), + sort_order = Seq( + SortOrder(Id("b"), Ascending, NullsLast), + SortOrder(Id("c"), Descending, NullsFirst), + SortOrder(Id("d"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, UnboundedPreceding, UnboundedFollowing)))) + } + "LAST_VALUE(col_name) IGNORE NULLS OVER (PARTITION BY a ORDER BY b, c DESC, d)" in { + exampleExpr( + "LAST_VALUE(col_name) IGNORE NULLS OVER (PARTITION BY a ORDER BY b, c DESC, d)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("LAST_VALUE", Seq(Id("col_name"))), + partition_spec = Seq(Id("a")), + sort_order = Seq( + SortOrder(Id("b"), Ascending, NullsLast), + SortOrder(Id("c"), Descending, NullsFirst), + SortOrder(Id("d"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, UnboundedPreceding, UnboundedFollowing)), + ignore_nulls = true)) + } + } + + "translate window frame specifications" should { + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)" in { + exampleExpr( + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("a")), + sort_order = Seq(SortOrder(Id("a"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, UnboundedPreceding, CurrentRow)))) + } + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)" in { + exampleExpr( + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("a")), + sort_order = Seq(SortOrder(Id("a"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, UnboundedPreceding, UnboundedFollowing)))) + } + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN 42 PRECEDING AND CURRENT ROW)" in { + exampleExpr( + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN 42 PRECEDING AND CURRENT ROW)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("a")), + sort_order = Seq(SortOrder(Id("a"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, PrecedingN(Literal(42)), CurrentRow)))) + } + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN 42 PRECEDING AND 42 FOLLOWING)" in { + exampleExpr( + "ROW_NUMBER() OVER(PARTITION BY a ORDER BY a ROWS BETWEEN 42 PRECEDING AND 42 FOLLOWING)", + _.rankingWindowedFunction(), + expectedAst = Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("a")), + sort_order = Seq(SortOrder(Id("a"), Ascending, NullsLast)), + frame_spec = Some(WindowFrame(RowsFrame, PrecedingN(Literal(42)), FollowingN(Literal(42)))))) + } + } + + "translate star-expressions" should { + "*" in { + exampleExpr("*", _.columnElemStar(), Star(None)) + } + "t.*" in { + exampleExpr("t.*", _.columnElemStar(), Star(Some(ObjectReference(Id("t"))))) + } + exampleExpr( + "db1.schema1.table1.*", + _.columnElemStar(), + Star(Some(ObjectReference(Id("db1"), Id("schema1"), Id("table1"))))) + } + + "translate scalar subquery" in { + exampleExpr( + query = "(SELECT col1 from table_expr)", + rule = _.expr(), + expectedAst = ScalarSubquery(Project(namedTable("table_expr"), Seq(Id("col1"))))) + } + } + + "SnowflakeExpressionBuilder.buildSortOrder" should { + + "translate ORDER BY a" in { + val tree = parseString("ORDER BY a", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq(SortOrder(Id("a"), Ascending, NullsLast)) + } + + "translate ORDER BY a ASC NULLS FIRST" in { + val tree = parseString("ORDER BY a ASC NULLS FIRST", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq(SortOrder(Id("a"), Ascending, NullsFirst)) + } + + "translate ORDER BY a DESC" in { + val tree = parseString("ORDER BY a DESC", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq(SortOrder(Id("a"), Descending, NullsFirst)) + } + + "translate ORDER BY a, b DESC" in { + val tree = parseString("ORDER BY a, b DESC", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq( + SortOrder(Id("a"), Ascending, NullsLast), + SortOrder(Id("b"), Descending, NullsFirst)) + } + + "translate ORDER BY a DESC NULLS LAST, b" in { + val tree = parseString("ORDER BY a DESC NULLS LAST, b", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq( + SortOrder(Id("a"), Descending, NullsLast), + SortOrder(Id("b"), Ascending, NullsLast)) + } + + "translate ORDER BY with many expressions" in { + val tree = parseString("ORDER BY a DESC, b, c ASC, d DESC NULLS LAST, e", _.orderByClause()) + vc.expressionBuilder.buildSortOrder(tree) shouldBe Seq( + SortOrder(Id("a"), Descending, NullsFirst), + SortOrder(Id("b"), Ascending, NullsLast), + SortOrder(Id("c"), Ascending, NullsLast), + SortOrder(Id("d"), Descending, NullsLast), + SortOrder(Id("e"), Ascending, NullsLast)) + } + + "translate EXISTS expressions" in { + exampleExpr("EXISTS (SELECT * FROM t)", _.predicate, Exists(Project(namedTable("t"), Seq(Star(None))))) + } + + // see https://github.com/databrickslabs/remorph/issues/273 + "translate NOT EXISTS expressions" ignore { + exampleExpr("NOT EXISTS (SELECT * FROM t)", _.expr(), Not(Exists(Project(namedTable("t"), Seq(Star(None)))))) + } + } + + "translate CASE expressions" should { + "CASE WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' END" in { + exampleExpr( + "CASE WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' END", + _.caseExpression(), + Case( + expression = None, + branches = scala.collection.immutable.Seq( + WhenBranch(Equals(Id("col1"), Literal(1)), Literal("one")), + WhenBranch(Equals(Id("col2"), Literal(2)), Literal("two"))), + otherwise = None)) + } + "CASE 'foo' WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' END" in { + exampleExpr( + "CASE 'foo' WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' END", + _.caseExpression(), + Case( + expression = Some(Literal("foo")), + branches = scala.collection.immutable.Seq( + WhenBranch(Equals(Id("col1"), Literal(1)), Literal("one")), + WhenBranch(Equals(Id("col2"), Literal(2)), Literal("two"))), + otherwise = None)) + } + "CASE WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' ELSE 'other' END" in { + exampleExpr( + "CASE WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' ELSE 'other' END", + _.caseExpression(), + Case( + expression = None, + branches = scala.collection.immutable.Seq( + WhenBranch(Equals(Id("col1"), Literal(1)), Literal("one")), + WhenBranch(Equals(Id("col2"), Literal(2)), Literal("two"))), + otherwise = Some(Literal("other")))) + } + "CASE 'foo' WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' ELSE 'other' END" in { + exampleExpr( + "CASE 'foo' WHEN col1 = 1 THEN 'one' WHEN col2 = 2 THEN 'two' ELSE 'other' END", + _.caseExpression(), + Case( + expression = Some(Literal("foo")), + branches = scala.collection.immutable.Seq( + WhenBranch(Equals(Id("col1"), Literal(1)), Literal("one")), + WhenBranch(Equals(Id("col2"), Literal(2)), Literal("two"))), + otherwise = Some(Literal("other")))) + } + } + + "SnowflakeExpressionBuilder.visit_Literal" should { + "handle unresolved child" in { + val literal = mock[LiteralContext] + vc.expressionBuilder.visitLiteral(literal) shouldBe Literal.Null + verify(literal).sign() + verify(literal).id() + verify(literal).TIMESTAMP() + verify(literal).string() + verify(literal).INT() + verify(literal).FLOAT() + verify(literal).REAL() + verify(literal).trueFalse() + verify(literal).NULL() + verify(literal).jsonLiteral() + verify(literal).arrayLiteral() + verifyNoMoreInteractions(literal) + } + } + + "SnowflakeExpressionBuilder.buildComparisonExpression" should { + "handle unresolved child" in { + val operator = mock[ComparisonOperatorContext] + val startTok = new CommonToken(ID, "%%%") + when(operator.getStart).thenReturn(startTok) + when(operator.getStop).thenReturn(startTok) + when(operator.getRuleIndex).thenReturn(SnowflakeParser.RULE_comparisonOperator) + vc.expressionBuilder.buildComparisonExpression(operator, null, null) shouldBe UnresolvedExpression( + ruleText = "Mocked string", + message = "Unknown comparison operator Mocked string in SnowflakeExpressionBuilder.buildComparisonExpression", + ruleName = "comparisonOperator", + tokenName = Some("ID")) + + verify(operator).EQ() + verify(operator).NE() + verify(operator).LTGT() + verify(operator).GT() + verify(operator).LT() + verify(operator).GE() + verify(operator).LE() + verify(operator).getRuleIndex + verify(operator, times(5)).getStart + verify(operator, times(2)).getStop + verifyNoMoreInteractions(operator) + } + } + + // Note that when we truly handle &vars, we will get Variable here and not 'Id' + // and the & parts will not be changed to ${} until we get to the final SQL generation, + // but we are in a half way house transition state + "variable substitution" should { + "&abc" in { + exampleExpr("&abc", _.expr(), Id("$abc")) + } + "&ab_c.bc_d" in { + exampleExpr("&ab_c.bc_d", _.expr(), Dot(Id("$ab_c"), Id("bc_d"))) + } + "&{ab_c}.&bc_d" in { + exampleExpr("&{ab_c}.&bc_d", _.expr(), Dot(Id("$ab_c"), Id("$bc_d"))) + } + } + + "translate :: casts" should { + "ARRAY_REMOVE([2, 3, 4.00::DOUBLE, 4, NULL], 4)" in { + exampleExpr( + "ARRAY_REMOVE([2, 3, 4.00::DOUBLE, 4, NULL], 4)", + _.expr(), + CallFunction( + "ARRAY_REMOVE", + Seq( + ArrayExpr( + Seq(Literal(2), Literal(3), Cast(Literal(4.00), DoubleType), Literal(4), Literal(null)), + IntegerType), + Literal(4)))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParserTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParserTestCommon.scala new file mode 100644 index 0000000000..fa60430263 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParserTestCommon.scala @@ -0,0 +1,22 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.parsers.{ErrorCollector, ParserTestCommon, ProductionErrorCollector} +import org.antlr.v4.runtime.{CharStream, TokenSource, TokenStream} +import org.scalatest.Assertions + +trait SnowflakeParserTestCommon extends ParserTestCommon[SnowflakeParser] { self: Assertions => + + protected val vc: SnowflakeVisitorCoordinator = + new SnowflakeVisitorCoordinator(SnowflakeParser.VOCABULARY, SnowflakeParser.ruleNames) + override final protected def makeLexer(chars: CharStream): TokenSource = new SnowflakeLexer(chars) + + override final protected def makeErrStrategy(): SnowflakeErrorStrategy = new SnowflakeErrorStrategy + + override final protected def makeErrListener(chars: String): ErrorCollector = + new ProductionErrorCollector(chars, "-- test string --") + + override final protected def makeParser(tokenStream: TokenStream): SnowflakeParser = { + val parser = new SnowflakeParser(tokenStream) + parser + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilderSpec.scala new file mode 100644 index 0000000000..ee493e8a8a --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilderSpec.scala @@ -0,0 +1,381 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.parsers.snowflake +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{JoinTypeContext, OuterJoinContext} +import org.antlr.v4.runtime.RuleContext +import org.mockito.Mockito._ +import org.scalatest.Assertion +import org.scalatest.Checkpoints.Checkpoint +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class SnowflakeRelationBuilderSpec + extends AnyWordSpec + with SnowflakeParserTestCommon + with Matchers + with MockitoSugar + with IRHelpers { + + override protected def astBuilder: SnowflakeRelationBuilder = vc.relationBuilder + + private def examples[R <: RuleContext]( + queries: Seq[String], + rule: SnowflakeParser => R, + expectedAst: LogicalPlan): Assertion = { + val cp = new Checkpoint() + queries.foreach(q => cp(example(q, rule, expectedAst))) + cp.reportAll() + succeed + } + + "SnowflakeRelationBuilder" should { + + "translate query with no FROM clause" in { + example("", _.selectOptionalClauses(), NoTable) + } + + "translate FROM clauses" should { + "FROM some_table" in { + example("FROM some_table", _.fromClause(), namedTable("some_table")) + } + "FROM t1, t2, t3" in { + example( + "FROM t1, t2, t3", + _.fromClause(), + Join( + Join( + namedTable("t1"), + namedTable("t2"), + None, + InnerJoin, + Seq(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + namedTable("t3"), + None, + InnerJoin, + Seq(), + JoinDataType(is_left_struct = false, is_right_struct = false))) + } + "FROM (SELECT * FROM t1) t2" in { + example( + "FROM (SELECT * FROM t1) t2", + _.fromClause(), + SubqueryAlias(Project(namedTable("t1"), Seq(Star(None))), Id("t2"), Seq())) + } + } + + "translate WHERE clauses" in { + example( + "FROM some_table WHERE 1=1", + _.selectOptionalClauses(), + Filter(namedTable("some_table"), Equals(Literal(1), Literal(1)))) + } + + "translate GROUP BY clauses" should { + "FROM some_table GROUP BY some_column" in { + example( + "FROM some_table GROUP BY some_column", + _.selectOptionalClauses(), + Aggregate( + namedTable("some_table"), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None)) + } + "FROM t1 PIVOT (AVG(a) FOR d IN('x', 'y'))" in { + example( + "FROM t1 PIVOT (AVG(a) FOR d IN('x', 'y'))", + _.selectOptionalClauses(), + Aggregate( + namedTable("t1"), + group_type = Pivot, + grouping_expressions = Seq(CallFunction("AVG", Seq(simplyNamedColumn("a")))), + pivot = Some(Pivot(simplyNamedColumn("d"), Seq(Literal("x"), Literal("y")))))) + } + "FROM t1 PIVOT (COUNT(a) FOR d IN('x', 'y'))" in { + example( + "FROM t1 PIVOT (COUNT(a) FOR d IN('x', 'y'))", + _.selectOptionalClauses(), + Aggregate( + namedTable("t1"), + group_type = Pivot, + grouping_expressions = Seq(CallFunction("COUNT", Seq(simplyNamedColumn("a")))), + pivot = Some(Pivot(simplyNamedColumn("d"), Seq(Literal("x"), Literal("y")))))) + } + "FROM t1 PIVOT (MIN(a) FOR d IN('x', 'y'))" in { + example( + "FROM t1 PIVOT (MIN(a) FOR d IN('x', 'y'))", + _.selectOptionalClauses(), + Aggregate( + namedTable("t1"), + group_type = Pivot, + grouping_expressions = Seq(CallFunction("MIN", Seq(simplyNamedColumn("a")))), + pivot = Some(Pivot(simplyNamedColumn("d"), Seq(Literal("x"), Literal("y")))))) + } + } + + "translate ORDER BY clauses" should { + "FROM some_table ORDER BY some_column" in { + example( + "FROM some_table ORDER BY some_column", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Ascending, NullsLast)), is_global = false)) + } + "FROM some_table ORDER BY some_column ASC" in { + example( + "FROM some_table ORDER BY some_column ASC", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Ascending, NullsLast)), is_global = false)) + } + "FROM some_table ORDER BY some_column ASC NULLS FIRST" in { + example( + "FROM some_table ORDER BY some_column ASC NULLS FIRST", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Ascending, NullsFirst)), is_global = false)) + } + "FROM some_table ORDER BY some_column DESC" in { + example( + "FROM some_table ORDER BY some_column DESC", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Descending, NullsFirst)), is_global = false)) + } + "FROM some_table ORDER BY some_column DESC NULLS LAST" in { + example( + "FROM some_table ORDER BY some_column DESC NULLS LAST", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Descending, NullsLast)), is_global = false)) + } + "FROM some_table ORDER BY some_column DESC NULLS FIRST" in { + example( + "FROM some_table ORDER BY some_column DESC NULLS FIRST", + _.selectOptionalClauses(), + Sort(namedTable("some_table"), Seq(SortOrder(Id("some_column"), Descending, NullsFirst)), is_global = false)) + } + } + + "translate SAMPLE clauses" should { + "probabilistic" in { + examples( + Seq("t1 SAMPLE (1)", "t1 TABLESAMPLE (1)", "t1 SAMPLE BERNOULLI (1)", "t1 TABLESAMPLE BERNOULLI (1)"), + _.tableSource(), + TableSample(namedTable("t1"), RowSamplingProbabilistic(BigDecimal(1)), None)) + } + "fixed" in { + examples( + Seq( + "t1 SAMPLE (1 ROWS)", + "t1 TABLESAMPLE (1 ROWS)", + "t1 SAMPLE BERNOULLI (1 ROWS)", + "t1 TABLESAMPLE BERNOULLI (1 ROWS)"), + _.tableSource(), + TableSample(namedTable("t1"), RowSamplingFixedAmount(BigDecimal(1)), None)) + } + "block" in { + examples( + Seq("t1 SAMPLE BLOCK (1)", "t1 TABLESAMPLE BLOCK (1)", "t1 SAMPLE SYSTEM (1)", "t1 TABLESAMPLE SYSTEM (1)"), + _.tableSource(), + TableSample(namedTable("t1"), BlockSampling(BigDecimal(1)), None)) + } + "seed" in { + examples( + Seq("t1 SAMPLE (1) SEED (1234)", "t1 SAMPLE (1) REPEATABLE (1234)"), + _.tableSource(), + TableSample(namedTable("t1"), RowSamplingProbabilistic(BigDecimal(1)), Some(BigDecimal(1234)))) + } + } + + "translate combinations of the above" should { + "FROM some_table WHERE 1=1 GROUP BY some_column" in { + example( + "FROM some_table WHERE 1=1 GROUP BY some_column", + _.selectOptionalClauses(), + Aggregate( + child = Filter(namedTable("some_table"), Equals(Literal(1), Literal(1))), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None)) + + } + "FROM some_table WHERE 1=1 GROUP BY some_column ORDER BY some_column NULLS FIRST" in { + example( + "FROM some_table WHERE 1=1 GROUP BY some_column ORDER BY some_column NULLS FIRST", + _.selectOptionalClauses(), + Sort( + Aggregate( + child = Filter(namedTable("some_table"), Equals(Literal(1), Literal(1))), + group_type = GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None), + Seq(SortOrder(Id("some_column"), Ascending, NullsFirst)), + is_global = false)) + } + "FROM some_table WHERE 1=1 ORDER BY some_column NULLS FIRST" in { + example( + "FROM some_table WHERE 1=1 ORDER BY some_column NULLS FIRST", + _.selectOptionalClauses(), + Sort( + Filter(namedTable("some_table"), Equals(Literal(1), Literal(1))), + Seq(SortOrder(Id("some_column"), Ascending, NullsFirst)), + is_global = false)) + } + } + + "translate CTE definitions" should { + "WITH a AS (SELECT x, y FROM d)" in { + example( + "WITH a AS (SELECT x, y FROM d)", + _.withExpression(), + SubqueryAlias(Project(namedTable("d"), Seq(Id("x"), Id("y"))), Id("a"), Seq())) + } + "WITH a (b, c) AS (SELECT x, y FROM d)" in { + example( + "WITH a (b, c) AS (SELECT x, y FROM d)", + _.withExpression(), + SubqueryAlias(Project(namedTable("d"), Seq(Id("x"), Id("y"))), Id("a"), Seq(Id("b"), Id("c")))) + } + } + + "translate QUALIFY clauses" in { + example( + "FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", + _.selectOptionalClauses(), + Filter( + input = namedTable("qt"), + condition = Equals( + Window( + window_function = CallFunction("ROW_NUMBER", Seq()), + partition_spec = Seq(Id("p")), + sort_order = Seq(SortOrder(Id("o"), Ascending, NullsLast)), + frame_spec = None), + Literal(1)))) + } + + "translate SELECT DISTINCT clauses" in { + example( + "SELECT DISTINCT a, b AS bb FROM t", + _.selectStatement(), + Deduplicate( + namedTable("t"), + column_names = Seq(Id("a"), Alias(Id("b"), Id("bb"))), + all_columns_as_keys = false, + within_watermark = false)) + } + + "translate SELECT TOP clauses" should { + "SELECT TOP 42 a FROM t" in { + example( + "SELECT TOP 42 a FROM t", + _.selectStatement(), + Project(Limit(namedTable("t"), Literal(42)), Seq(Id("a")))) + } + "SELECT DISTINCT TOP 42 a FROM t" in { + example( + "SELECT DISTINCT TOP 42 a FROM t", + _.selectStatement(), + Limit( + Deduplicate(namedTable("t"), Seq(Id("a")), all_columns_as_keys = false, within_watermark = false), + Literal(42))) + } + } + + "translate VALUES clauses as object references" in { + example( + "VALUES ('a', 1), ('b', 2)", + _.objectRef(), + Values(Seq(Seq(Literal("a"), Literal(1)), Seq(Literal("b"), Literal(2))))) + } + + "do not confuse VALUES clauses with a single row with a function call" in { + example("VALUES (1, 2, 3)", _.objectRef(), Values(Seq(Seq(Literal(1), Literal(2), Literal(3))))) + } + + "translate table functions as object references" should { + "TABLE(some_func(some_arg))" in { + example( + "TABLE(some_func(some_arg))", + _.objectRef(), + TableFunction( + UnresolvedFunction( + "some_func", + Seq(Id("some_arg")), + is_distinct = false, + is_user_defined_function = false, + ruleText = "some_func(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function some_func is not convertible to Databricks SQL"))) + } + "TABLE(some_func(some_arg)) t(c1, c2, c3)" in { + example( + "TABLE(some_func(some_arg)) t(c1, c2, c3)", + _.objectRef(), + SubqueryAlias( + TableFunction( + UnresolvedFunction( + "some_func", + Seq(Id("some_arg")), + is_distinct = false, + is_user_defined_function = false, + ruleText = "some_func(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function some_func is not convertible to Databricks SQL")), + Id("t"), + Seq(Id("c1"), Id("c2"), Id("c3")))) + } + } + + "translate LATERAL FLATTEN object references" should { + "LATERAL FLATTEN (input => some_col, OUTER => true)" in { + example( + "LATERAL FLATTEN (input => some_col, OUTER => true)", + _.objectRef(), + Lateral( + TableFunction( + CallFunction( + "FLATTEN", + Seq( + snowflake.NamedArgumentExpression("INPUT", Id("some_col")), + snowflake.NamedArgumentExpression("OUTER", Literal.True)))))) + } + "LATERAL FLATTEN (input => some_col) AS t" in { + example( + "LATERAL FLATTEN (input => some_col) AS t", + _.objectRef(), + SubqueryAlias( + Lateral( + TableFunction(CallFunction("FLATTEN", Seq(snowflake.NamedArgumentExpression("INPUT", Id("some_col")))))), + Id("t"), + Seq())) + } + } + } + + "Unparsed input" should { + "be reported as UnresolvedRelation" in { + example( + "MATCH_RECOGNIZE()", + _.matchRecognize(), + UnresolvedRelation( + ruleText = "MATCH_RECOGNIZE()", + message = "Unimplemented visitor visitMatchRecognize in class SnowflakeRelationBuilder", + ruleName = "matchRecognize", + tokenName = Some("MATCH_RECOGNIZE"))) + } + } + + "SnowflakeRelationBuilder.translateJoinType" should { + "handle unresolved join type" in { + val outerJoin = mock[OuterJoinContext] + val joinType = mock[JoinTypeContext] + when(joinType.outerJoin()).thenReturn(outerJoin) + vc.relationBuilder.translateJoinType(joinType) shouldBe UnspecifiedJoin + verify(outerJoin).LEFT() + verify(outerJoin).RIGHT() + verify(outerJoin).FULL() + verify(joinType).outerJoin() + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilderSpec.scala new file mode 100644 index 0000000000..27e624a879 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeTypeBuilderSpec.scala @@ -0,0 +1,112 @@ +package com.databricks.labs.remorph.parsers.snowflake + +import com.databricks.labs.remorph.intermediate._ +import org.antlr.v4.runtime.tree.ParseTreeVisitor +import org.scalatest.Assertion +import org.scalatest.matchers.should +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeTypeBuilderSpec extends AnyWordSpec with SnowflakeParserTestCommon with should.Matchers { + + private def example(query: String, expectedDataType: DataType): Assertion = { + assert(new SnowflakeTypeBuilder().buildDataType(parseString(query, _.dataType())) === expectedDataType) + } + "SnowflakeTypeBuilder" should { + "precision types" in { + example("CHAR(1)", StringType) + example("CHARACTER(1)", StringType) + example("VARCHAR(10)", StringType) + example("DECIMAL(10, 2)", DecimalType(Some(10), Some(2))) + example("NUMBER(10, 2)", DecimalType(Some(10), Some(2))) + example("NUMERIC(10, 2)", DecimalType(Some(10), Some(2))) + } + + "ARRAY(INTEGER)" in { + example("ARRAY(INTEGER)", ArrayType(defaultNumber)) + } + "ARRAY" in { + example("ARRAY", ArrayType(UnresolvedType)) + } + "BIGINT" in { + example("BIGINT", defaultNumber) + } + "BINARY" in { + example("BINARY", BinaryType) + } + "BOOLEAN" in { + example("BOOLEAN", BooleanType) + } + "BYTEINT" in { + example("BYTEINT", defaultNumber) + } + "DATE" in { + example("DATE", DateType) + } + "DOUBLE" in { + example("DOUBLE", DoubleType) + } + "DOUBLE PRECISION" in { + example("DOUBLE PRECISION", DoubleType) + } + "FLOAT" in { + example("FLOAT", DoubleType) + } + "FLOAT4" in { + example("FLOAT4", DoubleType) + } + "FLOAT8" in { + example("FLOAT8", DoubleType) + } + "INT" in { + example("INT", defaultNumber) + } + "INTEGER" in { + example("INTEGER", defaultNumber) + } + "OBJECT" in { + example("OBJECT", UnparsedType("OBJECT")) + } + "REAL" in { + example("REAL", DoubleType) + } + "SMALLINT" in { + example("SMALLINT", defaultNumber) + } + "STRING" in { + example("STRING", StringType) + } + "TEXT" in { + example("TEXT", StringType) + } + "TIME" in { + example("TIME", TimestampType) + } + "TIMESTAMP" in { + example("TIMESTAMP", TimestampType) + } + "TIMESTAMP_LTZ" in { + example("TIMESTAMP_LTZ", TimestampType) + } + "TIMESTAMP_NTZ" in { + example("TIMESTAMP_NTZ", TimestampNTZType) + } + "TIMESTAMP_TZ" in { + example("TIMESTAMP_TZ", TimestampType) + } + "TINYINT" in { + example("TINYINT", TinyintType) + } + "VARBINARY" in { + example("VARBINARY", BinaryType) + } + "VARIANT" in { + example("VARIANT", VariantType) + } + } + + private def defaultNumber = { + DecimalType(Some(38), Some(0)) + } + + override protected def astBuilder: ParseTreeVisitor[_] = null +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAsSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAsSpec.scala new file mode 100644 index 0000000000..372cdb29cc --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/DealiasLCAsSpec.scala @@ -0,0 +1,28 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.{intermediate => ir} + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class DealiasLCAsSpec extends AnyWordSpec with Matchers { + + val dealiaser = new DealiasLCAs + + "DealiasLCAs" should { + + "dealias a LCA" in { + val plan = + ir.Project( + ir.Filter(ir.NoTable, ir.GreaterThan(ir.Id("abs"), ir.Literal(42))), + Seq(ir.Alias(ir.Abs(ir.Literal(-42)), ir.Id("abs")))) + + dealiaser.transformPlan(plan) shouldBe + ir.Project( + ir.Filter(ir.NoTable, ir.GreaterThan(ir.Abs(ir.Literal(-42)), ir.Literal(42))), + Seq(ir.Alias(ir.Abs(ir.Literal(-42)), ir.Id("abs")))) + + } + + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcatSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcatSpec.scala new file mode 100644 index 0000000000..c915429b60 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/FlattenNestedConcatSpec.scala @@ -0,0 +1,47 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalactic.source.Position +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class FlattenNestedConcatSpec extends AnyWordSpec with Matchers { + + val optimizer = new FlattenNestedConcat + + implicit class Ops(e: ir.Expression) { + def becomes(expected: ir.Expression)(implicit pos: Position): Assertion = { + optimizer.flattenConcat.applyOrElse(e, (x: ir.Expression) => x) shouldBe expected + } + } + + "FlattenNestedConcat" should { + "CONCAT(a, b)" in { + ir.Concat(Seq(ir.Id("a"), ir.Id("b"))) becomes ir.Concat(Seq(ir.Id("a"), ir.Id("b"))) + } + + "CONCAT(CONCAT(a, b), c)" in { + ir.Concat(Seq(ir.Concat(Seq(ir.Id("a"), ir.Id("b"))), ir.Id("c"))) becomes ir.Concat( + Seq(ir.Id("a"), ir.Id("b"), ir.Id("c"))) + } + + "CONCAT(a, CONCAT(b, c))" in { + ir.Concat(Seq(ir.Id("a"), ir.Concat(Seq(ir.Id("b"), ir.Id("c"))))) becomes ir.Concat( + Seq(ir.Id("a"), ir.Id("b"), ir.Id("c"))) + } + + "CONCAT(CONCAT(a, b), CONCAT(c, d))" in { + ir.Concat(Seq(ir.Concat(Seq(ir.Id("a"), ir.Id("b"))), ir.Concat(Seq(ir.Id("c"), ir.Id("d"))))) becomes ir.Concat( + Seq(ir.Id("a"), ir.Id("b"), ir.Id("c"), ir.Id("d"))) + } + + "CONCAT(CONCAT(a, b), CONCAT(c, CONCAT(d, e)))" in { + ir.Concat( + Seq( + ir.Concat(Seq(ir.Id("a"), ir.Id("b"))), + ir.Concat(Seq(ir.Id("c"), ir.Concat(Seq(ir.Id("d"), ir.Id("e"))))))) becomes ir.Concat( + Seq(ir.Id("a"), ir.Id("b"), ir.Id("c"), ir.Id("d"), ir.Id("e"))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapperSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapperSpec.scala new file mode 100644 index 0000000000..1483f8ea28 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/rules/SnowflakeCallMapperSpec.scala @@ -0,0 +1,318 @@ +package com.databricks.labs.remorph.parsers.snowflake.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeCallMapperSpec extends AnyWordSpec with Matchers { + + private[this] val snowflakeCallMapper = new SnowflakeCallMapper + + implicit class CallMapperOps(fn: ir.Fn) { + def becomes(expected: ir.Expression): Assertion = { + snowflakeCallMapper.convert(fn) shouldBe expected + } + } + + "SnowflakeCallMapper" should { + "translate Snowflake functions" in { + + ir.CallFunction("ARRAY_CAT", Seq(ir.Noop)) becomes ir.Concat(Seq(ir.Noop)) + + ir.CallFunction("ARRAY_CONSTRUCT", Seq(ir.Noop)) becomes ir.CreateArray(Seq(ir.Noop)) + + ir.CallFunction("BOOLAND_AGG", Seq(ir.Noop)) becomes ir.BoolAnd(ir.Noop) + + ir.CallFunction("DATEADD", Seq(ir.Literal(1), ir.Literal(2))) becomes ir.DateAdd(ir.Literal(1), ir.Literal(2)) + + ir.CallFunction("EDITDISTANCE", Seq(ir.Literal(1), ir.Literal(2))) becomes ir.Levenshtein( + ir.Literal(1), + ir.Literal(2), + None) + + ir.CallFunction("IFNULL", Seq(ir.Noop)) becomes ir.Coalesce(Seq(ir.Noop)) + + ir.CallFunction("JSON_EXTRACT_PATH_TEXT", Seq(ir.Noop, ir.Literal("foo"))) becomes ir.GetJsonObject( + ir.Noop, + ir.Literal("$.foo")) + + ir.CallFunction("JSON_EXTRACT_PATH_TEXT", Seq(ir.Noop, ir.Id("foo"))) becomes ir.GetJsonObject( + ir.Noop, + ir.Concat(Seq(ir.Literal("$."), ir.Id("foo")))) + + ir.CallFunction("LEN", Seq(ir.Noop)) becomes ir.Length(ir.Noop) + + ir.CallFunction("LISTAGG", Seq(ir.Literal(1), ir.Literal(2))) becomes ir.ArrayJoin( + ir.CollectList(ir.Literal(1), None), + ir.Literal(2), + None) + + ir.CallFunction("MONTHNAME", Seq(ir.Noop)) becomes ir.DateFormatClass(ir.Noop, ir.Literal("MMM")) + + ir.CallFunction("OBJECT_KEYS", Seq(ir.Noop)) becomes ir.JsonObjectKeys(ir.Noop) + + ir.CallFunction("POSITION", Seq(ir.Noop)) becomes ir.CallFunction("LOCATE", Seq(ir.Noop)) + + ir.CallFunction("REGEXP_LIKE", Seq(ir.Literal(1), ir.Literal(2))) becomes ir.RLike(ir.Literal(1), ir.Literal(2)) + + ir.CallFunction("SPLIT_PART", Seq(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(0))) becomes ir + .StringSplitPart(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(1)) + + ir.CallFunction("SPLIT_PART", Seq(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(1))) becomes ir + .StringSplitPart(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(1)) + + ir.CallFunction("SPLIT_PART", Seq(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(4))) becomes ir + .StringSplitPart(ir.Literal("foo,bar"), ir.Literal(","), ir.Literal(4)) + + ir.CallFunction("SPLIT_PART", Seq(ir.Literal("foo,bar"), ir.Literal(","), ir.Id("c1"))) becomes ir + .StringSplitPart( + ir.Literal("foo,bar"), + ir.Literal(","), + ir.If(ir.Equals(ir.Id("c1"), ir.Literal(0)), ir.Literal(1), ir.Id("c1"))) + + ir.CallFunction("SQUARE", Seq(ir.Noop)) becomes ir.Pow(ir.Noop, ir.Literal(2)) + + ir.CallFunction("STRTOK_TO_ARRAY", Seq(ir.Literal("abc,def"), ir.Literal(","))) becomes ir.StringSplit( + ir.Literal("abc,def"), + ir.Literal("[,]"), + None) + + ir.CallFunction("STRTOK_TO_ARRAY", Seq(ir.Literal("abc,def"), ir.Id("c1"))) becomes ir.StringSplit( + ir.Literal("abc,def"), + ir.Concat(Seq(ir.Literal("["), ir.Id("c1"), ir.Literal("]"))), + None) + + ir.CallFunction("TO_DOUBLE", Seq(ir.Noop)) becomes ir.CallFunction("DOUBLE", Seq(ir.Noop)) + + ir.CallFunction("TO_NUMBER", Seq(ir.Literal("$123.5"), ir.Literal("$999.0"))) becomes ir.ToNumber( + ir.Literal("$123.5"), + ir.Literal("$999.0")) + + ir.CallFunction("TO_NUMBER", Seq(ir.Literal("$123.5"), ir.Literal("$999.0"), ir.Literal(26))) becomes ir.Cast( + ir.ToNumber(ir.Literal("$123.5"), ir.Literal("$999.0")), + ir.DecimalType(Some(26), None)) + + ir.CallFunction( + "TO_NUMBER", + Seq(ir.Literal("$123.5"), ir.Literal("$999.0"), ir.Literal(26), ir.Literal(4))) becomes ir.Cast( + ir.ToNumber(ir.Literal("$123.5"), ir.Literal("$999.0")), + ir.DecimalType(Some(26), Some(4))) + + ir.CallFunction("TO_NUMBER", Seq(ir.Literal("$123.5"), ir.Literal(26), ir.Literal(4))) becomes ir.Cast( + ir.Literal("$123.5"), + ir.DecimalType(Some(26), Some(4))) + + ir.CallFunction("TO_OBJECT", Seq(ir.Literal(1), ir.Literal(2))) becomes ir.StructsToJson( + ir.Literal(1), + Some(ir.Literal(2))) + + ir.CallFunction("TRY_TO_NUMBER", Seq(ir.Literal("$123.5"), ir.Literal("$999.0"), ir.Literal(26))) becomes ir.Cast( + ir.TryToNumber(ir.Literal("$123.5"), ir.Literal("$999.0")), + ir.DecimalType(Some(26), Some(0))) + + ir.CallFunction( + "TRY_TO_NUMBER", + Seq(ir.Literal("$123.5"), ir.Literal("$999.0"), ir.Literal(26), ir.Literal(4))) becomes ir.Cast( + ir.TryToNumber(ir.Literal("$123.5"), ir.Literal("$999.0")), + ir.DecimalType(Some(26), Some(4))) + + ir.CallFunction("TRY_TO_NUMBER", Seq(ir.Literal("$123.5"), ir.Literal(26), ir.Literal(4))) becomes ir.Cast( + ir.Literal("$123.5"), + ir.DecimalType(Some(26), Some(4))) + + ir.CallFunction( + "MONTHS_BETWEEN", + Seq(ir.Cast(ir.Literal("2021-01-01"), ir.DateType), ir.Cast(ir.Literal("2021-02-01"), ir.DateType))) becomes ir + .MonthsBetween( + ir.Cast(ir.Literal("2021-01-01"), ir.DateType), + ir.Cast(ir.Literal("2021-02-01"), ir.DateType), + ir.Literal.True) + + ir.CallFunction( + "MONTHS_BETWEEN", + Seq( + ir.Cast(ir.Literal("2020-05-01 10:00:00"), ir.TimestampType), + ir.Cast(ir.Literal("2020-04-15 08:00:00"), ir.TimestampType))) becomes ir.MonthsBetween( + ir.Cast(ir.Literal("2020-05-01 10:00:00"), ir.TimestampType), + ir.Cast(ir.Literal("2020-04-15 08:00:00"), ir.TimestampType), + ir.Literal.True) + + ir.CallFunction( + "ARRAY_SORT", + Seq( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.Literal.True, + ir.Literal.True)) becomes ir.SortArray( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + None) + + ir.CallFunction( + "ARRAY_SORT", + Seq( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.Literal.False, + ir.Literal.False)) becomes ir.SortArray( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + Some(ir.Literal.False)) + + ir.CallFunction( + "ARRAY_SORT", + Seq( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.Literal.True, + ir.Literal.False)) becomes ir.ArraySort( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.LambdaFunction( + ir.Case( + None, + Seq( + ir.WhenBranch(ir.And(ir.IsNull(ir.Id("left")), ir.IsNull(ir.Id("right"))), ir.Literal(0)), + ir.WhenBranch(ir.IsNull(ir.Id("left")), ir.Literal(1)), + ir.WhenBranch(ir.IsNull(ir.Id("right")), ir.Literal(-1)), + ir.WhenBranch(ir.LessThan(ir.Id("left"), ir.Id("right")), ir.Literal(-1)), + ir.WhenBranch(ir.GreaterThan(ir.Id("left"), ir.Id("right")), ir.Literal(1))), + Some(ir.Literal(0))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("left")), ir.UnresolvedNamedLambdaVariable(Seq("right"))))) + } + + ir.CallFunction( + "ARRAY_SORT", + Seq( + ir.CreateArray( + Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.Literal.False, + ir.Literal.True)) becomes ir.ArraySort( + ir.CreateArray(Seq(ir.Literal(0), ir.Literal(2), ir.Literal(4), ir.Literal.Null, ir.Literal(5), ir.Literal.Null)), + ir.LambdaFunction( + ir.Case( + None, + Seq( + ir.WhenBranch(ir.And(ir.IsNull(ir.Id("left")), ir.IsNull(ir.Id("right"))), ir.Literal(0)), + ir.WhenBranch(ir.IsNull(ir.Id("left")), ir.Literal(-1)), + ir.WhenBranch(ir.IsNull(ir.Id("right")), ir.Literal(1)), + ir.WhenBranch(ir.LessThan(ir.Id("left"), ir.Id("right")), ir.Literal(1)), + ir.WhenBranch(ir.GreaterThan(ir.Id("left"), ir.Id("right")), ir.Literal(-1))), + Some(ir.Literal(0))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("left")), ir.UnresolvedNamedLambdaVariable(Seq("right"))))) + + ir.CallFunction("SUBSTR", Seq(ir.Literal("Hello"), ir.Literal(1), ir.Literal(3))) becomes + ir.Substring(ir.Literal("Hello"), ir.Literal(1), Some(ir.Literal(3))) + + "ARRAY_SLICE index shift" in { + ir.CallFunction("ARRAY_SLICE", Seq(ir.Id("arr1"), ir.IntLiteral(0), ir.IntLiteral(2))) becomes ir.Slice( + ir.Id("arr1"), + ir.IntLiteral(1), + ir.IntLiteral(2)) + + ir.CallFunction("ARRAY_SLICE", Seq(ir.Id("arr1"), ir.UMinus(ir.IntLiteral(2)), ir.IntLiteral(2))) becomes ir + .Slice(ir.Id("arr1"), ir.UMinus(ir.IntLiteral(2)), ir.IntLiteral(2)) + + ir.CallFunction("ARRAY_SLICE", Seq(ir.Id("arr1"), ir.Id("col1"), ir.IntLiteral(2))) becomes ir + .Slice( + ir.Id("arr1"), + ir.If( + ir.GreaterThanOrEqual(ir.Id("col1"), ir.IntLiteral(0)), + ir.Add(ir.Id("col1"), ir.IntLiteral(1)), + ir.Id("col1")), + ir.IntLiteral(2)) + } + "REGEXP_SUBSTR" in { + ir.CallFunction("REGEXP_SUBSTR", Seq(ir.Literal("foo"), ir.Literal("f.."))) becomes + ir.RegExpExtract(ir.Literal("foo"), ir.Literal("f.."), Some(ir.Literal(0))) + + ir.CallFunction("REGEXP_SUBSTR", Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1))) becomes + ir.RegExpExtract(ir.Substring(ir.Literal("foo"), ir.Literal(1)), ir.Literal("f.."), Some(ir.Literal(0))) + + ir.CallFunction("REGEXP_SUBSTR", Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1), ir.Literal(1))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll(ir.Substring(ir.Literal("foo"), ir.Literal(1)), ir.Literal("f.."), Some(ir.Literal(0))), + ir.Literal(0)) + + ir.CallFunction( + "REGEXP_SUBSTR", + Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1), ir.Id("occurrence"))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll(ir.Substring(ir.Literal("foo"), ir.Literal(1)), ir.Literal("f.."), Some(ir.Literal(0))), + ir.Subtract(ir.Id("occurrence"), ir.Literal(1))) + + ir.CallFunction( + "REGEXP_SUBSTR", + Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1), ir.Id("occurrence"), ir.Literal("icmes"))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll( + ir.Substring(ir.Literal("foo"), ir.Literal(1)), + ir.Literal("(?ms)f.."), + Some(ir.Literal(0))), + ir.Subtract(ir.Id("occurrence"), ir.Literal(1))) + + ir.CallFunction( + "REGEXP_SUBSTR", + Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1), ir.Id("occurrence"), ir.Literal("icmesi"))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll( + ir.Substring(ir.Literal("foo"), ir.Literal(1)), + ir.Literal("(?msi)f.."), + Some(ir.Literal(0))), + ir.Subtract(ir.Id("occurrence"), ir.Literal(1))) + + ir.CallFunction( + "REGEXP_SUBSTR", + Seq(ir.Literal("foo"), ir.Literal("f.."), ir.Literal(1), ir.Id("occurrence"), ir.Id("regex_params"))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll( + ir.Substring(ir.Literal("foo"), ir.Literal(1)), + ir.ArrayAggregate( + ir.StringSplit(ir.Id("regex_params"), ir.Literal(""), None), + ir.Cast(ir.CreateArray(Seq()), ir.ArrayType(ir.StringType)), + ir.LambdaFunction( + ir.Case( + expression = None, + branches = Seq( + ir.WhenBranch( + ir.Equals(ir.Id("item"), ir.Literal("c")), + ir.ArrayFilter( + ir.Id("agg"), + ir.LambdaFunction( + ir.NotEquals(ir.Id("item"), ir.Literal("i")), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("item")))))), + ir.WhenBranch( + ir.In(ir.Id("item"), Seq(ir.Literal("i"), ir.Literal("s"), ir.Literal("m"))), + ir.ArrayAppend(ir.Id("agg"), ir.Id("item")))), + otherwise = Some(ir.Id("agg"))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("agg")), ir.UnresolvedNamedLambdaVariable(Seq("item")))), + ir.LambdaFunction( + ir.Concat( + Seq( + ir.Literal("(?"), + ir.ArrayJoin(ir.Id("filtered"), ir.Literal("")), + ir.Literal(")"), + ir.Literal("f.."))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("filtered"))))), + Some(ir.Literal(0))), + ir.Subtract(ir.Id("occurrence"), ir.Literal(1))) + + ir.CallFunction( + "REGEXP_SUBSTR", + Seq( + ir.Literal("foo"), + ir.Literal("(f.)."), + ir.Literal(1), + ir.Id("occurrence"), + ir.Literal("icmesi"), + ir.Literal(1))) becomes + ir.ArrayAccess( + ir.RegExpExtractAll( + ir.Substring(ir.Literal("foo"), ir.Literal(1)), + ir.Literal("(?msi)(f.)."), + Some(ir.Literal(1))), + ir.Subtract(ir.Id("occurrence"), ir.Literal(1))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilderSpec.scala new file mode 100644 index 0000000000..fc7484896a --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlAstBuilderSpec.scala @@ -0,0 +1,983 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate._ +import com.databricks.labs.remorph.parsers.tsql +import com.databricks.labs.remorph.parsers.tsql.rules.TopPercent +import org.mockito.Mockito.{mock, when} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlAstBuilderSpec extends AnyWordSpec with TSqlParserTestCommon with Matchers with IRHelpers { + + override protected def astBuilder: TSqlParserBaseVisitor[_] = vc.astBuilder + + private def example(query: String, expectedAst: LogicalPlan): Unit = + example(query, _.tSqlFile(), expectedAst) + + private def singleQueryExample(query: String, expectedAst: LogicalPlan): Unit = + example(query, _.tSqlFile(), Batch(Seq(expectedAst))) + + "tsql visitor" should { + + "accept empty child" in { + example(query = "", expectedAst = Batch(Seq.empty)) + } + + "translate a simple SELECT query" in { + example( + query = "SELECT a FROM dbo.table_x", + expectedAst = Batch(Seq(Project(namedTable("dbo.table_x"), Seq(simplyNamedColumn("a")))))) + + example( + query = "SELECT a FROM TABLE", + expectedAst = Batch(Seq(Project(namedTable("TABLE"), Seq(simplyNamedColumn("a")))))) + } + + "translate column aliases" in { + example( + query = "SELECT a AS b, J = BigCol FROM dbo.table_x", + expectedAst = Batch( + Seq( + Project( + namedTable("dbo.table_x"), + Seq(Alias(simplyNamedColumn("a"), Id("b")), Alias(simplyNamedColumn("BigCol"), Id("J"))))))) + } + + "accept constants in selects" in { + example( + query = "SELECT 42, 65535, 6.4, 0x5A, 2.7E9, 4.24523534425245E10, $40", + expectedAst = Batch( + Seq( + Project( + NoTable, + Seq( + Literal(42), + Literal(65535), + Literal(6.4f), + Literal("0x5A"), + Literal(2700000000L), + Literal(4.24523534425245e10), + Money(StringLiteral("$40"))))))) + } + + "translate collation specifiers" in { + example( + query = "SELECT a COLLATE Latin1_General_BIN FROM dbo.table_x", + expectedAst = + Batch(Seq(Project(namedTable("dbo.table_x"), Seq(Collate(simplyNamedColumn("a"), "Latin1_General_BIN")))))) + } + + "translate table source items with aliases" in { + example( + query = "SELECT a FROM dbo.table_x AS t", + expectedAst = Batch(Seq(Project(TableAlias(namedTable("dbo.table_x"), "t"), Seq(simplyNamedColumn("a")))))) + } + + "translate table sources involving *" in { + example( + query = "SELECT * FROM dbo.table_x", + expectedAst = Batch(Seq(Project(namedTable("dbo.table_x"), Seq(Star(None)))))) + + example( + query = "SELECT t.*", + expectedAst = Batch(Seq(Project(NoTable, Seq(Star(objectName = Some(ObjectReference(Id("t"))))))))) + + example( + query = "SELECT x..b.y.*", + expectedAst = + Batch(Seq(Project(NoTable, Seq(Star(objectName = Some(ObjectReference(Id("x"), Id("b"), Id("y"))))))))) + + // TODO: Add tests for OUTPUT clause once implemented - invalid semantics here to force coverage + example(query = "SELECT INSERTED.*", expectedAst = Batch(Seq(Project(NoTable, Seq(Inserted(Star(None))))))) + example(query = "SELECT DELETED.*", expectedAst = Batch(Seq(Project(NoTable, Seq(Deleted(Star(None))))))) + } + + "infer a cross join" in { + example( + query = "SELECT a, b, c FROM dbo.table_x, dbo.table_y", + expectedAst = Batch( + Seq(Project( + Join( + namedTable("dbo.table_x"), + namedTable("dbo.table_y"), + None, + CrossJoin, + Seq.empty, + JoinDataType(is_left_struct = false, is_right_struct = false)), + Seq(simplyNamedColumn("a"), simplyNamedColumn("b"), simplyNamedColumn("c")))))) + } + + val t1aCol = Column(Some(ObjectReference(Id("T1"))), Id("A")) + val t2aCol = Column(Some(ObjectReference(Id("T2"))), Id("A")) + val t3aCol = Column(Some(ObjectReference(Id("T3"))), Id("A")) + val t1bCol = Column(Some(ObjectReference(Id("T1"))), Id("B")) + val t2bCol = Column(Some(ObjectReference(Id("T2"))), Id("B")) + val t3bCol = Column(Some(ObjectReference(Id("T3"))), Id("B")) + + "translate a query with a JOIN" in { + + example( + query = "SELECT T1.A, T2.B FROM DBO.TABLE_X AS T1 INNER JOIN DBO.TABLE_Y AS T2 ON T1.A = T2.A AND T1.B = T2.B", + expectedAst = Batch( + Seq(Project( + Join( + TableAlias(namedTable("DBO.TABLE_X"), "T1"), + TableAlias(namedTable("DBO.TABLE_Y"), "T2"), + Some(And(Equals(t1aCol, t2aCol), Equals(t1bCol, t2bCol))), + InnerJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + List(t1aCol, t2bCol))))) + } + "translate a query with Multiple JOIN AND Condition" in { + + example( + query = "SELECT T1.A, T2.B FROM DBO.TABLE_X AS T1 INNER JOIN DBO.TABLE_Y AS T2 ON T1.A = T2.A " + + "LEFT JOIN DBO.TABLE_Z AS T3 ON T1.A = T3.A AND T1.B = T3.B", + expectedAst = Batch( + Seq(Project( + Join( + Join( + TableAlias(namedTable("DBO.TABLE_X"), "T1"), + TableAlias(namedTable("DBO.TABLE_Y"), "T2"), + Some(Equals(t1aCol, t2aCol)), + InnerJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + TableAlias(namedTable("DBO.TABLE_Z"), "T3"), + Some(And(Equals(t1aCol, t3aCol), Equals(t1bCol, t3bCol))), + LeftOuterJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + List(t1aCol, t2bCol))))) + } + "translate a query with Multiple JOIN OR Conditions" in { + example( + query = "SELECT T1.A, T2.B FROM DBO.TABLE_X AS T1 INNER JOIN DBO.TABLE_Y AS T2 ON T1.A = T2.A " + + "LEFT JOIN DBO.TABLE_Z AS T3 ON T1.A = T3.A OR T1.B = T3.B", + expectedAst = Batch( + Seq(Project( + Join( + Join( + TableAlias(namedTable("DBO.TABLE_X"), "T1"), + TableAlias(namedTable("DBO.TABLE_Y"), "T2"), + Some(Equals(t1aCol, t2aCol)), + InnerJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + TableAlias(namedTable("DBO.TABLE_Z"), "T3"), + Some(Or(Equals(t1aCol, t3aCol), Equals(t1bCol, t3bCol))), + LeftOuterJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + List(t1aCol, t2bCol))))) + } + "translate a query with a RIGHT OUTER JOIN" in { + example( + query = "SELECT T1.A FROM DBO.TABLE_X AS T1 RIGHT OUTER JOIN DBO.TABLE_Y AS T2 ON T1.A = T2.A", + expectedAst = Batch( + Seq(Project( + Join( + TableAlias(namedTable("DBO.TABLE_X"), "T1"), + TableAlias(namedTable("DBO.TABLE_Y"), "T2"), + Some(Equals(t1aCol, t2aCol)), + RightOuterJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + List(t1aCol))))) + } + "translate a query with a FULL OUTER JOIN" in { + example( + query = "SELECT T1.A FROM DBO.TABLE_X AS T1 FULL OUTER JOIN DBO.TABLE_Y AS T2 ON T1.A = T2.A", + expectedAst = Batch( + Seq(Project( + Join( + TableAlias(namedTable("DBO.TABLE_X"), "T1"), + TableAlias(namedTable("DBO.TABLE_Y"), "T2"), + Some(Equals(t1aCol, t2aCol)), + FullOuterJoin, + List(), + JoinDataType(is_left_struct = false, is_right_struct = false)), + List(t1aCol))))) + } + + "cover default case in translateJoinType" in { + val joinOnContextMock = mock(classOf[TSqlParser.JoinOnContext]) + + val outerJoinContextMock = mock(classOf[TSqlParser.OuterJoinContext]) + + // Set up the mock to return null for LEFT(), RIGHT(), and FULL() + when(outerJoinContextMock.LEFT()).thenReturn(null) + when(outerJoinContextMock.RIGHT()).thenReturn(null) + when(outerJoinContextMock.FULL()).thenReturn(null) + + when(joinOnContextMock.joinType()).thenReturn(null) + + val joinTypeContextMock = mock(classOf[TSqlParser.JoinTypeContext]) + when(joinTypeContextMock.outerJoin()).thenReturn(outerJoinContextMock) + when(joinTypeContextMock.INNER()).thenReturn(null) + when(joinOnContextMock.joinType()).thenReturn(joinTypeContextMock) + + val result = vc.relationBuilder.translateJoinType(joinOnContextMock) + result shouldBe UnspecifiedJoin + } + + "translate simple XML query and values" in { + example( + query = "SELECT xmlcolumn.query('/root/child') FROM tab", + expectedAst = Batch( + Seq(Project( + namedTable("tab"), + Seq(tsql + .TsqlXmlFunction(CallFunction("query", Seq(Literal("/root/child"))), simplyNamedColumn("xmlcolumn"))))))) + + example( + "SELECT xmlcolumn.value('path', 'type') FROM tab", + expectedAst = Batch( + Seq( + Project( + namedTable("tab"), + Seq(tsql.TsqlXmlFunction( + CallFunction("value", Seq(Literal("path"), Literal("type"))), + simplyNamedColumn("xmlcolumn"))))))) + + example( + "SELECT xmlcolumn.exist('/root/child[text()=\"Some Value\"]') FROM xmltable;", + expectedAst = Batch( + Seq(Project( + namedTable("xmltable"), + Seq(tsql.TsqlXmlFunction( + CallFunction("exist", Seq(Literal("/root/child[text()=\"Some Value\"]"))), + simplyNamedColumn("xmlcolumn"))))))) + + // TODO: Add nodes(), modify(), when we complete UPDATE and CROSS APPLY + } + + "translate all assignments to local variables as select list elements" in { + + example( + query = "SELECT @a = 1, @b = 2, @c = 3", + expectedAst = Batch( + Seq(Project( + NoTable, + Seq( + Assign(Identifier("@a", isQuoted = false), Literal(1)), + Assign(Identifier("@b", isQuoted = false), Literal(2)), + Assign(Identifier("@c", isQuoted = false), Literal(3))))))) + + example( + query = "SELECT @a += 1, @b -= 2", + expectedAst = Batch( + Seq(Project( + NoTable, + Seq( + Assign(Identifier("@a", isQuoted = false), Add(Identifier("@a", isQuoted = false), Literal(1))), + Assign(Identifier("@b", isQuoted = false), Subtract(Identifier("@b", isQuoted = false), Literal(2)))))))) + + example( + query = "SELECT @a *= 1, @b /= 2", + expectedAst = Batch( + Seq(Project( + NoTable, + Seq( + Assign(Identifier("@a", isQuoted = false), Multiply(Identifier("@a", isQuoted = false), Literal(1))), + Assign(Identifier("@b", isQuoted = false), Divide(Identifier("@b", isQuoted = false), Literal(2)))))))) + + example( + query = "SELECT @a %= myColumn", + expectedAst = Batch( + Seq( + Project( + NoTable, + Seq(Assign( + Identifier("@a", isQuoted = false), + Mod(Identifier("@a", isQuoted = false), simplyNamedColumn("myColumn")))))))) + + example( + query = "SELECT @a &= myColumn", + expectedAst = Batch( + Seq( + Project( + NoTable, + Seq(Assign( + Identifier("@a", isQuoted = false), + BitwiseAnd(Identifier("@a", isQuoted = false), simplyNamedColumn("myColumn")))))))) + + example( + query = "SELECT @a ^= myColumn", + expectedAst = Batch( + Seq( + Project( + NoTable, + Seq(Assign( + Identifier("@a", isQuoted = false), + BitwiseXor(Identifier("@a", isQuoted = false), simplyNamedColumn("myColumn")))))))) + + example( + query = "SELECT @a |= myColumn", + expectedAst = Batch( + Seq( + Project( + NoTable, + Seq(Assign( + Identifier("@a", isQuoted = false), + BitwiseOr(Identifier("@a", isQuoted = false), simplyNamedColumn("myColumn")))))))) + } + "translate scalar subqueries as expressions in select list" in { + example( + query = """SELECT + EmployeeID, + Name, + (SELECT AvgSalary FROM Employees) AS AverageSalary + FROM + Employees;""", + expectedAst = Batch( + Seq(Project( + namedTable("Employees"), + Seq( + simplyNamedColumn("EmployeeID"), + simplyNamedColumn("Name"), + Alias( + ScalarSubquery(Project(namedTable("Employees"), Seq(simplyNamedColumn("AvgSalary")))), + Id("AverageSalary"))))))) + } + } + + "SQL statements should support DISTINCT clauses" in { + example( + query = "SELECT DISTINCT * FROM Employees;", + expectedAst = Batch( + Seq( + Project( + Deduplicate(namedTable("Employees"), List(), all_columns_as_keys = true, within_watermark = false), + Seq(Star(None)))))) + example( + query = "SELECT DISTINCT a, b AS bb FROM t", + expectedAst = Batch( + Seq(Project( + Deduplicate(namedTable("t"), List(Id("a"), Id("bb")), all_columns_as_keys = false, within_watermark = false), + Seq(simplyNamedColumn("a"), Alias(simplyNamedColumn("b"), Id("bb"))))))) + } + + "SELECT NEXT VALUE FOR mySequence As nextVal" in { + example( + query = "SELECT NEXT VALUE FOR mySequence As nextVal", + expectedAst = Batch( + Seq(Project(NoTable, Seq(Alias(CallFunction("MONOTONICALLY_INCREASING_ID", List.empty), Id("nextVal"))))))) + } + + "SELECT NEXT VALUE FOR var.mySequence As nextVal" in { + example( + query = "SELECT NEXT VALUE FOR var.mySequence As nextVal", + expectedAst = Batch( + Seq(Project(NoTable, Seq(Alias(CallFunction("MONOTONICALLY_INCREASING_ID", List.empty), Id("nextVal"))))))) + } + + "SELECT NEXT VALUE FOR var.mySequence OVER (ORDER BY myColumn) As nextVal" in { + example( + query = "SELECT NEXT VALUE FOR var.mySequence OVER (ORDER BY myColumn) As nextVal ", + expectedAst = Batch( + Seq(Project( + NoTable, + Seq(Alias( + Window( + CallFunction("ROW_NUMBER", List.empty), + List.empty, + List(SortOrder(simplyNamedColumn("myColumn"), UnspecifiedSortDirection, SortNullsUnspecified)), + None), + Id("nextVal"))))))) + } + + "translate CTE select statements" in { + example( + query = "WITH cte AS (SELECT * FROM t) SELECT * FROM cte", + expectedAst = Batch( + Seq( + WithCTE( + Seq(SubqueryAlias(Project(namedTable("t"), Seq(Star(None))), Id("cte"), List.empty)), + Project(namedTable("cte"), Seq(Star(None))))))) + + example( + query = """WITH cteTable1 (col1, col2, col3count) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table1 + ), + cteTable2 (colx, coly, colxcount) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table2 + ) + SELECT col2, col1, col3count, colx, coly, colxcount + FROM cteTable""", + expectedAst = Batch( + Seq(WithCTE( + Seq( + SubqueryAlias( + Project( + namedTable("Table1"), + Seq( + simplyNamedColumn("col1"), + simplyNamedColumn("fred"), + Alias(CallFunction("COUNT", Seq(simplyNamedColumn("OrderDate"))), Id("counter")))), + Id("cteTable1"), + Seq(Id("col1"), Id("col2"), Id("col3count"))), + SubqueryAlias( + Project( + namedTable("Table2"), + Seq( + simplyNamedColumn("col1"), + simplyNamedColumn("fred"), + Alias(CallFunction("COUNT", Seq(simplyNamedColumn("OrderDate"))), Id("counter")))), + Id("cteTable2"), + Seq(Id("colx"), Id("coly"), Id("colxcount")))), + Project( + namedTable("cteTable"), + Seq( + simplyNamedColumn("col2"), + simplyNamedColumn("col1"), + simplyNamedColumn("col3count"), + simplyNamedColumn("colx"), + simplyNamedColumn("coly"), + simplyNamedColumn("colxcount"))))))) + } + + "translate a SELECT with a TOP clause" should { + "use LIMIT" in { + example( + query = "SELECT TOP 10 * FROM Employees;", + expectedAst = Batch(Seq(Project(Limit(namedTable("Employees"), Literal(10)), Seq(Star(None)))))) + } + + "use TOP PERCENT" in { + example( + query = "SELECT TOP 10 PERCENT * FROM Employees;", + expectedAst = Batch(Seq(Project(TopPercent(namedTable("Employees"), Literal(10)), Seq(Star(None)))))) + + example( + query = "SELECT TOP 10 PERCENT WITH TIES * FROM Employees;", + expectedAst = + Batch(Seq(Project(TopPercent(namedTable("Employees"), Literal(10), with_ties = true), Seq(Star(None)))))) + } + } + + "translate a SELECT statement with an ORDER BY and OFFSET" in { + example( + query = "SELECT * FROM Employees ORDER BY Salary OFFSET 10 ROWS", + expectedAst = Batch( + Seq(Project( + Offset( + Sort( + namedTable("Employees"), + Seq(SortOrder(simplyNamedColumn("Salary"), Ascending, SortNullsUnspecified)), + is_global = false), + Literal(10)), + Seq(Star(None)))))) + + example( + query = "SELECT * FROM Employees ORDER BY Salary OFFSET 10 ROWS FETCH NEXT 5 ROWS ONLY", + expectedAst = Batch( + Seq(Project( + Limit( + Offset( + Sort( + namedTable("Employees"), + Seq(SortOrder(simplyNamedColumn("Salary"), Ascending, SortNullsUnspecified)), + is_global = false), + Literal(10)), + Literal(5)), + Seq(Star(None)))))) + } + + "translate SELECT with a combination of DISTINCT, ORDER BY, and OFFSET" in { + example( + query = "SELECT DISTINCT * FROM Employees ORDER BY Salary OFFSET 10 ROWS", + expectedAst = Batch( + Seq(Project( + Deduplicate( + Offset( + Sort( + namedTable("Employees"), + Seq(SortOrder(simplyNamedColumn("Salary"), Ascending, SortNullsUnspecified)), + is_global = false), + Literal(10)), + List(), + all_columns_as_keys = true, + within_watermark = false), + Seq(Star(None)))))) + + example( + query = "SELECT DISTINCT * FROM Employees ORDER BY Salary OFFSET 10 ROWS FETCH NEXT 5 ROWS ONLY", + expectedAst = Batch( + List(Project( + Deduplicate( + Limit( + Offset( + Sort( + namedTable("Employees"), + List(SortOrder(simplyNamedColumn("Salary"), Ascending, SortNullsUnspecified)), + is_global = false), + Literal(10)), + Literal(5)), + List(), + all_columns_as_keys = true, + within_watermark = false), + List(Star(None)))))) + } + + "translate a query with PIVOT" in { + singleQueryExample( + query = "SELECT a FROM b PIVOT (SUM(a) FOR c IN ('foo', 'bar')) AS Source", + expectedAst = Project( + Aggregate( + child = namedTable("b"), + group_type = Pivot, + grouping_expressions = Seq(CallFunction("SUM", Seq(simplyNamedColumn("a")))), + pivot = Some(Pivot(simplyNamedColumn("c"), Seq(Literal("foo"), Literal("bar"))))), + Seq(simplyNamedColumn("a")))) + } + + "translate a query with UNPIVOT" in { + singleQueryExample( + query = "SELECT a FROM b UNPIVOT (c FOR d IN (e, f)) AsSource", + expectedAst = Project( + Unpivot( + child = namedTable("b"), + ids = Seq(simplyNamedColumn("e"), simplyNamedColumn("f")), + values = None, + variable_column_name = Id("c"), + value_column_name = Id("d")), + Seq(simplyNamedColumn("a")))) + } + + "translate a query with an explicit CROSS JOIN" in { + singleQueryExample( + query = "SELECT a FROM b CROSS JOIN c", + expectedAst = Project( + Join( + namedTable("b"), + namedTable("c"), + None, + CrossJoin, + Seq.empty, + JoinDataType(is_left_struct = false, is_right_struct = false)), + Seq(simplyNamedColumn("a")))) + } + + "translate a query with an explicit OUTER APPLY" in { + singleQueryExample( + query = "SELECT a FROM b OUTER APPLY c", + expectedAst = Project( + Join( + namedTable("b"), + namedTable("c"), + None, + OuterApply, + Seq.empty, + JoinDataType(is_left_struct = false, is_right_struct = false)), + Seq(simplyNamedColumn("a")))) + } + + "translate a query with an explicit CROSS APPLY" in { + singleQueryExample( + query = "SELECT a FROM b CROSS APPLY c", + expectedAst = Project( + Join( + namedTable("b"), + namedTable("c"), + None, + CrossApply, + Seq.empty, + JoinDataType(is_left_struct = false, is_right_struct = false)), + Seq(simplyNamedColumn("a")))) + } + + "parse and ignore IR for the FOR clause in a SELECT statement" in { + example( + query = "SELECT * FROM DAYS FOR XML RAW", + expectedAst = Batch(Seq(Project(namedTable("DAYS"), Seq(Star(None)))))) + } + + "parse and collect the options in the OPTION clause in a SELECT statement" in { + example( + query = """SELECT * FROM t FOR XML RAW + OPTION ( + MAXRECURSION 10, + OPTIMIZE [FOR] UNKNOWN, + SOMETHING ON, + SOMETHINGELSE OFF, + SOMEOTHER AUTO, + SOMEstrOpt = 'STRINGOPTION')""", + expectedAst = Batch( + Seq(WithOptions( + Project(namedTable("t"), Seq(Star(None))), + Options( + Map("MAXRECURSION" -> Literal(10), "OPTIMIZE" -> Column(None, Id("FOR", true))), + Map("SOMESTROPT" -> "STRINGOPTION"), + Map("SOMETHING" -> true, "SOMETHINGELSE" -> false), + List("SOMEOTHER")))))) + } + + "parse and collect table hints for named table select statements in all variants" in { + example( + query = "SELECT * FROM t WITH (NOLOCK)", + expectedAst = Batch(Seq(Project(TableWithHints(namedTable("t"), Seq(FlagHint("NOLOCK"))), Seq(Star(None)))))) + example( + query = "SELECT * FROM t WITH (FORCESEEK)", + expectedAst = + Batch(Seq(Project(TableWithHints(namedTable("t"), Seq(ForceSeekHint(None, None))), Seq(Star(None)))))) + example( + query = "SELECT * FROM t WITH (FORCESEEK(1 (Col1, Col2)))", + expectedAst = Batch( + Seq( + Project( + TableWithHints(namedTable("t"), Seq(ForceSeekHint(Some(Literal(1)), Some(Seq(Id("Col1"), Id("Col2")))))), + Seq(Star(None)))))) + example( + query = "SELECT * FROM t WITH (INDEX = (Bill, Ted))", + expectedAst = Batch( + Seq(Project(TableWithHints(namedTable("t"), Seq(IndexHint(Seq(Id("Bill"), Id("Ted"))))), Seq(Star(None)))))) + example( + query = "SELECT * FROM t WITH (FORCESEEK, INDEX = (Bill, Ted))", + expectedAst = Batch( + Seq( + Project( + TableWithHints(namedTable("t"), Seq(ForceSeekHint(None, None), IndexHint(Seq(Id("Bill"), Id("Ted"))))), + Seq(Star(None)))))) + } + + "translate INSERT statements" in { + example( + query = "INSERT INTO t (a, b) VALUES (1, 2)", + expectedAst = Batch( + Seq( + InsertIntoTable( + namedTable("t"), + Some(Seq(Id("a"), Id("b"))), + DerivedRows(Seq(Seq(Literal(1), Literal(2)))), + None, + None, + overwrite = false)))) + } + "translate INSERT statement with @LocalVar" in { + example( + query = "INSERT INTO @LocalVar (a, b) VALUES (1, 2)", + expectedAst = Batch( + Seq( + InsertIntoTable( + LocalVarTable(Id("@LocalVar")), + Some(Seq(Id("a"), Id("b"))), + DerivedRows(Seq(Seq(Literal(1), Literal(2)))), + None, + None, + overwrite = false)))) + } + + "translate insert statements with VALU(pa, irs)" in { + example( + query = "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", + expectedAst = Batch( + Seq( + InsertIntoTable( + namedTable("t"), + Some(Seq(Id("a"), Id("b"))), + DerivedRows(Seq(Seq(Literal(1), Literal(2)), Seq(Literal(3), Literal(4)))), + None, + None, + overwrite = false)))) + } + + "translate insert statements with (OPTIONS)" in { + example( + query = "INSERT INTO t WITH (TABLOCK) (a, b) VALUES (1, 2)", + expectedAst = Batch( + Seq(InsertIntoTable( + TableWithHints(namedTable("t"), List(FlagHint("TABLOCK"))), + Some(Seq(Id("a"), Id("b"))), + DerivedRows(Seq(Seq(Literal(1), Literal(2)))), + None, + None, + overwrite = false)))) + } + + "translate insert statement with DEFAULT VALUES" in { + example( + query = "INSERT INTO t DEFAULT VALUES", + expectedAst = Batch(Seq(InsertIntoTable(namedTable("t"), None, DefaultValues(), None, None, overwrite = false)))) + } + + "translate INSERT statement with OUTPUT clause" in { + example( + query = "INSERT INTO t (a, b) OUTPUT INSERTED.a as a_lias, INSERTED.b INTO Inserted(a, b) VALUES (1, 2)", + expectedAst = Batch( + List(InsertIntoTable( + namedTable("t"), + Some(List(Id("a"), Id("b"))), + DerivedRows(List(List(Literal(1), Literal(2)))), + Some(tsql.Output( + Some(namedTable("Inserted")), + List( + Alias(Column(Some(ObjectReference(Id("INSERTED"))), Id("a")), Id("a_lias")), + Column(Some(ObjectReference(Id("INSERTED"))), Id("b"))), + Some(List(simplyNamedColumn("a"), simplyNamedColumn("b"))))), + None, + overwrite = false)))) + } + + "translate insert statements with CTE" in { + example( + query = "WITH wtab AS (SELECT * FROM t) INSERT INTO t (a, b) select * from wtab", + expectedAst = Batch( + Seq(WithCTE( + Seq(SubqueryAlias(Project(namedTable("t"), Seq(Star(None))), Id("wtab"), List.empty)), + InsertIntoTable( + namedTable("t"), + Some(Seq(Id("a"), Id("b"))), + Project(namedTable("wtab"), Seq(Star(None))), + None, + None, + overwrite = false))))) + } + + "translate insert statements with SELECT" in { + example( + query = """ + INSERT INTO ConsolidatedRecords (ID, Name) + SELECT ID, Name + FROM ( + SELECT ID, Name + FROM TableA + UNION + SELECT ID, Name + FROM TableB) + AS DerivedTable;""", + expectedAst = Batch( + Seq(InsertIntoTable( + namedTable("ConsolidatedRecords"), + Some(Seq(Id("ID"), Id("Name"))), + Project( + TableAlias( + SetOperation( + Project(namedTable("TableA"), Seq(simplyNamedColumn("ID"), simplyNamedColumn("Name"))), + Project(namedTable("TableB"), Seq(simplyNamedColumn("ID"), simplyNamedColumn("Name"))), + UnionSetOp, + is_all = false, + by_name = false, + allow_missing_columns = false), + "DerivedTable"), + Seq(simplyNamedColumn("ID"), simplyNamedColumn("Name"))), + None, + None, + overwrite = false)))) + } + + "should translate UPDATE statements" in { + example( + query = "UPDATE t SET a = 1, b = 2", + expectedAst = Batch( + Seq( + UpdateTable( + NamedTable("t", Map(), is_streaming = false), + None, + Seq(Assign(Column(None, Id("a")), Literal(1)), Assign(Column(None, Id("b")), Literal(2))), + None, + None, + None)))) + + example( + query = "UPDATE t SET a = 1, b = 2 OUTPUT INSERTED.a as a_lias, INSERTED.b INTO Inserted(a, b)", + expectedAst = Batch( + Seq(UpdateTable( + NamedTable("t", Map(), is_streaming = false), + None, + Seq(Assign(Column(None, Id("a")), Literal(1)), Assign(Column(None, Id("b")), Literal(2))), + None, + Some(tsql.Output( + Some(NamedTable("Inserted", Map(), is_streaming = false)), + Seq( + Alias(Column(Some(ObjectReference(Id("INSERTED"))), Id("a")), Id("a_lias")), + Column(Some(ObjectReference(Id("INSERTED"))), Id("b"))), + Some(Seq(Column(None, Id("a")), Column(None, Id("b")))))), + None)))) + + example( + query = "UPDATE t SET a = 1, b = 2 FROM t1 WHERE t.a = t1.a", + expectedAst = Batch( + Seq(UpdateTable( + NamedTable("t", Map(), is_streaming = false), + Some(NamedTable("t1", Map(), is_streaming = false)), + Seq(Assign(Column(None, Id("a")), Literal(1)), Assign(Column(None, Id("b")), Literal(2))), + Some( + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("t1"))), Id("a")))), + None, + None)))) + + example( + query = "UPDATE t SET a = 1, udf.Transform(b) FROM t1 WHERE t.a = t1.a OPTION (KEEP PLAN)", + expectedAst = Batch( + Seq(UpdateTable( + NamedTable("t", Map(), is_streaming = false), + Some(NamedTable("t1", Map(), is_streaming = false)), + Seq( + Assign(Column(None, Id("a")), Literal(1)), + UnresolvedFunction( + "udf.Transform", + Seq(Column(None, Id("b"))), + is_distinct = false, + is_user_defined_function = true, + ruleText = "udf.Transform(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function udf.Transform is not convertible to Databricks SQL")), + Some( + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("t1"))), Id("a")))), + None, + Some(Options(Map("KEEP" -> Column(None, Id("PLAN"))), Map.empty, Map.empty, List.empty)))))) + } + + "translate DELETE statements" in { + example( + query = "DELETE FROM t", + expectedAst = Batch(Seq(DeleteFromTable(NamedTable("t", Map(), is_streaming = false), None, None, None, None)))) + + example( + query = "DELETE FROM t OUTPUT DELETED.a as a_lias, DELETED.b INTO Deleted(a, b)", + expectedAst = Batch( + Seq(DeleteFromTable( + NamedTable("t", Map(), is_streaming = false), + None, + None, + Some(tsql.Output( + Some(NamedTable("Deleted", Map(), is_streaming = false)), + Seq( + Alias(Column(Some(ObjectReference(Id("DELETED"))), Id("a")), Id("a_lias")), + Column(Some(ObjectReference(Id("DELETED"))), Id("b"))), + Some(Seq(Column(None, Id("a")), Column(None, Id("b")))))), + None)))) + + example( + query = "DELETE FROM t FROM t1 WHERE t.a = t1.a", + expectedAst = Batch( + Seq(DeleteFromTable( + NamedTable("t", Map(), is_streaming = false), + Some(NamedTable("t1", Map(), is_streaming = false)), + Some( + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("t1"))), Id("a")))), + None, + None)))) + + example( + query = "DELETE FROM t FROM t1 WHERE t.a = t1.a OPTION (KEEP PLAN)", + expectedAst = Batch( + Seq(DeleteFromTable( + NamedTable("t", Map(), is_streaming = false), + Some(NamedTable("t1", Map(), is_streaming = false)), + Some( + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("t1"))), Id("a")))), + None, + Some(Options(Map("KEEP" -> Column(None, Id("PLAN"))), Map.empty, Map.empty, List.empty)))))) + } + + "translate MERGE statements" in { + example( + query = """ + |MERGE INTO t USING s + | ON t.a = s.a + | WHEN MATCHED THEN UPDATE SET t.b = s.b + | WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.a, s.b)""".stripMargin, + expectedAst = Batch( + Seq(MergeIntoTable( + NamedTable("t", Map(), is_streaming = false), + NamedTable("s", Map(), is_streaming = false), + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Seq( + UpdateAction( + None, + Seq(Assign( + Column(Some(ObjectReference(Id("t"))), Id("b")), + Column(Some(ObjectReference(Id("s"))), Id("b")))))), + Seq(InsertAction( + None, + Seq( + Assign(Column(None, Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Assign(Column(None, Id("b")), Column(Some(ObjectReference(Id("s"))), Id("b")))))), + List.empty)))) + } + + "translate MERGE statements with options" in { + example( + query = """ + |MERGE INTO t USING s + | ON t.a = s.a + | WHEN MATCHED THEN UPDATE SET t.b = s.b + | WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.a, s.b) + | OPTION ( KEEPFIXED PLAN, FAST 666, MAX_GRANT_PERCENT = 30, FLAME ON, FLAME OFF, QUICKLY) """.stripMargin, + expectedAst = Batch( + Seq(WithModificationOptions( + MergeIntoTable( + NamedTable("t", Map(), is_streaming = false), + NamedTable("s", Map(), is_streaming = false), + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Seq( + UpdateAction( + None, + Seq(Assign( + Column(Some(ObjectReference(Id("t"))), Id("b")), + Column(Some(ObjectReference(Id("s"))), Id("b")))))), + Seq(InsertAction( + None, + Seq( + Assign(Column(None, Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Assign(Column(None, Id("b")), Column(Some(ObjectReference(Id("s"))), Id("b")))))), + List.empty), + Options( + Map("KEEPFIXED" -> Column(None, Id("PLAN")), "FAST" -> Literal(666), "MAX_GRANT_PERCENT" -> Literal(30)), + Map(), + Map("FLAME" -> false, "QUICKLY" -> true), + List()))))) + example( + query = """ + |WITH s (a, b, col3count) + | AS + | ( + | SELECT col1, fred, COUNT(OrderDate) AS counter + | FROM Table1 + | ) + | MERGE INTO t WITH (NOLOCK, READCOMMITTED) USING s + | ON t.a = s.a + | WHEN MATCHED THEN UPDATE SET t.b = s.b + | WHEN NOT MATCHED BY TARGET THEN DELETE + | WHEN NOT MATCHED BY SOURCE THEN INSERT (a, b) VALUES (s.a, s.b)""".stripMargin, + expectedAst = Batch( + Seq(WithCTE( + Seq(SubqueryAlias( + Project( + namedTable("Table1"), + Seq( + simplyNamedColumn("col1"), + simplyNamedColumn("fred"), + Alias(CallFunction("COUNT", Seq(simplyNamedColumn("OrderDate"))), Id("counter")))), + Id("s"), + Seq(Id("a"), Id("b"), Id("col3count")))), + MergeIntoTable( + TableWithHints( + NamedTable("t", Map(), is_streaming = false), + Seq(FlagHint("NOLOCK"), FlagHint("READCOMMITTED"))), + NamedTable("s", Map(), is_streaming = false), + Equals(Column(Some(ObjectReference(Id("t"))), Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Seq( + UpdateAction( + None, + Seq(Assign( + Column(Some(ObjectReference(Id("t"))), Id("b")), + Column(Some(ObjectReference(Id("s"))), Id("b")))))), + Seq(DeleteAction(None)), + Seq(InsertAction( + None, + Seq( + Assign(Column(None, Id("a")), Column(Some(ObjectReference(Id("s"))), Id("a"))), + Assign(Column(None, Id("b")), Column(Some(ObjectReference(Id("s"))), Id("b"))))))))))) + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilderSpec.scala new file mode 100644 index 0000000000..98cd5de366 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlDDLBuilderSpec.scala @@ -0,0 +1,179 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlDDLBuilderSpec extends AnyWordSpec with TSqlParserTestCommon with Matchers with ir.IRHelpers { + + override protected def astBuilder: TSqlParserBaseVisitor[_] = vc.astBuilder + + private def singleQueryExample(query: String, expectedAst: ir.LogicalPlan): Unit = + example(query, _.tSqlFile(), ir.Batch(Seq(expectedAst))) + + "tsql DDL visitor" should { + + "translate a simple CREATE TABLE" in { + singleQueryExample( + "CREATE TABLE some_table (a INT, b VARCHAR(10))", + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with a primary key, foreign key and a Unique column" in { + singleQueryExample( + "CREATE TABLE some_table (a INT PRIMARY KEY, b VARCHAR(10) UNIQUE, FOREIGN KEY (b) REFERENCES other_table(b))", + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq(ir.PrimaryKey()), "b" -> Seq(ir.Unique())), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq(ir.ForeignKey("b", "other_table", "b", Seq.empty)), + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with a CHECK constraint and column options" in { + singleQueryExample( + "CREATE TABLE some_table (a INT SPARSE, b VARCHAR(10), CONSTRAINT c1 CHECK (a > 0))", + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq(ir.OptionUnresolved("Unsupported Option: SPARSE")), "b" -> Seq.empty), + Seq( + ir.NamedConstraint( + "c1", + ir.CheckConstraint(ir.GreaterThan(ir.Column(None, ir.Id("a")), ir.Literal(0, ir.IntegerType))))), + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with a DEFAULT constraint" in { + singleQueryExample( + "CREATE TABLE some_table (a INT DEFAULT 0, b VARCHAR(10) DEFAULT 'foo')", + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map( + "a" -> Seq(ir.DefaultValueConstraint(ir.Literal(0, ir.IntegerType))), + "b" -> Seq(ir.DefaultValueConstraint(ir.Literal("foo", ir.StringType)))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with a complex FK constraint" in { + singleQueryExample( + "CREATE TABLE some_table (a INT, b VARCHAR(10), CONSTRAINT c1 FOREIGN KEY (a, b) REFERENCES other_table(c, d))", + ir.CreateTableParams( + ir.CreateTable( + "some_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("a", ir.IntegerType), ir.StructField("b", ir.VarcharType(Some(10)))))), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Map("a" -> Seq.empty, "b" -> Seq.empty), + Seq(ir.NamedConstraint("c1", ir.ForeignKey("a, b", "other_table", "c, d", Seq.empty))), + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with various column level constraints" in { + singleQueryExample( + "CREATE TABLE example_table (id INT PRIMARY KEY, name VARCHAR(50) NOT NULL, age INT CHECK (age >= 18)," + + "email VARCHAR(100) UNIQUE, department_id INT FOREIGN KEY REFERENCES departments(id));", + ir.CreateTableParams( + ir.CreateTable( + "example_table", + None, + None, + None, + ir.StructType(Seq( + ir.StructField("id", ir.IntegerType), + ir.StructField("name", ir.VarcharType(Some(50)), nullable = false), + ir.StructField("age", ir.IntegerType), + ir.StructField("email", ir.VarcharType(Some(100))), + ir.StructField("department_id", ir.IntegerType)))), + Map( + "name" -> Seq.empty, + "email" -> Seq(ir.Unique()), + "department_id" -> Seq.empty, + "age" -> Seq.empty, + "id" -> Seq(ir.PrimaryKey())), + Map( + "name" -> Seq.empty, + "email" -> Seq.empty, + "department_id" -> Seq.empty, + "age" -> Seq.empty, + "id" -> Seq.empty), + Seq( + ir.CheckConstraint(ir.GreaterThanOrEqual(ir.Column(None, ir.Id("age")), ir.Literal(18, ir.IntegerType))), + ir.ForeignKey("department_id", "departments", "id", Seq.empty)), + Seq.empty, + None, + Some(Seq.empty))) + } + + "translate a CREATE TABLE with a named NULL constraint" in { + singleQueryExample( + "CREATE TABLE example_table (id VARCHAR(10) CONSTRAINT c1 NOT NULL);", + ir.CreateTableParams( + ir.CreateTable( + "example_table", + None, + None, + None, + ir.StructType(Seq(ir.StructField("id", ir.VarcharType(Some(10)))))), + Map("id" -> Seq.empty), + Map("id" -> Seq.empty), + Seq(ir.NamedConstraint("c1", ir.CheckConstraint(ir.IsNotNull(ir.Column(None, ir.Id("id")))))), + Seq.empty, + None, + Some(Seq.empty))) + } + } + + "translate a CREATE TABLE with unsupported table level TSQL options" in { + singleQueryExample( + "CREATE TABLE example_table (id INT) WITH (LEDGER = ON);", + ir.CreateTableParams( + ir.CreateTable("example_table", None, None, None, ir.StructType(Seq(ir.StructField("id", ir.IntegerType)))), + Map("id" -> Seq.empty), + Map("id" -> Seq.empty), + Seq.empty, + Seq.empty, + None, + Some(Seq(ir.OptionUnresolved("LEDGER = ON"))))) + + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorHandlerSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorHandlerSpec.scala new file mode 100644 index 0000000000..e99f420e4a --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorHandlerSpec.scala @@ -0,0 +1,197 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate.ParsingError +import com.databricks.labs.remorph.parsers.{DefaultErrorCollector, EmptyErrorCollector, ProductionErrorCollector} +import org.antlr.v4.runtime.CommonToken +import org.apache.logging.log4j.core.appender.AbstractAppender +import org.apache.logging.log4j.core.config.{Configuration, Configurator} +import org.apache.logging.log4j.core.layout.PatternLayout +import org.apache.logging.log4j.core.{LogEvent, LoggerContext} +import org.apache.logging.log4j.{Level, LogManager, Logger} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.io.{ByteArrayOutputStream, PrintStream} +import scala.collection.mutable.ListBuffer + +class ListBufferAppender(appenderName: String, layout: PatternLayout) + extends AbstractAppender(appenderName, null, layout, false, null) { + val logMessages: ListBuffer[String] = ListBuffer() + + override def append(event: LogEvent): Unit = { + logMessages += event.getMessage.getFormattedMessage + } + + def getLogMessages: List[String] = logMessages.toList +} + +class TSqlErrorHandlerSpec extends AnyFlatSpec with Matchers { + + "ErrorCollector" should "collect syntax errors correctly" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + val token = new CommonToken(1) + errorCollector.syntaxError(null, token, 1, 1, "msg", null) + errorCollector.errors.head shouldBe ParsingError( + 1, + 1, + "msg", + 1, + null, + "unresolved token name", + "unresolved rule name") + } + + it should "format errors correctly" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + val token = new CommonToken(1) + token.setLine(1) + token.setCharPositionInLine(1) + token.setText("text") + errorCollector.errors += ParsingError(1, 1, "msg", 4, "text", "unresolved token name", "unresolved rule name") + errorCollector.formatErrors.head should include("Token: text") + } + + it should "convert errors to JSON correctly" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + val token = new CommonToken(1) + token.setLine(1) + token.setCharPositionInLine(1) + token.setText("text") + errorCollector.errors += ParsingError(1, 1, "msg", 4, "text", "unresolved token name", "unresolved rule name") + errorCollector.errorsAsJson should include("starting at 1:1") + errorCollector.errorsAsJson should include("'unresolved rule name'") + errorCollector.errorsAsJson should include("'text'") + errorCollector.errorsAsJson should include("(unresolved token name)") + errorCollector.errorsAsJson should include("msg") + } + + it should "count errors correctly" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + val token = new CommonToken(1) + token.setLine(1) + token.setCharPositionInLine(1) + token.setText("text") + errorCollector.errors += ParsingError(1, 1, "msg", 4, "text", "unresolved token name", "unresolved rule name") + errorCollector.errorCount shouldBe 1 + } + + it should "window long lines correctly" in { + val longLine = "a" * 40 + "error" + "a" * 40 + val errorCollector = new ProductionErrorCollector(longLine, "fileName") + val token = new CommonToken(1) + token.setLine(1) + token.setCharPositionInLine(40) + token.setStartIndex(40) + token.setStopIndex(44) + token.setText("error") + errorCollector.errors += ParsingError(1, 40, "msg", 5, "error", "unresolved token name", "unresolved rule name") + + // Call the method + val formattedErrors = errorCollector.formatErrors + + // Check the windowing + val s = "..." + "a" * 34 + "error" + "a" * 35 + "...\n" + " " * 37 + "^^^^^" + formattedErrors.head should include(s) + } + + it should "log errors correctly" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + val token = new CommonToken(1) + token.setLine(1) + token.setCharPositionInLine(1) + token.setText("text") + errorCollector.errors += ParsingError(1, 1, "msg", 4, "text", "unresolved token name", "unresolved rule name") + + // Capture the logs + val logger: Logger = LogManager.getLogger("com.databricks.labs.remorph.parsers.ErrorCollector") + Configurator.setLevel(logger.getName, Level.ERROR) + val layout = PatternLayout.createDefaultLayout() + + // Create a custom appender to capture the logs + val appenderName = "CaptureAppender" + val appender = new ListBufferAppender(appenderName, layout) + appender.start() + + val context: LoggerContext = LogManager.getContext(false).asInstanceOf[LoggerContext] + val config: Configuration = context.getConfiguration + + config.addAppender(appender) + // Get the logger config + val loggerConfig = config.getLoggerConfig(logger.getName) + + // Add the appender to the logger and set additivity to false + loggerConfig.addAppender(appender, null, null) + loggerConfig.setAdditive(false) + + // Update the logger context with the new configuration + context.updateLoggers() + + // Call the method + errorCollector.logErrors() + + // Check the logs + val logs = appender.getLogMessages + logs.exists(log => log.contains("File: fileName, Line: 1")) shouldBe true + } + + it should "capture syntax errors correctly using syntaxError method" in { + val errorCollector = new ProductionErrorCollector("sourceCode", "fileName") + errorCollector.reset() + val token = new CommonToken(1) + token.setLine(10) + token.setCharPositionInLine(5) + token.setStartIndex(42) + token.setStopIndex(50) + token.setText("errorText") + + errorCollector.syntaxError(null, token, 10, 5, "Syntax error message", null) + + errorCollector.errorCount shouldBe 1 + errorCollector.errors should contain( + ParsingError(10, 5, "Syntax error message", 9, "errorText", "unresolved token name", "unresolved rule name")) + } + + def captureStdErr[T](block: => T): (T, String) = { + val originalErr = System.err + val errContent = new ByteArrayOutputStream() + val printStream = new PrintStream(errContent) + System.setErr(printStream) + try { + val result = block + printStream.flush() + (result, errContent.toString) + } finally { + System.setErr(originalErr) + } + } + + it should "capture syntax errors correctly using the default syntaxError method" in { + val errorCollector = new DefaultErrorCollector() + errorCollector.reset() + val token = new CommonToken(1) + token.setLine(10) + token.setCharPositionInLine(5) + token.setText("errorText") + val (_, capturedErr) = captureStdErr { + errorCollector.syntaxError( + null, + token, + 10, + 5, + "Ignore this Syntax error message - it is supposed to be here", + null) + } + errorCollector.errorCount shouldBe 1 + capturedErr should include("Ignore this Syntax error message") + } + + it should "have sensible defaults" in { + val errorCollector = new EmptyErrorCollector() + errorCollector.errorCount shouldBe 0 + errorCollector.logErrors() + errorCollector.reset() + errorCollector.errorCount shouldBe 0 + errorCollector.formatErrors shouldBe List() + errorCollector.errorsAsJson shouldBe "{}" + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStategySpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStategySpec.scala new file mode 100644 index 0000000000..32727d7961 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlErrorStategySpec.scala @@ -0,0 +1,43 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.intermediate.IRHelpers +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlErrorStategySpec extends AnyWordSpec with TSqlParserTestCommon with Matchers with IRHelpers { + override protected def astBuilder: TSqlParserBaseVisitor[_] = vc.astBuilder + + private def checkError(query: String, errContains: String): Assertion = + checkError(query, _.tSqlFile(), errContains) + + "TSqlErrorStrategy" should { + "process an invalid match parser exception" in { + checkError(query = "SELECT * FROM", errContains = "was unexpected") + } + "process an extraneous input exception" in { + checkError(query = "*", errContains = "unexpected extra input") + } + "process a missing input exception" in { + checkError(query = "SELECT * FROM FRED As X Y ", errContains = "unexpected extra input") + } + } + + "TSqlErrorStrategy" should { + "produce human readable messages" in { + checkError( + query = "SELECT * FROM FRED As X Y ", + errContains = "unexpected extra input 'Y' while parsing a T-SQL batch") + + checkError( + query = "*", + errContains = "unexpected extra input '*' while parsing a T-SQL batch\n" + + "expecting one of: End of batch, Identifier, Jinja Template Element, Select Statement, ") + + checkError( + query = "SELECT * FROM", + errContains = "'' was unexpected while parsing a table source in a FROM clause " + + "in a SELECT statement\nexpecting one of:") + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilderSpec.scala new file mode 100644 index 0000000000..4b61157587 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilderSpec.scala @@ -0,0 +1,765 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser +import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.ID +import com.databricks.labs.remorph.{intermediate => ir} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.antlr.v4.runtime.{CommonToken, Token} +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mockito.{mock, when} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlExpressionBuilderSpec extends AnyWordSpec with TSqlParserTestCommon with Matchers with ir.IRHelpers { + + override protected def astBuilder: TSqlParserBaseVisitor[_] = vc.expressionBuilder + + "TSqlExpressionBuilder" should { + "translate literals" in { + exampleExpr("null", _.expression(), ir.Literal.Null) + exampleExpr("1", _.expression(), ir.Literal(1)) + exampleExpr("-1", _.expression(), ir.UMinus(ir.Literal(1))) // TODO: add optimizer + exampleExpr("+1", _.expression(), ir.UPlus(ir.Literal(1))) + exampleExpr("1.1", _.expression(), ir.Literal(1.1f)) + exampleExpr("'foo'", _.expression(), ir.Literal("foo")) + } + "translate scientific notation" in { + exampleExpr("1.1e2", _.expression(), ir.Literal(110)) + exampleExpr("1.1e-2", _.expression(), ir.Literal(0.011f)) + exampleExpr("1e2", _.expression(), ir.Literal(100)) + exampleExpr("0.123456789", _.expression(), ir.Literal(0.123456789)) + exampleExpr("0.123456789e-1234", _.expression(), ir.DecimalLiteral("0.123456789e-1234")) + } + "translate simple primitives" in { + exampleExpr("DEFAULT", _.expression(), Default()) + exampleExpr("@LocalId", _.expression(), ir.Identifier("@LocalId", isQuoted = false)) + } + + "translate simple numeric binary expressions" in { + exampleExpr("1 + 2", _.expression(), ir.Add(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 +2", _.expression(), ir.Add(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 - 2", _.expression(), ir.Subtract(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 -2", _.expression(), ir.Subtract(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 * 2", _.expression(), ir.Multiply(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 / 2", _.expression(), ir.Divide(ir.Literal(1), ir.Literal(2))) + exampleExpr("1 % 2", _.expression(), ir.Mod(ir.Literal(1), ir.Literal(2))) + exampleExpr("'A' || 'B'", _.expression(), ir.Concat(Seq(ir.Literal("A"), ir.Literal("B")))) + exampleExpr("4 ^ 2", _.expression(), ir.BitwiseXor(ir.Literal(4), ir.Literal(2))) + } + "translate complex binary expressions" in { + exampleExpr( + "a + b * 2", + _.expression(), + ir.Add(simplyNamedColumn("a"), ir.Multiply(simplyNamedColumn("b"), ir.Literal(2)))) + exampleExpr( + "(a + b) * 2", + _.expression(), + ir.Multiply(ir.Add(simplyNamedColumn("a"), simplyNamedColumn("b")), ir.Literal(2))) + exampleExpr( + "a & b | c", + _.expression(), + ir.BitwiseOr(ir.BitwiseAnd(simplyNamedColumn("a"), simplyNamedColumn("b")), simplyNamedColumn("c"))) + exampleExpr( + "(a & b) | c", + _.expression(), + ir.BitwiseOr(ir.BitwiseAnd(simplyNamedColumn("a"), simplyNamedColumn("b")), simplyNamedColumn("c"))) + exampleExpr( + "a + b * 2", + _.expression(), + ir.Add(simplyNamedColumn("a"), ir.Multiply(simplyNamedColumn("b"), ir.Literal(2)))) + exampleExpr( + "(a + b) * 2", + _.expression(), + ir.Multiply(ir.Add(simplyNamedColumn("a"), simplyNamedColumn("b")), ir.Literal(2))) + exampleExpr( + "a & b | c", + _.expression(), + ir.BitwiseOr(ir.BitwiseAnd(simplyNamedColumn("a"), simplyNamedColumn("b")), simplyNamedColumn("c"))) + exampleExpr( + "(a & b) | c", + _.expression(), + ir.BitwiseOr(ir.BitwiseAnd(simplyNamedColumn("a"), simplyNamedColumn("b")), simplyNamedColumn("c"))) + exampleExpr( + "a % 3 + b * 2 - c / 5", + _.expression(), + ir.Subtract( + ir.Add(ir.Mod(simplyNamedColumn("a"), ir.Literal(3)), ir.Multiply(simplyNamedColumn("b"), ir.Literal(2))), + ir.Divide(simplyNamedColumn("c"), ir.Literal(5)))) + exampleExpr( + query = "a || b || c", + _.expression(), + ir.Concat(Seq(ir.Concat(Seq(simplyNamedColumn("a"), simplyNamedColumn("b"))), simplyNamedColumn("c")))) + } + "correctly apply operator precedence and associativity" in { + exampleExpr( + "1 + -++-2", + _.expression(), + ir.Add(ir.Literal(1), ir.UMinus(ir.UPlus(ir.UPlus(ir.UMinus(ir.Literal(2))))))) + exampleExpr( + "1 + ~ 2 * 3", + _.expression(), + ir.Add(ir.Literal(1), ir.Multiply(ir.BitwiseNot(ir.Literal(2)), ir.Literal(3)))) + exampleExpr( + "1 + -2 * 3", + _.expression(), + ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3)))) + exampleExpr( + "1 + -2 * 3 + 7 & 66", + _.expression(), + ir.BitwiseAnd( + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3))), ir.Literal(7)), + ir.Literal(66))) + exampleExpr( + "1 + -2 * 3 + 7 ^ 66", + _.expression(), + ir.BitwiseXor( + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3))), ir.Literal(7)), + ir.Literal(66))) + exampleExpr( + "1 + -2 * 3 + 7 | 66", + _.expression(), + ir.BitwiseOr( + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3))), ir.Literal(7)), + ir.Literal(66))) + exampleExpr( + "1 + -2 * 3 + 7 + ~66", + _.expression(), + ir.Add( + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3))), ir.Literal(7)), + ir.BitwiseNot(ir.Literal(66)))) + exampleExpr( + "1 + -2 * 3 + 7 | 1980 || 'leeds1' || 'leeds2' || 'leeds3'", + _.expression(), + ir.Concat( + Seq( + ir.Concat(Seq( + ir.Concat( + Seq( + ir.BitwiseOr( + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.UMinus(ir.Literal(2)), ir.Literal(3))), ir.Literal(7)), + ir.Literal(1980)), + ir.Literal("leeds1"))), + ir.Literal("leeds2"))), + ir.Literal("leeds3")))) + } + "correctly respect explicit precedence with parentheses" in { + exampleExpr("(1 + 2) * 3", _.expression(), ir.Multiply(ir.Add(ir.Literal(1), ir.Literal(2)), ir.Literal(3))) + exampleExpr("1 + (2 * 3)", _.expression(), ir.Add(ir.Literal(1), ir.Multiply(ir.Literal(2), ir.Literal(3)))) + exampleExpr( + "(1 + 2) * (3 + 4)", + _.expression(), + ir.Multiply(ir.Add(ir.Literal(1), ir.Literal(2)), ir.Add(ir.Literal(3), ir.Literal(4)))) + exampleExpr( + "1 + (2 * 3) + 4", + _.expression(), + ir.Add(ir.Add(ir.Literal(1), ir.Multiply(ir.Literal(2), ir.Literal(3))), ir.Literal(4))) + exampleExpr( + "1 + (2 * 3 + 4)", + _.expression(), + ir.Add(ir.Literal(1), ir.Add(ir.Multiply(ir.Literal(2), ir.Literal(3)), ir.Literal(4)))) + exampleExpr( + "1 + (2 * (3 + 4))", + _.expression(), + ir.Add(ir.Literal(1), ir.Multiply(ir.Literal(2), ir.Add(ir.Literal(3), ir.Literal(4))))) + exampleExpr( + "(1 + (2 * (3 + 4)))", + _.expression(), + ir.Add(ir.Literal(1), ir.Multiply(ir.Literal(2), ir.Add(ir.Literal(3), ir.Literal(4))))) + } + + "correctly resolve dot delimited plain references" in { + exampleExpr("a", _.expression(), simplyNamedColumn("a")) + exampleExpr("a.b", _.expression(), ir.Column(Some(ir.ObjectReference(ir.Id("a"))), ir.Id("b"))) + exampleExpr("a.b.c", _.expression(), ir.Column(Some(ir.ObjectReference(ir.Id("a"), ir.Id("b"))), ir.Id("c"))) + } + + "correctly resolve RAW identifiers" in { + exampleExpr("RAW", _.expression(), simplyNamedColumn("RAW")) + } + + "correctly resolve # identifiers" in { + exampleExpr("#RAW", _.expression(), simplyNamedColumn("#RAW")) + } + + "correctly resolve \" quoted identifiers" in { + exampleExpr("\"a\"", _.expression(), ir.Column(None, ir.Id("a", caseSensitive = true))) + } + + "correctly resolve [] quoted identifiers" in { + exampleExpr("[a]", _.expression(), ir.Column(None, ir.Id("a", caseSensitive = true))) + } + + "correctly resolve [] quoted dot identifiers" in { + exampleExpr( + "[a].[b]", + _.expression(), + ir.Column(Some(ir.ObjectReference(ir.Id("a", caseSensitive = true))), ir.Id("b", caseSensitive = true))) + } + + "correctly resolve [] quoted triple dot identifiers" in { + exampleExpr( + "[a].[b].[c]", + _.expression(), + ir.Column( + Some(ir.ObjectReference(ir.Id("a", caseSensitive = true), ir.Id("b", caseSensitive = true))), + ir.Id("c", caseSensitive = true))) + } + + "correctly resolve keywords used as identifiers" in { + exampleExpr("ABORT", _.expression(), simplyNamedColumn("ABORT")) + } + + "translate a simple column" in { + exampleExpr("a", _.selectListElem(), simplyNamedColumn("a")) + exampleExpr("#a", _.selectListElem(), simplyNamedColumn("#a")) + exampleExpr("[a]", _.selectListElem(), ir.Column(None, ir.Id("a", caseSensitive = true))) + exampleExpr("\"a\"", _.selectListElem(), ir.Column(None, ir.Id("a", caseSensitive = true))) + exampleExpr("RAW", _.selectListElem(), simplyNamedColumn("RAW")) + } + + "translate a column with a table" in { + exampleExpr("table_x.a", _.selectListElem(), ir.Column(Some(ir.ObjectReference(ir.Id("table_x"))), ir.Id("a"))) + } + + "translate a column with a schema" in { + exampleExpr( + "schema1.table_x.a", + _.selectListElem(), + ir.Column(Some(ir.ObjectReference(ir.Id("schema1"), ir.Id("table_x"))), ir.Id("a"))) + } + + "translate a column with a database" in { + exampleExpr( + "database1.schema1.table_x.a", + _.selectListElem(), + ir.Column(Some(ir.ObjectReference(ir.Id("database1"), ir.Id("schema1"), ir.Id("table_x"))), ir.Id("a"))) + } + + "translate a column with a server" in { + exampleExpr( + "server1..schema1.table_x.a", + _.fullColumnName(), + ir.Column(Some(ir.ObjectReference(ir.Id("server1"), ir.Id("schema1"), ir.Id("table_x"))), ir.Id("a"))) + } + + "translate a column without a table reference" in { + exampleExpr("a", _.fullColumnName(), simplyNamedColumn("a")) + } + + "return ir.Dot for otherwise unhandled DotExpr" in { + val mockDotExprCtx = mock(classOf[TSqlParser.ExprDotContext]) + val mockExpressionCtx = mock(classOf[TSqlParser.ExpressionContext]) + val mockVisitor = mock(classOf[TSqlExpressionBuilder]) + + when(mockDotExprCtx.expression(anyInt())).thenReturn(mockExpressionCtx) + when(mockExpressionCtx.accept(mockVisitor)).thenReturn(ir.Literal("a")) + val result = astBuilder.visitExprDot(mockDotExprCtx) + + result shouldBe a[ir.Dot] + } + + "translate search conditions" in { + exampleExpr("a = b", _.searchCondition(), ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a > b", _.searchCondition(), ir.GreaterThan(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a < b", _.searchCondition(), ir.LessThan(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a >= b", _.searchCondition(), ir.GreaterThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a !< b", _.searchCondition(), ir.GreaterThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a <= b", _.searchCondition(), ir.LessThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a !> b", _.searchCondition(), ir.LessThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a > = b", _.searchCondition(), ir.GreaterThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a < = b", _.searchCondition(), ir.LessThanOrEqual(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a <> b", _.searchCondition(), ir.NotEquals(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("a != b", _.searchCondition(), ir.NotEquals(simplyNamedColumn("a"), simplyNamedColumn("b"))) + exampleExpr("NOT a = b", _.searchCondition(), ir.Not(ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b")))) + exampleExpr( + "a = b AND c = e", + _.searchCondition(), + ir.And( + ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b")), + ir.Equals(simplyNamedColumn("c"), simplyNamedColumn("e")))) + exampleExpr( + "a = b OR c = e", + _.searchCondition(), + ir.Or( + ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b")), + ir.Equals(simplyNamedColumn("c"), simplyNamedColumn("e")))) + exampleExpr( + "a = b AND c = x OR e = f", + _.searchCondition(), + ir.Or( + ir.And( + ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b")), + ir.Equals(simplyNamedColumn("c"), simplyNamedColumn("x"))), + ir.Equals(simplyNamedColumn("e"), simplyNamedColumn("f")))) + exampleExpr( + "a = b AND (c = x OR e = f)", + _.searchCondition(), + ir.And( + ir.Equals(simplyNamedColumn("a"), simplyNamedColumn("b")), + ir.Or( + ir.Equals(simplyNamedColumn("c"), simplyNamedColumn("x")), + ir.Equals(simplyNamedColumn("e"), simplyNamedColumn("f"))))) + } + + "handle non special functions used in dot operators" in { + exampleExpr( + "a.b()", + _.expression(), + ir.Dot( + simplyNamedColumn("a"), + ir.UnresolvedFunction( + "b", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "b(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function b is not convertible to Databricks SQL"))) + exampleExpr( + "a.b.c()", + _.expression(), + ir.Dot( + simplyNamedColumn("a"), + ir.Dot( + simplyNamedColumn("b"), + ir.UnresolvedFunction( + "c", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "c(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function c is not convertible to Databricks SQL")))) + exampleExpr( + "a.b.c.FLOOR(c)", + _.expression(), + ir.Dot( + simplyNamedColumn("a"), + ir.Dot( + simplyNamedColumn("b"), + ir.Dot(simplyNamedColumn("c"), ir.CallFunction("FLOOR", Seq(simplyNamedColumn("c"))))))) + } + + "handle unknown functions used with dots" in { + exampleExpr( + "a.UNKNOWN_FUNCTION()", + _.expression(), + ir.Dot( + simplyNamedColumn("a"), + ir.UnresolvedFunction( + "UNKNOWN_FUNCTION", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "UNKNOWN_FUNCTION(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function UNKNOWN_FUNCTION is not convertible to Databricks SQL"))) + } + + "cover case that cannot happen with dot" in { + + val mockCtx = mock(classOf[TSqlParser.ExprDotContext]) + val expressionMockColumn = mock(classOf[TSqlParser.ExpressionContext]) + when(mockCtx.expression(0)).thenReturn(expressionMockColumn) + when(expressionMockColumn.accept(any())).thenReturn(simplyNamedColumn("a")) + val expressionMockFunc = mock(classOf[TSqlParser.ExpressionContext]) + when(mockCtx.expression(1)).thenReturn(expressionMockFunc) + when(expressionMockFunc.accept(any())).thenReturn(ir.CallFunction("UNKNOWN_FUNCTION", List())) + val result = vc.expressionBuilder.visitExprDot(mockCtx) + result shouldBe a[ir.Dot] + } + + "translate case/when/else expressions" in { + // Case with an initial expression and an else clause + exampleExpr( + "CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END", + _.expression(), + ir.Case( + Some(simplyNamedColumn("a")), + Seq(ir.WhenBranch(ir.Literal(1), ir.Literal("one")), ir.WhenBranch(ir.Literal(2), ir.Literal("two"))), + Some(ir.Literal("other")))) + + // Case without an initial expression and with an else clause + exampleExpr( + "CASE WHEN a = 1 THEN 'one' WHEN a = 2 THEN 'two' ELSE 'other' END", + _.expression(), + ir.Case( + None, + Seq( + ir.WhenBranch(ir.Equals(simplyNamedColumn("a"), ir.Literal(1)), ir.Literal("one")), + ir.WhenBranch(ir.Equals(simplyNamedColumn("a"), ir.Literal(2)), ir.Literal("two"))), + Some(ir.Literal("other")))) + + // Case with an initial expression and without an else clause + exampleExpr( + "CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' END", + _.expression(), + ir.Case( + Some(simplyNamedColumn("a")), + Seq(ir.WhenBranch(ir.Literal(1), ir.Literal("one")), ir.WhenBranch(ir.Literal(2), ir.Literal("two"))), + None)) + + // Case without an initial expression and without an else clause + exampleExpr( + "CASE WHEN a = 1 AND b < 7 THEN 'one' WHEN a = 2 THEN 'two' END", + _.expression(), + ir.Case( + None, + Seq( + ir.WhenBranch( + ir.And( + ir.Equals(simplyNamedColumn("a"), ir.Literal(1)), + ir.LessThan(simplyNamedColumn("b"), ir.Literal(7))), + ir.Literal("one")), + ir.WhenBranch(ir.Equals(simplyNamedColumn("a"), ir.Literal(2)), ir.Literal("two"))), + None)) + } + + "translate the $ACTION special column reference" in { + exampleExpr("$ACTION", _.expression(), ir.DollarAction) + } + + "translate a timezone reference" in { + exampleExpr("a AT TIME ZONE 'UTC'", _.expression(), ir.Timezone(simplyNamedColumn("a"), ir.Literal("UTC"))) + } + + "return UnresolvedExpression for unsupported SelectListElem" in { + + val mockCtx = mock(classOf[TSqlParser.SelectListElemContext]) + + // Ensure that both asterisk() and expressionElem() methods return null + when(mockCtx.asterisk()).thenReturn(null) + when(mockCtx.expressionElem()).thenReturn(null) + val startTok = new CommonToken(ID, "s") + when(mockCtx.getStart).thenReturn(startTok) + when(mockCtx.getStop).thenReturn(startTok) + when(mockCtx.getRuleIndex).thenReturn(SnowflakeParser.RULE_constraintAction) + + // Call the method with the mock instance + val result = vc.expressionBuilder.buildSelectListElem(mockCtx) + + // Verify the result + result shouldBe a[Seq[_]] + } + + "cover default case in buildLocalAssign via visitSelectListElem" in { + val selectListElemContextMock = mock(classOf[TSqlParser.SelectListElemContext]) + val eofToken = new CommonToken(Token.EOF) + selectListElemContextMock.op = eofToken + when(selectListElemContextMock.LOCAL_ID()).thenReturn(new TerminalNodeImpl(eofToken)) + when(selectListElemContextMock.asterisk()).thenReturn(null) + when(selectListElemContextMock.getText).thenReturn("") + + val expressionContextMock = mock(classOf[TSqlParser.ExpressionContext]) + when(expressionContextMock.accept(any())).thenReturn(null) + when(selectListElemContextMock.expression()).thenReturn(expressionContextMock) + + val result = vc.expressionBuilder.buildSelectListElem(selectListElemContextMock) + + result shouldBe a[Seq[_]] + } + + "translate CAST(a AS tinyint)" in { + exampleExpr("CAST(a AS tinyint)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.ByteType(size = Some(1)))) + } + "translate CAST(a AS smallint)" in { + exampleExpr("CAST(a AS smallint)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.ShortType)) + } + "translate CAST(a AS INT)" in { + exampleExpr("CAST(a AS INT)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.IntegerType)) + } + "translate CAST(a AS bigint)" in { + exampleExpr("CAST(a AS bigint)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.LongType)) + } + "translate CAST(a AS bit)" in { + exampleExpr("CAST(a AS bit)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.BooleanType)) + } + "translate CAST(a AS money)" in { + exampleExpr( + "CAST(a AS money)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(19), Some(4)))) + } + "translate CAST(a AS smallmoney)" in { + exampleExpr( + "CAST(a AS smallmoney)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(4)))) + } + "translate CAST(a AS float)" in { + exampleExpr("CAST(a AS float)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.FloatType)) + } + "translate CAST(a AS real)" in { + exampleExpr("CAST(a AS real)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.DoubleType)) + } + "translate CAST(a AS date)" in { + exampleExpr("CAST(a AS date)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.DateType)) + } + "translate CAST(a AS time)" in { + exampleExpr("CAST(a AS time)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.TimeType)) + } + "translate CAST(a AS datetime)" in { + exampleExpr("CAST(a AS datetime)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.TimestampType)) + } + "translate CAST(a AS datetime2)" in { + exampleExpr("CAST(a AS datetime2)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.TimestampType)) + } + "translate CAST(a AS datetimeoffset)" in { + exampleExpr("CAST(a AS datetimeoffset)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.StringType)) + } + "translate CAST(a AS smalldatetime)" in { + exampleExpr("CAST(a AS smalldatetime)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.TimestampType)) + } + "translate CAST(a AS char)" in { + exampleExpr("CAST(a AS char)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.CharType(size = None))) + } + "translate CAST(a AS varchar)" in { + exampleExpr("CAST(a AS varchar)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = None))) + } + "translate CAST(a AS nchar)" in { + exampleExpr("CAST(a AS nchar)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.CharType(size = None))) + } + "translate CAST(a AS nvarchar)" in { + exampleExpr("CAST(a AS nvarchar)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = None))) + } + "translate CAST(a AS text)" in { + exampleExpr("CAST(a AS text)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None))) + } + "translate CAST(a AS ntext)" in { + exampleExpr("CAST(a AS ntext)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None))) + } + "translate CAST(a AS image)" in { + exampleExpr("CAST(a AS image)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.BinaryType)) + } + "translate CAST(a AS decimal)" in { + exampleExpr("CAST(a AS decimal)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.DecimalType(None, None))) + } + "translate CAST(a AS numeric)" in { + exampleExpr("CAST(a AS numeric)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.DecimalType(None, None))) + } + "translate CAST(a AS binary)" in { + exampleExpr("CAST(a AS binary)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.BinaryType)) + } + "translate CAST(a AS varbinary)" in { + exampleExpr("CAST(a AS varbinary)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.BinaryType)) + } + "translate CAST(a AS json)" in { + exampleExpr("CAST(a AS json)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None))) + } + "translate CAST(a AS uniqueidentifier)" in { + exampleExpr( + "CAST(a AS uniqueidentifier)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(16)))) + } + + "translate CAST pseudo function calls with length arguments" in { + exampleExpr("CAST(a AS char(10))", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.CharType(size = Some(10)))) + exampleExpr( + "CAST(a AS varchar(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(10)))) + exampleExpr("CAST(a AS nchar(10))", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.CharType(size = Some(10)))) + exampleExpr( + "CAST(a AS nvarchar(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(10)))) + } + + "translate CAST pseudo function calls with scale arguments" in { + exampleExpr( + "CAST(a AS decimal(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), None))) + exampleExpr( + "CAST(a AS numeric(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), None))) + } + + "translate CAST pseudo function calls with precision and scale arguments" in { + exampleExpr( + "CAST(a AS decimal(10, 2))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(2)))) + exampleExpr( + "CAST(a AS numeric(10, 2))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(2)))) + } + + "translate TRY_CAST pseudo function calls with simple scalars" in { + exampleExpr( + "TRY_CAST(a AS tinyint)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.ByteType(size = Some(1)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS smallint)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.ShortType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS INT)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.IntegerType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS bigint)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.LongType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS bit)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.BooleanType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS money)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(19), Some(4)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS smallmoney)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(4)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS float)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.FloatType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS real)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DoubleType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS date)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DateType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS time)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.TimeType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS datetime)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.TimestampType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS datetime2)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.TimestampType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS datetimeoffset)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.StringType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS smalldatetime)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.TimestampType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS char)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.CharType(size = None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS varchar)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS nchar)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.CharType(size = None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS nvarchar)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS text)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS ntext)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS image)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.BinaryType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS decimal)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(None, None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS numeric)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(None, None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS binary)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.BinaryType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS varbinary)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.BinaryType, returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS json)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS uniqueidentifier)", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(16)), returnNullOnError = true)) + } + + "translate TRY_CAST pseudo function calls with length arguments" in { + exampleExpr( + "TRY_CAST(a AS char(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.CharType(size = Some(10)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS varchar(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(10)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS nchar(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.CharType(size = Some(10)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS nvarchar(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.VarcharType(size = Some(10)), returnNullOnError = true)) + } + + "translate TRY_CAST pseudo function calls with scale arguments" in { + exampleExpr( + "TRY_CAST(a AS decimal(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), None), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS numeric(10))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), None), returnNullOnError = true)) + } + + "translate TRY_CAST pseudo function calls with precision and scale arguments" in { + exampleExpr( + "TRY_CAST(a AS decimal(10, 2))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(2)), returnNullOnError = true)) + exampleExpr( + "TRY_CAST(a AS numeric(10, 2))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.DecimalType(Some(10), Some(2)), returnNullOnError = true)) + } + + "translate identity to UnparsedType" in { + // TODO: Resolve what to do with IDENTITY + // IDENTITY it isn't actually castable but we have not implemented CREATE TABLE yet, so cover here for now + // then examine what happens in snowflake + exampleExpr( + "CAST(a AS col1 IDENTITY(10, 2))", + _.expression(), + ir.Cast(simplyNamedColumn("a"), ir.UnparsedType("col1IDENTITY(10,2)"))) + } + + "translate unknown types to UnParsedType" in { + exampleExpr("CAST(a AS sometype)", _.expression(), ir.Cast(simplyNamedColumn("a"), ir.UnparsedType("sometype"))) + } + + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionGeneratorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionGeneratorTest.scala new file mode 100644 index 0000000000..84b3f062c9 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionGeneratorTest.scala @@ -0,0 +1,27 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.generators.{GeneratorContext, GeneratorTestCommon} +import com.databricks.labs.remorph.generators.sql.{ExpressionGenerator, LogicalPlanGenerator, OptionGenerator} +import com.databricks.labs.remorph.intermediate.{Batch, Expression} +import com.databricks.labs.remorph.{Generating, intermediate => ir} +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +// Only add tests here that require the TSqlCallMapper, or in the future any other transformer/rule +// that is specific to T-SQL. Otherwise they belong in ExpressionGeneratorTest. + +class TSqlExpressionGeneratorTest + extends AnyWordSpec + with GeneratorTestCommon[ir.Expression] + with MockitoSugar + with ir.IRHelpers { + + override protected val generator = new ExpressionGenerator() + + private[this] val optionGenerator = new OptionGenerator(generator) + + private[this] val logical = new LogicalPlanGenerator(generator, optionGenerator) + + override protected def initialState(t: Expression): Generating = + Generating(optimizedPlan = Batch(Seq.empty), currentNode = t, ctx = GeneratorContext(logical)) +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilderSpec.scala new file mode 100644 index 0000000000..6ff75af7ce --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionBuilderSpec.scala @@ -0,0 +1,57 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.FunctionDefinition +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks + +class TSqlFunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenPropertyChecks { + + private[this] val functionBuilder = new TSqlFunctionBuilder + + // While this appears to be somewhat redundant, it will catch any changes in the functionArity method + // that happen through typos or other mistakes such as deletion. + "TSqlFunctionBuilder" should "return correct arity for each function" in { + + val functions = Table( + ("functionName", "expectedArity"), // Header + + // TSql specific + (s"$$PARTITION", Some(FunctionDefinition.notConvertible(0))), + ("@@CURSOR_ROWS", Some(FunctionDefinition.notConvertible(0))), + ("@@DBTS", Some(FunctionDefinition.notConvertible(0))), + ("@@FETCH_STATUS", Some(FunctionDefinition.notConvertible(0))), + ("@@LANGID", Some(FunctionDefinition.notConvertible(0))), + ("@@LANGUAGE", Some(FunctionDefinition.notConvertible(0))), + ("@@LOCKTIMEOUT", Some(FunctionDefinition.notConvertible(0))), + ("@@MAX_CONNECTIONS", Some(FunctionDefinition.notConvertible(0))), + ("@@MAX_PRECISION", Some(FunctionDefinition.notConvertible(0))), + ("@@NESTLEVEL", Some(FunctionDefinition.notConvertible(0))), + ("@@OPTIONS", Some(FunctionDefinition.notConvertible(0))), + ("@@REMSERVER", Some(FunctionDefinition.notConvertible(0))), + ("@@SERVERNAME", Some(FunctionDefinition.notConvertible(0))), + ("@@SERVICENAME", Some(FunctionDefinition.notConvertible(0))), + ("@@SPID", Some(FunctionDefinition.notConvertible(0))), + ("@@TEXTSIZE", Some(FunctionDefinition.notConvertible(0))), + ("@@VERSION", Some(FunctionDefinition.notConvertible(0))), + ("COLLATIONPROPERTY", Some(FunctionDefinition.notConvertible(2))), + ("CONTAINSTABLE", Some(FunctionDefinition.notConvertible(0))), + ("FREETEXTTABLE", Some(FunctionDefinition.notConvertible(0))), + ("HIERARCHYID", Some(FunctionDefinition.notConvertible(0))), + ("MODIFY", Some(FunctionDefinition.xml(1))), + ("SEMANTICKEYPHRASETABLE", Some(FunctionDefinition.notConvertible(0))), + ("SEMANTICSIMILARITYDETAILSTABLE", Some(FunctionDefinition.notConvertible(0))), + ("SEMANTICSSIMILARITYTABLE", Some(FunctionDefinition.notConvertible(0)))) + + forAll(functions) { (functionName: String, expectedArity: Option[FunctionDefinition]) => + functionBuilder.functionDefinition(functionName) shouldEqual expectedArity + } + } + + "TSqlFunctionBuilder rename strategy" should "handle default case" in { + + val result = functionBuilder.rename("UNKNOWN_FUNCTION", List.empty) + assert(result == ir.CallFunction("UNKNOWN_FUNCTION", List.empty)) + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionSpec.scala new file mode 100644 index 0000000000..263b6dd76b --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlFunctionSpec.scala @@ -0,0 +1,409 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlFunctionSpec extends AnyWordSpec with TSqlParserTestCommon with Matchers with ir.IRHelpers { + + override protected def astBuilder: TSqlParserBaseVisitor[_] = vc.expressionBuilder + + "translate functions with no parameters" in { + exampleExpr("APP_NAME()", _.expression(), ir.CallFunction("APP_NAME", List())) + exampleExpr("SCOPE_IDENTITY()", _.expression(), ir.CallFunction("SCOPE_IDENTITY", List())) + } + + "translate functions with variable numbers of parameters" in { + exampleExpr( + "CONCAT('a', 'b', 'c')", + _.expression(), + ir.CallFunction("CONCAT", Seq(ir.Literal("a"), ir.Literal("b"), ir.Literal("c")))) + + exampleExpr( + "CONCAT_WS(',', 'a', 'b', 'c')", + _.expression(), + ir.CallFunction("CONCAT_WS", List(ir.Literal(","), ir.Literal("a"), ir.Literal("b"), ir.Literal("c")))) + } + + "translate functions with functions as parameters" in { + exampleExpr( + "CONCAT(Greatest(42, 2, 4, \"ali\"), 'c')", + _.expression(), + ir.CallFunction( + "CONCAT", + List( + ir.CallFunction( + "Greatest", + List(ir.Literal(42), ir.Literal(2), ir.Literal(4), ir.Column(None, ir.Id("ali", caseSensitive = true)))), + ir.Literal("c")))) + } + + "translate functions with complicated expressions as parameters" in { + exampleExpr( + "CONCAT('a', 'b' || 'c', Greatest(42, 2, 4, \"ali\"))", + _.standardFunction(), + ir.CallFunction( + "CONCAT", + List( + ir.Literal("a"), + ir.Concat(Seq(ir.Literal("b"), ir.Literal("c"))), + ir.CallFunction( + "Greatest", + List(ir.Literal(42), ir.Literal(2), ir.Literal(4), ir.Column(None, ir.Id("ali", caseSensitive = true))))))) + } + + "translate unknown functions as unresolved" in { + exampleExpr( + "UNKNOWN_FUNCTION()", + _.expression(), + ir.UnresolvedFunction( + "UNKNOWN_FUNCTION", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "UNKNOWN_FUNCTION(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function UNKNOWN_FUNCTION is not convertible to Databricks SQL")) + } + + "translate functions with invalid function argument counts" in { + exampleExpr( + "USER_NAME('a', 'b', 'c', 'd')", // USER_NAME function only accepts 0 or 1 argument + _.expression(), + ir.UnresolvedFunction( + "USER_NAME", + Seq(ir.Literal("a"), ir.Literal("b"), ir.Literal("c"), ir.Literal("d")), + is_distinct = false, + is_user_defined_function = false, + has_incorrect_argc = true, + ruleText = "USER_NAME(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Invocation of USER_NAME has incorrect argument count")) + + exampleExpr( + "FLOOR()", // FLOOR requires 1 argument + _.expression(), + ir.UnresolvedFunction( + "FLOOR", + List(), + is_distinct = false, + is_user_defined_function = false, + has_incorrect_argc = true, + ruleText = "FLOOR(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Invocation of FLOOR has incorrect argument count")) + } + + "translate functions that we know cannot be converted" in { + // Later, we will register a semantic or lint error + exampleExpr( + "CONNECTIONPROPERTY('property')", + _.expression(), + ir.UnresolvedFunction( + "CONNECTIONPROPERTY", + List(ir.Literal("property")), + is_distinct = false, + is_user_defined_function = false, + ruleText = "CONNECTIONPROPERTY(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function CONNECTIONPROPERTY is not convertible to Databricks SQL")) + } + + "translate windowing functions in all forms" in { + exampleExpr( + """SUM(salary) OVER (PARTITION BY department ORDER BY employee_id + RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)""", + _.expression(), + ir.Window( + ir.CallFunction("SUM", Seq(simplyNamedColumn("salary"))), + Seq(simplyNamedColumn("department")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.UnspecifiedSortDirection, ir.SortNullsUnspecified)), + Some(ir.WindowFrame(ir.RangeFrame, ir.UnboundedPreceding, ir.CurrentRow)))) + exampleExpr( + "SUM(salary) OVER (PARTITION BY department ORDER BY employee_id ROWS UNBOUNDED PRECEDING)", + _.expression(), + ir.Window( + ir.CallFunction("SUM", Seq(simplyNamedColumn("salary"))), + Seq(simplyNamedColumn("department")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.UnspecifiedSortDirection, ir.SortNullsUnspecified)), + Some(ir.WindowFrame(ir.RowsFrame, ir.UnboundedPreceding, ir.NoBoundary)))) + + exampleExpr( + "SUM(salary) OVER (PARTITION BY department ORDER BY employee_id ROWS 66 PRECEDING)", + _.expression(), + ir.Window( + ir.CallFunction("SUM", Seq(simplyNamedColumn("salary"))), + Seq(simplyNamedColumn("department")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.UnspecifiedSortDirection, ir.SortNullsUnspecified)), + Some(ir.WindowFrame(ir.RowsFrame, ir.PrecedingN(ir.Literal(66)), ir.NoBoundary)))) + + exampleExpr( + query = """ + AVG(salary) OVER (PARTITION BY department_id ORDER BY employee_id ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """, + _.expression(), + ir.Window( + ir.CallFunction("AVG", Seq(simplyNamedColumn("salary"))), + Seq(simplyNamedColumn("department_id")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.UnspecifiedSortDirection, ir.SortNullsUnspecified)), + Some(ir.WindowFrame(ir.RowsFrame, ir.UnboundedPreceding, ir.CurrentRow)))) + + exampleExpr( + query = """ + SUM(sales) OVER (ORDER BY month ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) + """, + _.expression(), + ir.Window( + ir.CallFunction("SUM", Seq(simplyNamedColumn("sales"))), + List(), + Seq(ir.SortOrder(simplyNamedColumn("month"), ir.UnspecifiedSortDirection, ir.SortNullsUnspecified)), + Some(ir.WindowFrame(ir.RowsFrame, ir.CurrentRow, ir.FollowingN(ir.Literal(2)))))) + + exampleExpr( + "ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary DESC)", + _.selectListElem(), + ir.Window( + ir.CallFunction("ROW_NUMBER", Seq.empty), + Seq(simplyNamedColumn("department")), + Seq(ir.SortOrder(simplyNamedColumn("salary"), ir.Descending, ir.SortNullsUnspecified)), + None)) + + exampleExpr( + "ROW_NUMBER() OVER (PARTITION BY department)", + _.selectListElem(), + ir.Window(ir.CallFunction("ROW_NUMBER", Seq.empty), Seq(simplyNamedColumn("department")), List(), None)) + + } + + "translate functions with DISTINCT arguments" in { + exampleExpr( + "COUNT(DISTINCT salary)", + _.expression(), + ir.CallFunction("COUNT", Seq(ir.Distinct(simplyNamedColumn("salary"))))) + } + + "translate COUNT(*)" in { + exampleExpr("COUNT(*)", _.expression(), ir.CallFunction("COUNT", Seq(ir.Star(None)))) + } + + "translate special keyword functions" in { + exampleExpr( + // TODO: Returns UnresolvedFunction as it is not convertible - create UnsupportedFunctionRule + "@@CURSOR_ROWS", + _.expression(), + ir.UnresolvedFunction( + "@@CURSOR_ROWS", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "@@CURSOR_ROWS(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function @@CURSOR_ROWS is not convertible to Databricks SQL")) + + exampleExpr( + // TODO: Returns UnresolvedFunction as it is not convertible - create UnsupportedFunctionRule + "@@FETCH_STATUS", + _.expression(), + ir.UnresolvedFunction( + "@@FETCH_STATUS", + List(), + is_distinct = false, + is_user_defined_function = false, + ruleText = "@@FETCH_STATUS(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function @@FETCH_STATUS is not convertible to Databricks SQL")) + + exampleExpr("SESSION_USER", _.expression(), ir.CallFunction("SESSION_USER", List())) + + exampleExpr("USER", _.expression(), ir.CallFunction("USER", List())) + } + + "translate analytic windowing functions in all forms" in { + + exampleExpr( + query = "FIRST_VALUE(Salary) OVER (PARTITION BY DepartmentID ORDER BY Salary DESC)", + _.expression(), + ir.Window( + ir.CallFunction("FIRST_VALUE", Seq(simplyNamedColumn("Salary"))), + Seq(simplyNamedColumn("DepartmentID")), + Seq(ir.SortOrder(simplyNamedColumn("Salary"), ir.Descending, ir.SortNullsUnspecified)), + None)) + + exampleExpr( + query = """ + LAST_VALUE(salary) OVER (PARTITION BY department_id ORDER BY employee_id DESC) + """, + _.expression(), + ir.Window( + ir.CallFunction("LAST_VALUE", Seq(simplyNamedColumn("salary"))), + Seq(simplyNamedColumn("department_id")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.Descending, ir.SortNullsUnspecified)), + None)) + + exampleExpr( + query = "PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY Salary ASC) OVER (PARTITION BY DepartmentID)", + _.expression(), + ir.Window( + ir.WithinGroup( + ir.CallFunction("PERCENTILE_CONT", Seq(ir.Literal(0.5f))), + Seq(ir.SortOrder(simplyNamedColumn("Salary"), ir.Ascending, ir.SortNullsUnspecified))), + Seq(simplyNamedColumn("DepartmentID")), + List(), + None)) + + exampleExpr( + query = """ + LEAD(salary, 1) OVER (PARTITION BY department_id ORDER BY employee_id DESC) + """, + _.expression(), + ir.Window( + ir.CallFunction("LEAD", Seq(simplyNamedColumn("salary"), ir.Literal(1))), + Seq(simplyNamedColumn("department_id")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.Descending, ir.SortNullsUnspecified)), + None)) + + exampleExpr( + query = """ + LEAD(salary, 1) IGNORE NULLS OVER (PARTITION BY department_id ORDER BY employee_id DESC) + """, + _.expression(), + ir.Window( + ir.CallFunction("LEAD", Seq(simplyNamedColumn("salary"), ir.Literal(1))), + Seq(simplyNamedColumn("department_id")), + Seq(ir.SortOrder(simplyNamedColumn("employee_id"), ir.Descending, ir.SortNullsUnspecified)), + None, + ignore_nulls = true)) + + } + + "translate 'functions' with non-standard syntax" in { + exampleExpr( + query = "NEXT VALUE FOR mySequence", + _.expression(), + ir.CallFunction("MONOTONICALLY_INCREASING_ID", List.empty)) + } + + "translate JSON_ARRAY in various forms" in { + exampleExpr( + query = "JSON_ARRAY(1, 2, 3 ABSENT ON NULL)", + _.expression(), + ir.CallFunction( + "TO_JSON", + Seq( + ir.ValueArray(Seq(ir.FilterExpr( + Seq(ir.Literal(1), ir.Literal(2), ir.Literal(3)), + ir.LambdaFunction( + ir.Not(ir.IsNull(ir.UnresolvedNamedLambdaVariable(Seq("x")))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("x")))))))))) + + exampleExpr( + query = "JSON_ARRAY(4, 5, 6)", + _.expression(), + ir.CallFunction( + "TO_JSON", + Seq( + ir.ValueArray(Seq(ir.FilterExpr( + Seq(ir.Literal(4), ir.Literal(5), ir.Literal(6)), + ir.LambdaFunction( + ir.Not(ir.IsNull(ir.UnresolvedNamedLambdaVariable(Seq("x")))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("x")))))))))) + + exampleExpr( + query = "JSON_ARRAY(1, 2, 3 NULL ON NULL)", + _.expression(), + ir.CallFunction("TO_JSON", Seq(ir.ValueArray(Seq(ir.Literal(1), ir.Literal(2), ir.Literal(3)))))) + + exampleExpr( + query = "JSON_ARRAY(1, col1, x.col2 NULL ON NULL)", + _.expression(), + ir.CallFunction( + "TO_JSON", + Seq( + ir.ValueArray( + Seq( + ir.Literal(1), + simplyNamedColumn("col1"), + ir.Column(Some(ir.ObjectReference(ir.Id("x"))), ir.Id("col2"))))))) + } + + "translate JSON_OBJECT in various forms" in { + exampleExpr( + query = "JSON_OBJECT('one': 1, 'two': 2, 'three': 3 ABSENT ON NULL)", + _.expression(), + ir.CallFunction( + "TO_JSON", + Seq( + ir.FilterStruct( + ir.NamedStruct( + keys = Seq(ir.Literal("one"), ir.Literal("two"), ir.Literal("three")), + values = Seq(ir.Literal(1), ir.Literal(2), ir.Literal(3))), + ir.LambdaFunction( + ir.Not(ir.IsNull(ir.UnresolvedNamedLambdaVariable(Seq("v")))), + Seq(ir.UnresolvedNamedLambdaVariable(Seq("k", "v")))))))) + + exampleExpr( + query = "JSON_OBJECT('a': a, 'b': b, 'c': c NULL ON NULL)", + _.expression(), + ir.CallFunction( + "TO_JSON", + Seq( + ir.NamedStruct( + Seq(ir.Literal("a"), ir.Literal("b"), ir.Literal("c")), + Seq(simplyNamedColumn("a"), simplyNamedColumn("b"), simplyNamedColumn("c")))))) + } + + "translate functions using ALL" in { + exampleExpr(query = "COUNT(ALL goals)", _.expression(), ir.CallFunction("COUNT", Seq(simplyNamedColumn("goals")))) + } + + "translate freetext functions as inconvertible" in { + exampleExpr( + query = "FREETEXTTABLE(table, col, 'search')", + _.expression(), + ir.UnresolvedFunction( + "FREETEXTTABLE", + List.empty, + is_distinct = false, + is_user_defined_function = false, + ruleText = "FREETEXTTABLE(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function FREETEXTTABLE is not convertible to Databricks SQL")) + } + + "translate $PARTITION functions as inconvertible" in { + exampleExpr( + query = "$PARTITION.partitionFunction(col)", + _.expression(), + ir.UnresolvedFunction( + "$PARTITION", + List.empty, + is_distinct = false, + is_user_defined_function = false, + ruleText = "$PARTITION(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function $PARTITION is not convertible to Databricks SQL")) + } + + "translate HIERARCHYID static method as inconvertible" in { + exampleExpr( + query = "HIERARCHYID::Parse('1/2/3')", + _.expression(), + ir.UnresolvedFunction( + "HIERARCHYID", + List.empty, + is_distinct = false, + is_user_defined_function = false, + ruleText = "HIERARCHYID(...)", + ruleName = "N/A", + tokenName = Some("N/A"), + message = "Function HIERARCHYID is not convertible to Databricks SQL")) + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlLexerSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlLexerSpec.scala new file mode 100644 index 0000000000..a71b1661ae --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlLexerSpec.scala @@ -0,0 +1,39 @@ +package com.databricks.labs.remorph.parsers.tsql + +import org.antlr.v4.runtime.{CharStreams, Token} +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.wordspec.AnyWordSpec + +class TSqlLexerSpec extends AnyWordSpec with Matchers with TableDrivenPropertyChecks { + + private[this] val lexer = new TSqlLexer(null) + + // TODO: Expand this test to cover all token types, and maybe all tokens + "TSqlLexer" should { + "parse string literals and ids" in { + + val testInput = Table( + ("child", "expected"), // Headers + + ("'And it''s raining'", TSqlLexer.STRING), + ("""'Tab\oir'""", TSqlLexer.STRING), + ("""'Tab\'oir'""", TSqlLexer.STRING), + ("'hello'", TSqlLexer.STRING), + (""""quoted""id"""", TSqlLexer.DOUBLE_QUOTE_ID), + ("\"quote\"\"andunquote\"\"\"", TSqlLexer.DOUBLE_QUOTE_ID), + ("_!Jinja0001", TSqlLexer.JINJA_REF)) + + forAll(testInput) { (input: String, expected: Int) => + val inputString = CharStreams.fromString(input) + + lexer.setInputStream(inputString) + val tok: Token = lexer.nextToken() + tok.getType shouldBe expected + tok.getText shouldBe input + } + } + + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlParserTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlParserTestCommon.scala new file mode 100644 index 0000000000..e1975216dd --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlParserTestCommon.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.parsers.tsql + +import com.databricks.labs.remorph.parsers.{ErrorCollector, ParserTestCommon, ProductionErrorCollector} +import org.antlr.v4.runtime.{CharStream, Lexer, TokenStream} +import org.scalatest.Assertions + +trait TSqlParserTestCommon extends ParserTestCommon[TSqlParser] { self: Assertions => + + protected val vc: TSqlVisitorCoordinator = new TSqlVisitorCoordinator(TSqlParser.VOCABULARY, TSqlParser.ruleNames) + override final protected def makeLexer(chars: CharStream): Lexer = new TSqlLexer(chars) + + override final protected def makeErrStrategy(): TSqlErrorStrategy = new TSqlErrorStrategy + + override protected def makeErrListener(chars: String): ErrorCollector = + new ProductionErrorCollector(chars, "-- test string --") + + override final protected def makeParser(tokenStream: TokenStream): TSqlParser = { + val parser = new TSqlParser(tokenStream) + parser + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilderSpec.scala new file mode 100644 index 0000000000..5a76c0a7d4 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/TSqlRelationBuilderSpec.scala @@ -0,0 +1,169 @@ +package com.databricks.labs.remorph.parsers +package tsql + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class TSqlRelationBuilderSpec + extends AnyWordSpec + with TSqlParserTestCommon + with SetOperationBehaviors[TSqlParser] + with Matchers + with MockitoSugar + with ir.IRHelpers { + + override protected def astBuilder: TSqlRelationBuilder = vc.relationBuilder + + "TSqlRelationBuilder" should { + + "translate query with no FROM clause" in { + example("", _.selectOptionalClauses(), ir.NoTable) + } + + "translate FROM clauses" should { + "FROM some_table" in { + example("FROM some_table", _.fromClause(), namedTable("some_table")) + } + "FROM some_schema.some_table" in { + example("FROM some_schema.some_table", _.fromClause(), namedTable("some_schema.some_table")) + } + "FROM some_server..some_table" in { + example("FROM some_server..some_table", _.fromClause(), namedTable("some_server..some_table")) + } + "FROM t1, t2, t3" in { + example( + "FROM t1, t2, t3", + _.fromClause(), + ir.Join( + ir.Join( + namedTable("t1"), + namedTable("t2"), + None, + ir.CrossJoin, + Seq(), + ir.JoinDataType(is_left_struct = false, is_right_struct = false)), + namedTable("t3"), + None, + ir.CrossJoin, + Seq(), + ir.JoinDataType(is_left_struct = false, is_right_struct = false))) + } + } + + "FROM some_table WHERE 1=1" in { + example( + "FROM some_table WHERE 1=1", + _.selectOptionalClauses(), + ir.Filter(namedTable("some_table"), ir.Equals(ir.Literal(1), ir.Literal(1)))) + } + + "FROM some_table GROUP BY some_column" in { + example( + "FROM some_table GROUP BY some_column", + _.selectOptionalClauses(), + ir.Aggregate( + child = namedTable("some_table"), + group_type = ir.GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None)) + } + + "translate ORDER BY clauses" should { + "FROM some_table ORDER BY some_column" in { + example( + "FROM some_table ORDER BY some_column", + _.selectOptionalClauses(), + ir.Sort( + namedTable("some_table"), + Seq(ir.SortOrder(simplyNamedColumn("some_column"), ir.Ascending, ir.SortNullsUnspecified)), + is_global = false)) + } + "FROM some_table ORDER BY some_column ASC" in { + example( + "FROM some_table ORDER BY some_column ASC", + _.selectOptionalClauses(), + ir.Sort( + namedTable("some_table"), + Seq(ir.SortOrder(simplyNamedColumn("some_column"), ir.Ascending, ir.SortNullsUnspecified)), + is_global = false)) + } + "FROM some_table ORDER BY some_column DESC" in { + example( + "FROM some_table ORDER BY some_column DESC", + _.selectOptionalClauses(), + ir.Sort( + namedTable("some_table"), + Seq(ir.SortOrder(simplyNamedColumn("some_column"), ir.Descending, ir.SortNullsUnspecified)), + is_global = false)) + } + } + + "translate combinations of the above" should { + "FROM some_table WHERE 1=1 GROUP BY some_column" in { + example( + "FROM some_table WHERE 1=1 GROUP BY some_column", + _.selectOptionalClauses(), + ir.Aggregate( + child = ir.Filter(namedTable("some_table"), ir.Equals(ir.Literal(1), ir.Literal(1))), + group_type = ir.GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None)) + } + "FROM some_table WHERE 1=1 GROUP BY some_column ORDER BY some_column" in { + example( + "FROM some_table WHERE 1=1 GROUP BY some_column ORDER BY some_column", + _.selectOptionalClauses(), + ir.Sort( + ir.Aggregate( + child = ir.Filter(namedTable("some_table"), ir.Equals(ir.Literal(1), ir.Literal(1))), + group_type = ir.GroupBy, + grouping_expressions = Seq(simplyNamedColumn("some_column")), + pivot = None), + Seq(ir.SortOrder(simplyNamedColumn("some_column"), ir.Ascending, ir.SortNullsUnspecified)), + is_global = false)) + } + } + + "WITH a (b, c) AS (SELECT x, y FROM d)" in { + example( + "WITH a (b, c) AS (SELECT x, y FROM d)", + _.withExpression(), + ir.SubqueryAlias( + ir.Project(namedTable("d"), Seq(simplyNamedColumn("x"), simplyNamedColumn("y"))), + ir.Id("a"), + Seq(ir.Id("b"), ir.Id("c")))) + } + + "SELECT DISTINCT a, b AS bb FROM t" in { + example( + "SELECT DISTINCT a, b AS bb FROM t", + _.selectStatement(), + ir.Project( + ir.Deduplicate( + namedTable("t"), + column_names = Seq(ir.Id("a"), ir.Id("bb")), + all_columns_as_keys = false, + within_watermark = false), + Seq(simplyNamedColumn("a"), ir.Alias(simplyNamedColumn("b"), ir.Id("bb"))))) + } + + behave like setOperationsAreTranslated(_.queryExpression()) + + "SELECT a, b AS bb FROM (SELECT x, y FROM d) AS t (aliasA, 'aliasB')" in { + example( + "SELECT a, b AS bb FROM (SELECT x, y FROM d) AS t (aliasA, 'aliasB')", + _.selectStatement(), + ir.Project( + ir.TableAlias( + ColumnAliases( + ir.Project( + ir.NamedTable("d", Map(), is_streaming = false), + Seq(ir.Column(None, ir.Id("x")), ir.Column(None, ir.Id("y")))), + Seq(ir.Id("aliasA"), ir.Id("aliasB"))), + "t"), + Seq(ir.Column(None, ir.Id("a")), ir.Alias(ir.Column(None, ir.Id("b")), ir.Id("bb"))))) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwardsTest.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwardsTest.scala new file mode 100644 index 0000000000..9661678a0b --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwardsTest.scala @@ -0,0 +1,44 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.parsers.PlanComparison +import com.databricks.labs.remorph.intermediate._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class PullLimitUpwardsTest extends AnyWordSpec with PlanComparison with Matchers with IRHelpers { + "from project" in { + val out = PullLimitUpwards.apply(Project(Limit(namedTable("a"), Literal(10)), Seq(Star(None)))) + comparePlans(out, Limit(Project(namedTable("a"), Seq(Star())), Literal(10))) + } + + "from project with filter" in { + val out = PullLimitUpwards.apply( + Filter( + Project(Limit(namedTable("a"), Literal(10)), Seq(Star(None))), + GreaterThan(UnresolvedAttribute("b"), Literal(1)))) + comparePlans( + out, + Limit( + Filter(Project(namedTable("a"), Seq(Star())), GreaterThan(UnresolvedAttribute("b"), Literal(1))), + Literal(10))) + } + + "from project with filter order by" in { + val out = PullLimitUpwards.apply( + Sort( + Filter( + Project(Limit(namedTable("a"), Literal(10)), Seq(Star(None))), + GreaterThan(UnresolvedAttribute("b"), Literal(1))), + Seq(SortOrder(UnresolvedAttribute("b"))), + is_global = false)) + comparePlans( + out, + Limit( + Sort( + Filter(Project(namedTable("a"), Seq(Star())), GreaterThan(UnresolvedAttribute("b"), Literal(1))), + Seq(SortOrder(UnresolvedAttribute("b"))), + is_global = false), + Literal(10))) + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapperSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapperSpec.scala new file mode 100644 index 0000000000..342be3ac07 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TSqlCallMapperSpec.scala @@ -0,0 +1,107 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.{intermediate => ir} +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TSqlCallMapperSpec extends AnyWordSpec with Matchers with ir.IRHelpers { + + private[this] val tsqlCallMapper = new TSqlCallMapper + + implicit class CallMapperOps(fn: ir.Fn) { + def becomes(expected: ir.Expression): Assertion = { + tsqlCallMapper.convert(fn) shouldBe expected + } + } + + "CHECKSUM_AGG" should { + "transpile to MD5 function" in { + ir.CallFunction("CHECKSUM_AGG", Seq(ir.Id("col1"))) becomes ir.Md5( + ir.ConcatWs(Seq(ir.StringLiteral(","), ir.CollectList(ir.Id("col1"))))) + } + } + + "SET_BIT" should { + "transpile to bitwise logic" in { + ir.CallFunction("SET_BIT", Seq(ir.Literal(42.toShort), ir.Literal(7.toShort), ir.Literal(0.toShort))) becomes + ir.BitwiseOr( + ir.BitwiseAnd( + ir.Literal(42.toShort), + ir.BitwiseXor(ir.Literal(-1), ir.ShiftLeft(ir.Literal(1), ir.Literal(7.toShort)))), + ir.ShiftRight(ir.Literal(0.toShort), ir.Literal(7.toShort))) + + ir.CallFunction("SET_BIT", Seq(ir.Literal(42.toShort), ir.Literal(7.toShort))) becomes + ir.BitwiseOr(ir.Literal(42.toShort), ir.ShiftLeft(ir.Literal(1), ir.Literal(7.toShort))) + } + } + + "DATEADD" should { + "transpile to DATE_ADD" in { + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("day"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.DateAdd( + simplyNamedColumn("col1"), + ir.Literal(42.toShort)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("week"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.DateAdd( + simplyNamedColumn("col1"), + ir.Multiply(ir.Literal(42.toShort), ir.Literal(7))) + } + + "transpile to ADD_MONTHS" in { + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("Month"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.AddMonths( + simplyNamedColumn("col1"), + ir.Literal(42.toShort)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("qq"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.AddMonths( + simplyNamedColumn("col1"), + ir.Multiply(ir.Literal(42.toShort), ir.Literal(3))) + } + + "transpile to INTERVAL" in { + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("hour"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.HOUR_INTERVAL)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("minute"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.MINUTE_INTERVAL)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("second"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.SECOND_INTERVAL)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("millisecond"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.MILLISECOND_INTERVAL)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("mcs"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.MICROSECOND_INTERVAL)) + + ir.CallFunction( + "DATEADD", + Seq(simplyNamedColumn("ns"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) becomes ir.Add( + simplyNamedColumn("col1"), + ir.KnownInterval(ir.Literal(42.toShort), ir.NANOSECOND_INTERVAL)) + } + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubqueryTest.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubqueryTest.scala new file mode 100644 index 0000000000..e221cd369f --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubqueryTest.scala @@ -0,0 +1,71 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.parsers.PlanComparison +import com.databricks.labs.remorph.intermediate._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TopPercentToLimitSubqueryTest extends AnyWordSpec with PlanComparison with Matchers with IRHelpers { + "PERCENT applies" in { + val out = + (new TopPercentToLimitSubquery).apply(TopPercent(Project(namedTable("Employees"), Seq(Star())), Literal(10))) + comparePlans( + out, + WithCTE( + Seq( + SubqueryAlias(Project(namedTable("Employees"), Seq(Star())), Id("_limited1")), + SubqueryAlias( + Project( + UnresolvedRelation(ruleText = "_limited1", message = "Unresolved relation _limited1"), + Seq(Alias(Count(Seq(Star())), Id("count")))), + Id("_counted1"))), + Limit( + Project( + UnresolvedRelation( + ruleText = "_limited1", + message = "Unresolved relation _limited1", + ruleName = "rule name undetermined", + tokenName = None), + Seq(Star())), + ScalarSubquery( + Project( + UnresolvedRelation( + ruleText = "_counted1", + message = "Unresolved relation _counted1", + ruleName = "N/A", + tokenName = Some("N/A")), + Seq(Cast(Multiply(Divide(Id("count"), Literal(10)), Literal(100)), LongType))))))) + } + + "PERCENT WITH TIES applies" in { + val out = (new TopPercentToLimitSubquery).apply( + Sort( + Project(TopPercent(namedTable("Employees"), Literal(10), with_ties = true), Seq(Star())), + Seq(SortOrder(UnresolvedAttribute("a"))), + is_global = false)) + comparePlans( + out, + WithCTE( + Seq( + SubqueryAlias(Project(namedTable("Employees"), Seq(Star())), Id("_limited1")), + SubqueryAlias( + Project( + UnresolvedRelation(ruleText = "_limited1", message = "Unresolved _limited1"), + Seq( + Star(), + Alias( + Window(NTile(Literal(100)), sort_order = Seq(SortOrder(UnresolvedAttribute("a")))), + Id("_percentile1")))), + Id("_with_percentile1"))), + Filter( + Project( + UnresolvedRelation(ruleText = "_with_percentile1", message = "Unresolved _with_percentile1"), + Seq(Star())), + LessThanOrEqual( + UnresolvedAttribute( + unparsed_identifier = "_percentile1", + ruleText = "_percentile1", + message = "Unresolved _percentile1"), + Divide(Literal(10), Literal(100)))))) + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultActionTest.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultActionTest.scala new file mode 100644 index 0000000000..dde7f70a32 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/tsql/rules/TrapInsertDefaultActionTest.scala @@ -0,0 +1,24 @@ +package com.databricks.labs.remorph.parsers.tsql.rules + +import com.databricks.labs.remorph.parsers.PlanComparison +import com.databricks.labs.remorph.intermediate._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +class TrapInsertDefaultActionTest extends AnyWordSpec with PlanComparison with Matchers with IRHelpers { + "TrapInsertDefaultsAction" should { + "throw an exception when the MERGE WHEN NOT MATCHED action is 'INSERT DEFAULT VALUES'" in { + val merge = MergeIntoTable( + namedTable("table"), + namedTable("table2"), + Noop, + Seq.empty, + Seq(InsertDefaultsAction(None)), + Seq.empty) + assertThrows[IllegalArgumentException] { + TrapInsertDefaultsAction(merge) + } + } + } + +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessorTest.scala b/core/src/test/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessorTest.scala new file mode 100644 index 0000000000..58eb509740 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/preprocessors/jinja/JinjaProcessorTest.scala @@ -0,0 +1,63 @@ +package com.databricks.labs.remorph.preprocessors.jinja + +import com.databricks.labs.remorph.transpilers.TSqlToDatabricksTranspiler +import com.databricks.labs.remorph.{OkResult, PartialResult, PreProcessing, TranspilerState} +import org.scalatest.wordspec.AnyWordSpec + +// Note that this test is as much for debugging purposes as anything else, but it does create the more complex +// cases of template use. +// Integration tests are really where it's at +class JinjaProcessorTest extends AnyWordSpec { + + "Preprocessor" should { + "pre statement block" in { + + val transpiler = new TSqlToDatabricksTranspiler + + // Note that template replacement means that token lines and offsets will be out of sync with the start point + // and we will need to insert positional tokens in a subsequent PR, so that the lexer can account for the missing + // text. Another option may be to pas the replacement _!Jinja9999 with spaces and newlines to match the length + // of the text they are replacing. + val input = PreProcessing("""{%- set payment_methods = dbt_utils.get_column_values( + | table=ref('raw_payments'), + | column='payment_method' + |) -%} + | + |select + | order_id, + | {%- for payment_method in payment_methods %} + | sum(case when payment_method = '{{payment_method}}' then amount end) as {{payment_method}}_amount + | {%- if not loop.last %},{% endif -%} + | {% endfor %} + | from {{ ref('raw_payments') }} + | group by 1 + |""".stripMargin) + + // Note that we cannot format the output because the Scala based formatter we have does not handle DBT/Jinja + // templates and therefore breaks the output + val output = """{%- set payment_methods = dbt_utils.get_column_values( + | table=ref('raw_payments'), + | column='payment_method' + |) -%} + | + |SELECT order_id, + | {%- for payment_method in payment_methods %} + | SUM(CASE WHEN payment_method = '{{payment_method}}' THEN amount END) AS {{payment_method}}_amount + | {%- if not loop.last %},{% endif -%} + | {% endfor %} + | FROM {{ ref('raw_payments') }} + | GROUP BY 1;""".stripMargin + + val result = transpiler.transpile(input).runAndDiscardState(TranspilerState(input)) + + val processed = result match { + case OkResult(output) => + output + case PartialResult(output, error) => + output + case _ => "" + } + assert(output == processed) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/queries/ExampleDebuggerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/queries/ExampleDebuggerTest.scala new file mode 100644 index 0000000000..49dea9e8f7 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/queries/ExampleDebuggerTest.scala @@ -0,0 +1,29 @@ +package com.databricks.labs.remorph.queries + +import com.databricks.labs.remorph.{OkResult, TransformationConstructors} +import com.databricks.labs.remorph.parsers.PlanParser +import com.databricks.labs.remorph.intermediate.NoopNode +import org.antlr.v4.runtime.ParserRuleContext +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar + +class ExampleDebuggerTest extends AnyWordSpec with Matchers with MockitoSugar with TransformationConstructors { + "ExampleDebugger" should { + "work" in { + val buf = new StringBuilder + val parser = mock[PlanParser[_]] + when(parser.parse).thenReturn(lift(OkResult(ParserRuleContext.EMPTY))) + when(parser.visit(any())).thenReturn(lift(OkResult(NoopNode))) + + val debugger = new ExampleDebugger(parser, x => buf.append(x), "snowflake") + val name = s"${NestedFiles.projectRoot}/tests/resources/functional/snowflake/nested_query_with_json_1.sql" + + debugger.debugExample(name) + + buf.toString() should equal("NoopNode$\n") + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SetOperationBehaviors.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SetOperationBehaviors.scala new file mode 100644 index 0000000000..da5a44ffc7 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SetOperationBehaviors.scala @@ -0,0 +1,45 @@ +package com.databricks.labs.remorph.transpilers + +import org.scalatest.wordspec.AnyWordSpec + +trait SetOperationBehaviors { this: TranspilerTestCommon with AnyWordSpec => + + protected[this] final def correctlyTranspile(expectedTranspilation: (String, String)): Unit = { + val (originalTSql, expectedDatabricksSql) = expectedTranspilation + s"correctly transpile: $originalTSql" in { + originalTSql transpilesTo expectedDatabricksSql + } + } + + protected[this] def expectedSetOperationTranslations: Map[String, String] = Map( + "SELECT a, b FROM c UNION SELECT x, y FROM z" -> "(SELECT a, b FROM c) UNION (SELECT x, y FROM z);", + "SELECT a, b FROM c UNION ALL SELECT x, y FROM z" -> "(SELECT a, b FROM c) UNION ALL (SELECT x, y FROM z);", + "SELECT a, b FROM c EXCEPT SELECT x, y FROM z" -> "(SELECT a, b FROM c) EXCEPT (SELECT x, y FROM z);", + "SELECT a, b FROM c INTERSECT SELECT x, y FROM z" -> "(SELECT a, b FROM c) INTERSECT (SELECT x, y FROM z);", + "SELECT a, b FROM c UNION (SELECT x, y FROM z)" -> "(SELECT a, b FROM c) UNION (SELECT x, y FROM z);", + "(SELECT a, b FROM c) UNION SELECT x, y FROM z" -> "(SELECT a, b FROM c) UNION (SELECT x, y FROM z);", + "(SELECT a, b FROM c) UNION ALL SELECT x, y FROM z" -> "(SELECT a, b FROM c) UNION ALL (SELECT x, y FROM z);", + "(SELECT a, b FROM c)" -> "SELECT a, b FROM c;", + """SELECT a, b FROM c + |UNION + |SELECT d, e FROM f + |UNION ALL + |SELECT g, h FROM i + |INTERSECT + |SELECT j, k FROM l + |EXCEPT + |SELECT m, n FROM o""".stripMargin -> + """(((SELECT a, b FROM c) + | UNION + | (SELECT d, e FROM f)) + | UNION ALL + | ((SELECT g, h FROM i) + | INTERSECT + | (SELECT j, k FROM l))) + |EXCEPT + |(SELECT m, n FROM o);""".stripMargin) + + def setOperationsAreTranspiled(): Unit = { + expectedSetOperationTranslations.foreach(correctlyTranspile) + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala new file mode 100644 index 0000000000..27fed26378 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala @@ -0,0 +1,492 @@ +package com.databricks.labs.remorph.transpilers + +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeToDatabricksTranspilerTest extends AnyWordSpec with TranspilerTestCommon with SetOperationBehaviors { + + protected val transpiler = new SnowflakeToDatabricksTranspiler + + "transpile TO_NUMBER and TO_DECIMAL" should { + "transpile TO_NUMBER" in { + "select TO_NUMBER(EXPR) from test_tbl;" transpilesTo + """SELECT CAST(EXPR AS DECIMAL(38, 0)) FROM test_tbl + | ;""".stripMargin + } + + "transpile TO_NUMBER with precision and scale" in { + "select TO_NUMBER(EXPR,38,0) from test_tbl;" transpilesTo + """SELECT CAST(EXPR AS DECIMAL(38, 0)) FROM test_tbl + | ;""".stripMargin + } + + "transpile TO_DECIMAL" in { + "select TO_DECIMAL(EXPR) from test_tbl;" transpilesTo + """SELECT CAST(EXPR AS DECIMAL(38, 0)) FROM test_tbl + | ;""".stripMargin + } + } + + "snowsql commands" should { + + "transpile BANG with semicolon" in { + "!set error_flag = true;" transpilesTo + """/* The following issues were detected: + | + | Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand + | !set error_flag = true; + | */""".stripMargin + } + "transpile BANG without semicolon" in { + "!print Include This Text" transpilesTo + """/* The following issues were detected: + | + | Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand + | !print Include This Text + | */""".stripMargin + } + "transpile BANG with options" in { + "!options catch=true" transpilesTo + """/* The following issues were detected: + | + | Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand + | !options catch=true + | */""".stripMargin + } + "transpile BANG with negative scenario unknown command" in { + "!test unknown command".failsTranspilation + } + "transpile BANG with negative scenario unknown command2" in { + "!abc set=abc".failsTranspilation + } + } + + "Snowflake Alter commands" should { + + "ALTER TABLE t1 ADD COLUMN c1 INTEGER" in { + "ALTER TABLE t1 ADD COLUMN c1 INTEGER;" transpilesTo ( + s"""ALTER TABLE + | t1 + |ADD + | COLUMN c1 DECIMAL(38, 0);""".stripMargin + ) + } + + "ALTER TABLE t1 ADD COLUMN c1 INTEGER, c2 VARCHAR;" in { + "ALTER TABLE t1 ADD COLUMN c1 INTEGER, c2 VARCHAR;" transpilesTo + s"""ALTER TABLE + | t1 + |ADD + | COLUMN c1 DECIMAL(38, 0), + | c2 STRING;""".stripMargin + } + + "ALTER TABLE t1 DROP COLUMN c1;" in { + "ALTER TABLE t1 DROP COLUMN c1;" transpilesTo ( + s"""ALTER TABLE + | t1 DROP COLUMN c1;""".stripMargin + ) + } + + "ALTER TABLE t1 DROP COLUMN c1, c2;" in { + "ALTER TABLE t1 DROP COLUMN c1, c2;" transpilesTo + s"""ALTER TABLE + | t1 DROP COLUMN c1, + | c2;""".stripMargin + } + + "ALTER TABLE t1 RENAME COLUMN c1 to c2;" in { + "ALTER TABLE t1 RENAME COLUMN c1 to c2;" transpilesTo + s"""ALTER TABLE + | t1 RENAME COLUMN c1 to c2;""".stripMargin + } + + "ALTER TABLE s.t1 DROP CONSTRAINT pk" in { + "ALTER TABLE s.t1 DROP CONSTRAINT pk;" transpilesTo + s"""ALTER TABLE + | s.t1 DROP CONSTRAINT pk;""".stripMargin + } + } + + "Snowflake transpiler" should { + + "transpile queries" in { + + "SELECT * FROM t1 WHERE col1 != 100;" transpilesTo ( + s"""SELECT + | * + |FROM + | t1 + |WHERE + | col1 != 100;""".stripMargin + ) + + "SELECT * FROM t1;" transpilesTo + s"""SELECT + | * + |FROM + | t1;""".stripMargin + + "SELECT t1.* FROM t1 INNER JOIN t2 ON t2.c2 = t2.c1;" transpilesTo + s"""SELECT + | t1.* + |FROM + | t1 + | INNER JOIN t2 ON t2.c2 = t2.c1;""".stripMargin + + "SELECT t1.c2 FROM t1 LEFT JOIN t2 USING (c2);" transpilesTo + s"""SELECT + | t1.c2 + |FROM + | t1 + | LEFT JOIN t2 + |USING + | (c2);""".stripMargin + + "SELECT c1::DOUBLE FROM t1;" transpilesTo + s"""SELECT + | CAST(c1 AS DOUBLE) + |FROM + | t1;""".stripMargin + + "SELECT JSON_EXTRACT_PATH_TEXT(json_data, path_col) FROM demo1;" transpilesTo + """SELECT + | GET_JSON_OBJECT(json_data, CONCAT('$.', path_col)) + |FROM + | demo1;""".stripMargin + } + + "transpile select distinct query" in { + s"""SELECT DISTINCT c1, c2 FROM T1""".stripMargin transpilesTo + s"""SELECT + | DISTINCT c1, + | c2 + |FROM + | T1;""".stripMargin + } + + "transpile window functions" in { + s"""SELECT LAST_VALUE(c1) + |IGNORE NULLS OVER (PARTITION BY t1.c2 ORDER BY t1.c3 DESC + |RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS dc4 + |FROM t1;""".stripMargin transpilesTo + s"""SELECT + | LAST(c1) IGNORE NULLS OVER ( + | PARTITION BY + | t1.c2 + | ORDER BY + | t1.c3 DESC NULLS FIRST + | RANGE + | BETWEEN UNBOUNDED PRECEDING + | AND CURRENT ROW + | ) AS dc4 + |FROM + | t1;""".stripMargin + } + + "transpile MONTHS_BETWEEN function" in { + "SELECT MONTHS_BETWEEN('2021-02-01'::DATE, '2021-01-01'::DATE);" transpilesTo + """SELECT + | MONTHS_BETWEEN(CAST('2021-02-01' AS DATE), CAST('2021-01-01' AS DATE), TRUE) + | ;""".stripMargin + + """SELECT + | MONTHS_BETWEEN('2019-03-01 02:00:00'::TIMESTAMP, '2019-02-15 01:00:00'::TIMESTAMP) + | AS mb;""".stripMargin transpilesTo + """SELECT + | MONTHS_BETWEEN(CAST('2019-03-01 02:00:00' AS TIMESTAMP), CAST('2019-02-15 01:00:00' + | AS TIMESTAMP), TRUE) AS mb + | ;""".stripMargin + } + + "transpile ARRAY_REMOVE function" in { + "SELECT ARRAY_REMOVE([1, 2, 3], 1);" transpilesTo + "SELECT ARRAY_REMOVE(ARRAY(1, 2, 3), 1);" + + "SELECT ARRAY_REMOVE([2, 3, 4.11::DOUBLE, 4, NULL], 4);" transpilesTo + "SELECT ARRAY_REMOVE(ARRAY(2, 3, CAST(4.11 AS DOUBLE), 4, NULL), 4);" + + // TODO - Enable this test case once the VARIANT casting is implemented. + // In Snow, if the value to remove is a VARCHAR, + // it is required to cast the value to VARIANT. + // "SELECT ARRAY_REMOVE(['a', 'b', 'c'], 'a'::VARIANT);" transpilesTo + // "SELECT ARRAY_REMOVE(ARRAY('a', 'b', 'c'), 'a');" + } + "transpile ARRAY_SORT function" in { + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], TRUE, True);" transpilesTo + """SELECT + | SORT_ARRAY(ARRAY(0, 2, 4, NULL, 5, NULL)) + | ;""".stripMargin + + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], FALSE, False);" transpilesTo + """SELECT + | SORT_ARRAY(ARRAY(0, 2, 4, NULL, 5, NULL), false) + | ;""".stripMargin + + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], TRUE, FALSE);" transpilesTo + """SELECT + | ARRAY_SORT( + | ARRAY(0, 2, 4, NULL, 5, NULL), + | (left, right) -> CASE + | WHEN left IS NULL + | AND right IS NULL THEN 0 + | WHEN left IS NULL THEN 1 + | WHEN right IS NULL THEN -1 + | WHEN left < right THEN -1 + | WHEN left > right THEN 1 + | ELSE 0 + | END + | ) + | ;""".stripMargin + + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], False, true);" transpilesTo + """SELECT + | ARRAY_SORT( + | ARRAY(0, 2, 4, NULL, 5, NULL), + | (left, right) -> CASE + | WHEN left IS NULL + | AND right IS NULL THEN 0 + | WHEN left IS NULL THEN -1 + | WHEN right IS NULL THEN 1 + | WHEN left < right THEN 1 + | WHEN left > right THEN -1 + | ELSE 0 + | END + | ) + | ;""".stripMargin + + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], TRUE, 1 = 1);".failsTranspilation + "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], 1 = 1, TRUE);".failsTranspilation + } + + "GROUP BY ALL" in { + "SELECT car_model, COUNT(DISTINCT city) FROM dealer GROUP BY ALL;" transpilesTo + "SELECT car_model, COUNT(DISTINCT city) FROM dealer GROUP BY ALL;" + } + + "transpile LCA replacing aliases" in { + "SELECT column_a AS alias_a FROM table_a WHERE alias_a = '123';" transpilesTo + "SELECT column_a AS alias_a FROM table_a WHERE column_a = '123';" + } + + "transpile LCA replacing aliased literals" in { + "SELECT '123' as alias_a FROM table_a where alias_a = '123';" transpilesTo + "SELECT '123' as alias_a FROM table_a where '123' = '123';" + } + + "transpile LCA with aliased table" in { + "SELECT t.col1, t.col2, t.col3 AS ca FROM table1 t WHERE ca in ('v1', 'v2');" transpilesTo + "SELECT t.col1, t.col2, t.col3 AS ca FROM table1 as t WHERE t.col3 in ('v1', 'v2');" + } + + "transpile LCA with partition" in { + "SELECT t.col1 AS ca, ROW_NUMBER() OVER (PARTITION by ca ORDER BY ca) FROM table1 t;" transpilesTo + "SELECT t.col1 AS ca, ROW_NUMBER() OVER (PARTITION by t.col1 ORDER BY t.col1 ASC NULLS LAST) FROM table1 AS t;" + } + + "transpile LCA with function" in { + "SELECT col1 AS ca FROM table1 where SUBSTR(ca, 1, 3) = '123';" transpilesTo + "SELECT col1 AS ca FROM table1 where SUBSTR(col1, 1, 3) = '123';" + } + } + + "Snowflake transpile function with optional brackets" should { + + "SELECT CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_TIME, LOCALTIME, LOCALTIMESTAMP FROM t1" in { + s"""SELECT CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_TIME, + |LOCALTIME, LOCALTIMESTAMP FROM t1""".stripMargin transpilesTo ( + s"""SELECT + | CURRENT_DATE(), + | CURRENT_TIMESTAMP(), + | DATE_FORMAT(CURRENT_TIMESTAMP(), 'HH:mm:ss'), + | DATE_FORMAT(CURRENT_TIMESTAMP(), 'HH:mm:ss'), + | CURRENT_TIMESTAMP() + |FROM + | t1;""".stripMargin + ) + } + + "SELECT CURRENT_TIMESTAMP(1) FROM t1 where dt < CURRENT_TIMESTAMP" in { + s"""SELECT CURRENT_TIMESTAMP(1) FROM t1 where dt < CURRENT_TIMESTAMP""".stripMargin transpilesTo ( + s"""SELECT + | DATE_FORMAT(CURRENT_TIMESTAMP(), 'yyyy-MM-dd HH:mm:ss.SSS') + |FROM + | t1 + |WHERE + | dt < CURRENT_TIMESTAMP();""".stripMargin + ) + } + + "SELECT CURRENT_TIME(1) FROM t1 where dt < CURRENT_TIMESTAMP()" in { + s"""SELECT CURRENT_TIME(1) FROM t1 where dt < CURRENT_TIMESTAMP()""".stripMargin transpilesTo ( + s"""SELECT + | DATE_FORMAT(CURRENT_TIMESTAMP(), 'HH:mm:ss') + |FROM + | t1 + |WHERE + | dt < CURRENT_TIMESTAMP();""".stripMargin + ) + } + + "SELECT LOCALTIME() FROM t1 where dt < LOCALTIMESTAMP" in { + s"""SELECT LOCALTIME() FROM t1 where dt < LOCALTIMESTAMP()""".stripMargin transpilesTo ( + s"""SELECT + | DATE_FORMAT(CURRENT_TIMESTAMP(), 'HH:mm:ss') + |FROM + | t1 + |WHERE + | dt < CURRENT_TIMESTAMP();""".stripMargin + ) + } + } + + "Snowflake Execute commands" should { + + "EXECUTE TASK task1;" in { + "EXECUTE TASK task1;" transpilesTo + """/* The following issues were detected: + | + | Execute Task is not yet supported + | EXECUTE TASK task1 + | */""".stripMargin + } + } + + "Snowflake MERGE commands" should { + + "MERGE;" in { + """MERGE INTO target_table AS t + |USING source_table AS s + |ON t.id = s.id + |WHEN MATCHED AND s.status = 'active' THEN + | UPDATE SET t.value = s.value AND status = 'active' + |WHEN MATCHED AND s.status = 'inactive' THEN + | DELETE + |WHEN NOT MATCHED THEN + | INSERT (id, value, status) VALUES (s.id, s.value, s.status);""".stripMargin transpilesTo + s"""MERGE INTO target_table AS t + |USING source_table AS s + |ON t.id = s.id + |WHEN MATCHED AND s.status = 'active' THEN + | UPDATE SET t.value = s.value AND status = 'active' + |WHEN MATCHED AND s.status = 'inactive' THEN + | DELETE + |WHEN NOT MATCHED THEN + | INSERT (id, value, status) VALUES (s.id, s.value, s.status);""".stripMargin + } + } + + override protected[this] final val expectedSetOperationTranslations: Map[String, String] = { + super.expectedSetOperationTranslations ++ Map( + "SELECT a, b FROM c MINUS SELECT x, y FROM z" -> "(SELECT a, b FROM c) EXCEPT (SELECT x, y FROM z);", + """SELECT a, b FROM c + |UNION + |SELECT d, e FROM f + |MINUS + |SELECT g, h FROM i + |INTERSECT + |SELECT j, k FROM l + |EXCEPT + |SELECT m, n FROM o""".stripMargin -> + """(((SELECT a, b FROM c) + | UNION + | (SELECT d, e FROM f)) + | EXCEPT + | ((SELECT g, h FROM i) + | INTERSECT + | (SELECT j, k FROM l))) + |EXCEPT + |(SELECT m, n FROM o);""".stripMargin) + } + + "Set operations" should { + behave like setOperationsAreTranspiled() + } + + "Common Table Expressions (CTEs)" should { + "support expressions" in { + """WITH + | a AS (1), + | b AS (2), + | t (d, e) AS (SELECT 4, 5), + | c AS (3) + |SELECT + | a + b, + | a * c, + | a * t.d + |FROM t;""".stripMargin transpilesTo + """WITH + | t (d, e) AS (SELECT 4, 5) + |SELECT + | 1 + 2, + | 1 * 3, + | 1 * t.d + |FROM + | t;""".stripMargin + } + "have lower precedence than set operations" in { + """WITH + | a AS (SELECT b, c, d FROM e), + | f AS (SELECT g, h, i FROM j) + |SELECT * FROM a + |UNION + |SELECT * FROM f;""".stripMargin transpilesTo + """WITH + | a AS (SELECT b, c, d FROM e), + | f AS (SELECT g, h, i FROM j) + |(SELECT * FROM a) + |UNION + |(SELECT * FROM f);""".stripMargin + } + "allow nested set operations" in { + """WITH + | a AS ( + | SELECT b, c, d from e + | UNION + | SELECT e, f, g from h), + | i AS (SELECT j, k, l from m) + |SELECT * FROM a + |UNION + |SELECT * FROM i;""".stripMargin transpilesTo + """WITH + | a AS ( + | (SELECT b, c, d from e) + | UNION + | (SELECT e, f, g from h) + | ), + | i AS (SELECT j, k, l from m) + |(SELECT * FROM a) + |UNION + |(SELECT * FROM i);""".stripMargin + } + } + + "Batch statements" should { + "survive invalid SQL" in { + """ + |CREATE TABLE t1 (x VARCHAR); + |SELECT x y z; + |SELECT 3 FROM t3; + |""".stripMargin transpilesTo (""" + |CREATE TABLE t1 (x STRING); + |/* The following issues were detected: + | + | Unparsed input - ErrorNode encountered + | Unparsable text: SELECTxyz + | */ + |/* The following issues were detected: + | + | Unparsed input - ErrorNode encountered + | Unparsable text: SELECT + | Unparsable text: x + | Unparsable text: y + | Unparsable text: z + | Unparsable text: parser recovered by ignoring: SELECTxyz; + | */ + | SELECT + | 3 + |FROM + | t3;""".stripMargin, false) + + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala new file mode 100644 index 0000000000..a219c1cee7 --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala @@ -0,0 +1,37 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult} +import com.databricks.labs.remorph.generators.py.RuffFormatter +import org.scalatest.wordspec.AnyWordSpec + +class SnowflakeToPySparkTranspilerTest extends AnyWordSpec with TranspilerTestCommon { + protected val transpiler = new SnowflakeToPySparkTranspiler + private[this] val formatter = new RuffFormatter + override def format(input: String): String = formatter.format(input) match { + case OkResult(formatted) => formatted + case KoResult(_, error) => fail(error.msg) + case PartialResult(output, error) => fail(s"Partial result: $output, error: $error") + } + + "Snowflake SQL" should { + "transpile window functions" in { + s"""SELECT LAST_VALUE(c1) + |IGNORE NULLS OVER (PARTITION BY t1.c2 ORDER BY t1.c3 DESC + |RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS dc4 + |FROM t1;""".stripMargin transpilesTo + """import pyspark.sql.functions as F + |from pyspark.sql.window import Window + | + |spark.table("t1").select( + | F.last(F.col("c1")) + | .over( + | Window.partitionBy(F.col("t1.c2")) + | .orderBy(F.col("t1.c3").desc_nulls_first()) + | .rangeBetween(Window.unboundedPreceding, Window.currentRow) + | ) + | .alias("dc4") + |) + |""".stripMargin + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala new file mode 100644 index 0000000000..786e597dae --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala @@ -0,0 +1,40 @@ +package com.databricks.labs.remorph.transpilers + +import com.databricks.labs.remorph._ +import com.databricks.labs.remorph.coverage.ErrorEncoders +import io.circe.syntax._ +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers + +trait TranspilerTestCommon extends Matchers with Formatter with ErrorEncoders { + + protected def transpiler: Transpiler + + implicit class TranspilerTestOps(input: String) { + def transpilesTo(expectedOutput: String, failOnError: Boolean = true): Assertion = { + val formattedExpectedOutput = format(expectedOutput) + transpiler.transpile(PreProcessing(input)).runAndDiscardState(TranspilerState()) match { + case OkResult(output) => + format(output) + val formattedOutput = format(output) + formattedOutput shouldBe formattedExpectedOutput + case PartialResult(output, err) => + if (failOnError) { + fail(err.asJson.noSpaces) + } else { + val formattedOutput = format(output) + formattedOutput shouldBe formattedExpectedOutput + } + case KoResult(_, err) => fail(err.asJson.noSpaces) + } + } + def failsTranspilation: Assertion = { + transpiler.transpile(PreProcessing(input)).runAndDiscardState(TranspilerState()) match { + case KoResult(_, _) => succeed + case PartialResult(_, _) => succeed + case x => + fail(s"query was expected to fail transpilation but didn't: $x") + } + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/TsqlToDatabricksTranspilerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TsqlToDatabricksTranspilerTest.scala new file mode 100644 index 0000000000..795dd3e9bf --- /dev/null +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TsqlToDatabricksTranspilerTest.scala @@ -0,0 +1,23 @@ +package com.databricks.labs.remorph.transpilers + +import org.scalatest.wordspec.AnyWordSpec + +class TsqlToDatabricksTranspilerTest extends AnyWordSpec with TranspilerTestCommon with SetOperationBehaviors { + + protected final val transpiler = new TSqlToDatabricksTranspiler + + "The TSQL-to-Databricks transpiler" when { + "transpiling set operations" should { + behave like setOperationsAreTranspiled() + } + "mixing CTEs with set operations" should { + correctlyTranspile( + """WITH cte1 AS (SELECT a, b FROM c), + | cte2 AS (SELECT x, y FROM z) + |SELECT a, b FROM cte1 UNION SELECT x, y FROM cte2""".stripMargin -> + """WITH cte1 AS (SELECT a, b FROM c), + | cte2 AS (SELECT x, y FROM z) + |(SELECT a, b FROM cte1) UNION (SELECT x, y FROM cte2);""".stripMargin) + } + } +} diff --git a/core/src/test/scala/toolchain/ToolchainSpec.scala b/core/src/test/scala/toolchain/ToolchainSpec.scala new file mode 100644 index 0000000000..7135f14792 --- /dev/null +++ b/core/src/test/scala/toolchain/ToolchainSpec.scala @@ -0,0 +1,41 @@ +package toolchain.testsource + +import com.databricks.labs.remorph.transpilers.DirectorySource +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import java.nio.file.Paths + +class ToolchainSpec extends AnyWordSpec with Matchers { + private[this] val projectBaseDir = Paths.get(".").toAbsolutePath.normalize + private[this] val moduleBaseDir = + if (projectBaseDir.endsWith("core")) projectBaseDir else projectBaseDir.resolve("core") + private[this] val testSourceDir = moduleBaseDir.resolve("src/test/resources/toolchain/testsource").toString + + "Toolchain" should { + "traverse a directory of files" in { + val directorySourceWithoutFilter = new DirectorySource(testSourceDir) + val fileNames = directorySourceWithoutFilter.map(_.filename).toSeq + val expectedFileNames = Seq("test_1.sql", "test_2.sql", "test_3.sql", "not_sql.md") + fileNames should contain theSameElementsAs expectedFileNames + } + + "traverse and filter a directory of files" in { + val directorySource = + new DirectorySource(testSourceDir, Some(_.getFileName.toString.endsWith(".sql"))) + val fileNames = directorySource.map(_.filename).toSeq + val expectedFileNames = Seq("test_1.sql", "test_2.sql", "test_3.sql") + fileNames should contain theSameElementsAs expectedFileNames + } + + "retrieve source code successfully from files" in { + val directorySource = + new DirectorySource(testSourceDir, Some(_.getFileName.toString.endsWith(".sql"))) + val sources = directorySource.map(_.source).toSeq + sources should not be empty + sources.foreach { source => + source should not be empty + } + } + } +} diff --git a/docs/img/aggregates-reconcile-help.png b/docs/img/aggregates-reconcile-help.png new file mode 100644 index 0000000000..08ce52d3b9 Binary files /dev/null and b/docs/img/aggregates-reconcile-help.png differ diff --git a/docs/img/aggregates-reconcile-run.gif b/docs/img/aggregates-reconcile-run.gif new file mode 100644 index 0000000000..657a235584 Binary files /dev/null and b/docs/img/aggregates-reconcile-run.gif differ diff --git a/docs/img/check-python-version.gif b/docs/img/check-python-version.gif new file mode 100644 index 0000000000..6e43976d67 Binary files /dev/null and b/docs/img/check-python-version.gif differ diff --git a/docs/img/macos-databricks-cli-install.gif b/docs/img/macos-databricks-cli-install.gif new file mode 100644 index 0000000000..f71cf25ad6 Binary files /dev/null and b/docs/img/macos-databricks-cli-install.gif differ diff --git a/docs/img/recon-install.gif b/docs/img/recon-install.gif new file mode 100644 index 0000000000..5dfe14fa82 Binary files /dev/null and b/docs/img/recon-install.gif differ diff --git a/docs/img/recon-run.gif b/docs/img/recon-run.gif new file mode 100644 index 0000000000..df1b48e9a4 Binary files /dev/null and b/docs/img/recon-run.gif differ diff --git a/docs/img/reconcile-help.png b/docs/img/reconcile-help.png new file mode 100644 index 0000000000..fd93de6bed Binary files /dev/null and b/docs/img/reconcile-help.png differ diff --git a/docs/img/remorph-logo.svg b/docs/img/remorph-logo.svg new file mode 100644 index 0000000000..097a83586b --- /dev/null +++ b/docs/img/remorph-logo.svg @@ -0,0 +1 @@ + diff --git a/docs/img/remorph_intellij.gif b/docs/img/remorph_intellij.gif new file mode 100644 index 0000000000..fefaea2b31 Binary files /dev/null and b/docs/img/remorph_intellij.gif differ diff --git a/docs/img/transpile-help.png b/docs/img/transpile-help.png new file mode 100644 index 0000000000..a6847da837 Binary files /dev/null and b/docs/img/transpile-help.png differ diff --git a/docs/img/transpile-install.gif b/docs/img/transpile-install.gif new file mode 100644 index 0000000000..475a1445da Binary files /dev/null and b/docs/img/transpile-install.gif differ diff --git a/docs/img/transpile-run.gif b/docs/img/transpile-run.gif new file mode 100644 index 0000000000..e77f8f32c9 Binary files /dev/null and b/docs/img/transpile-run.gif differ diff --git a/docs/img/windows-databricks-cli-install.gif b/docs/img/windows-databricks-cli-install.gif new file mode 100644 index 0000000000..698c2aa502 Binary files /dev/null and b/docs/img/windows-databricks-cli-install.gif differ diff --git a/docs/recon_configurations/README.md b/docs/recon_configurations/README.md new file mode 100644 index 0000000000..0144774fad --- /dev/null +++ b/docs/recon_configurations/README.md @@ -0,0 +1,769 @@ +# Remorph Reconciliation + +Reconcile is an automated tool designed to streamline the reconciliation process between source data and target data +residing on Databricks. Currently, the platform exclusively offers support for Snowflake, Oracle and other Databricks +tables as the primary data source. This tool empowers users to efficiently identify discrepancies and variations in data +when comparing the source with the Databricks target. + +* [Types of Report Supported](#types-of-report-supported) +* [Report Type-Flow Chart](#report-type-flow-chart) +* [Supported Source System](#supported-source-system) +* [TABLE Config JSON filename](#table-config-json-filename) +* [TABLE Config Elements](#table-config-elements) + * [aggregates](#aggregate) + * [jdbc_reader_options](#jdbc_reader_options) + * [column_mapping](#column_mapping) + * [transformations](#transformations) + * [column_thresholds](#column_thresholds) + * [table_thresholds](#table_thresholds) + * [filters](#filters) + * [Key Considerations](#key-considerations) +* [Key Considerations for Oracle JDBC Reader Options](#key-considerations-for-oracle-jdbc-reader-options) +* [Reconciliation Example](#reconciliation-example) +* [DataFlow Example](#dataflow-example) +* [Aggregates Reconcile](#remorph-aggregates-reconciliation) + * [Supported Aggregate Functions](#supported-aggregate-functions) + * [Flow Chart](#flow-chart) + * [Aggregate](#aggregate) + * [TABLE Config Examples](#table-config-examples) + * [Key Considerations](#key-considerations) + * [Aggregates Reconciliation Example](#aggregates-reconciliation-json-example) + * [DataFlow Example](#dataflow-example) + +## Types of Report Supported + +| report type | sample visualisation | description | key outputs captured in the recon metrics tables | +|-------------|------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **schema** | [schema](report_types_visualisation.md#schema) | reconcile the schema of source and target.
- validate the datatype is same or compatible | - **schema_comparison**
- **schema_difference** | +| **row** | [row](report_types_visualisation.md#row) | reconcile the data only at row level(hash value of the source row is matched with the hash value of the target).Preferred when there are no join columns identified between source and target. | - **missing_in_src**(sample rows that are available in target but missing in source + sample rows in the target that don't match with the source)
- **missing_in_tgt**(sample rows that are available in source but are missing in target + sample rows in the source that doesn't match with target)
**NOTE**: the report won't differentiate the mismatch and missing here. | +| **data** | [data](report_types_visualisation.md#data) | reconcile the data at row and column level- ```join_columns``` will help us to identify mismatches at each row and column level | - **mismatch_data**(the sample data with mismatches captured at each column and row level )
- **missing_in_src**(sample rows that are available in target but missing in source)
- **missing_in_tgt**(sample rows that are available in source but are missing in target)
- **threshold_mismatch**(configured column will be reconciled based on percentile or threshold boundary or date boundary)
- **mismatch_columns**(consolidated list of columns that has mismatches in them)
| +| **all** | [all](report_types_visualisation.md#all) | this is a combination of data + schema | - **data + schema outputs** | + +[[↑ back to top](#remorph-reconciliation)] + +## Report Type-Flow Chart + +```mermaid +flowchart TD + REPORT_TYPE --> DATA + REPORT_TYPE --> SCHEMA + REPORT_TYPE --> ROW + REPORT_TYPE --> ALL +``` + +```mermaid +flowchart TD + SCHEMA --> SCHEMA_VALIDATION +``` + +```mermaid +flowchart TD + ROW --> MISSING_IN_SRC + ROW --> MISSING_IN_TGT +``` + +```mermaid +flowchart TD + DATA --> MISMATCH_ROWS + DATA --> MISSING_IN_SRC + DATA --> MISSING_IN_TGT +``` + +```mermaid +flowchart TD + ALL --> MISMATCH_ROWS + ALL --> MISSING_IN_SRC + ALL --> MISSING_IN_TGT + ALL --> SCHEMA_VALIDATION +``` + +[[↑ back to top](#remorph-reconciliation)] + +## Supported Source System + +| Source | Schema | Row | Data | All | +|------------|--------|-----|------|-----| +| Oracle | Yes | Yes | Yes | Yes | +| Snowflake | Yes | Yes | Yes | Yes | +| Databricks | Yes | Yes | Yes | Yes | + +[[↑ back to top](#remorph-reconciliation)] + +### TABLE Config Json filename: +The config file must be named as `recon_config___.json` and should be placed in the remorph root directory `.remorph` within the Databricks Workspace. + +> The filename pattern would remain the same for all the data_sources. + +Please find the `Table Recon` filename examples below for the `Snowflake`, `Oracle`, and `Databricks` source systems. + + + + + + + + + + + + + + + + + + + + + + +
Data SourceReconcile ConfigTable Recon filename
Snowflake +
+ database_config:
+  source_catalog: sample_data
+  source_schema: default
+  ...
+metadata_config:
+  ...
+data_source: snowflake
+report_type: all
+...
+             
+
recon_config_snowflake_sample_data_all.json
Oracle +
+ database_config:
+  source_schema: orc
+  ...
+metadata_config:
+  ...
+data_source: oracle
+report_type: data
+...
+             
+
recon_config_oracle_orc_data.json
Databricks (Hive MetaStore) +
+ database_config:
+  source_schema: hms
+  ...
+metadata_config:
+  ...
+data_source: databricks
+report_type: schema
+...
+             
+
recon_config_databricks_hms_schema.json
+ +> **Note:** the filename must be created in the same case as is defined. +> For example, if the source schema is defined as `ORC` in the `reconcile` config, the filename should be `recon_config_oracle_ORC_data.json`. + + +[[↑ back to top](#remorph-reconciliation)] + +### TABLE Config Elements: + + + + + + + + + + +
PythonJSON
+
+@dataclass
+class Table:
+    source_name: str
+    target_name: str
+    aggregates: list[Aggregate] | None = None
+    join_columns: list[str] | None = None
+    jdbc_reader_options: JdbcReaderOptions | None = None
+    select_columns: list[str] | None = None
+    drop_columns: list[str] | None = None
+    column_mapping: list[ColumnMapping] | None = None
+    transformations: list[Transformation] | None = None
+    column_thresholds: list[ColumnThresholds] | None = None
+    filters: Filters | None = None
+    table_thresholds: list[TableThresholds] | None = None
+
+
+
+{
+  "source_name": "<SOURCE_NAME>",
+  "target_name": "<TARGET_NAME>",
+  "aggregates": null,
+  "join_columns": ["<COLUMN_NAME_1>","<COLUMN_NAME_2>"],
+  "jdbc_reader_options": null,
+  "select_columns": null,
+  "drop_columns": null,
+  "column_mapping": null,
+  "transformation": null,
+  "column_thresholds": null,
+  "filters": null,
+  "table_thresholds": null
+}
+
+
+ + + +| config_name | data_type | description | required/optional | example_value | +|---------------------|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| +| source_name | string | name of the source table | required | product | +| target_name | string | name of the target table | required | product | +| aggregates | list[Aggregate] | list of aggregates, refer [Aggregate](#aggregate) for more information | optional(default=None) | "aggregates": [{"type": "MAX", "agg_columns": [""]}], | +| join_columns | list[string] | list of column names which act as the primary key to the table | optional(default=None) | ["product_id"] or ["product_id", "order_id"] | +| jdbc_reader_options | string | jdbc_reader_option, which helps to parallelise the data read from jdbc sources based on the given configuration.For more info [jdbc_reader_options](#jdbc_reader_options) | optional(default=None) | "jdbc_reader_options": {"number_partitions": 10,"partition_column": "s_suppkey","upper_bound": "10000000","lower_bound": "10","fetch_size":"100"} | +| select_columns | list[string] | list of columns to be considered for the reconciliation process | optional(default=None) | ["id", "name", "address"] | +| drop_columns | list[string] | list of columns to be eliminated from the reconciliation process | optional(default=None) | ["comment"] | +| column_mapping | list[ColumnMapping] | list of column_mapping that helps in resolving column name mismatch between src and tgt, e.g., "id" in src and "emp_id" in tgt.For more info [column_mapping](#column_mapping) | optional(default=None) | "column_mapping": [{"source_name": "id","target_name": "emp_id"}] | +| transformations | list[Transformations] | list of user-defined transformations that can be applied to src and tgt columns in case of any incompatibility data types or explicit transformation is applied during migration.For more info [transformations](#transformations) | optional(default=None) | "transformations": [{"column_name": "s_address","source": "trim(s_address)","target": "trim(s_address)"}] | +| column_thresholds | list[ColumnThresholds] | list of threshold conditions that can be applied on the columns to match the minor exceptions in data. It supports percentile, absolute, and date fields. For more info [column_thresholds](#column_thresholds) | optional(default=None) | "thresholds": [{"column_name": "sal", "lower_bound": "-5%", "upper_bound": "5%", "type": "int"}] | +| table_thresholds | list[TableThresholds] | list of table thresholds conditions that can be applied on the tables to match the minor exceptions in mismatch count. It supports percentile, absolute. For more info [table_thresholds](#table_thresholds) | optional(default=None) | "table_thresholds": [{"lower_bound": "0%", "upper_bound": "5%", "model": "mismatch"}] | +| filters | Filters | filter expr that can be used to filter the data on src and tgt based on respective expressions | optional(default=None) | "filters": {"source": "lower(dept_name)>’ it’”, "target": "lower(department_name)>’ it’”} | + + +### jdbc_reader_options + + + + + + + + + + +
PythonJSON
+
+@dataclass
+class JdbcReaderOptions:
+    number_partitions: int
+    partition_column: str
+    lower_bound: str
+    upper_bound: str
+    fetch_size: int = 100
+
+
+
+"jdbc_reader_options":{
+  "number_partitions": "<NUMBER_PARTITIONS>",
+  "partition_column": "<PARTITION_COLUMN>",
+  "lower_bound": "<LOWER_BOUND>",
+  "upper_bound": "<UPPER_BOUND>",
+  "fetch_size": "<FETCH_SIZE>"
+}
+
+
+ +| field_name | data_type | description | required/optional | example_value | +|-------------------|-----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|---------------| +| number_partitions | string | the number of partitions for reading input data in parallel | required | "200" | +| partition_column | string | Int/date/timestamp parameter defining the column used for partitioning, typically the primary key of the source table. Note that this parameter accepts only one column, which is especially crucial when dealing with a composite primary key. In such cases, provide the column with higher cardinality. | required | "employee_id | +| upper_bound | string | integer or date or timestamp without time zone value as string), that should be set appropriately (usually the maximum value in case of non-skew data) so the data read from the source should be approximately equally distributed | required | "1" | +| lower_bound | string | integer or date or timestamp without time zone value as string), that should be set appropriately (usually the minimum value in case of non-skew data) so the data read from the source should be approximately equally distributed | required | "100000" | +| fetch_size | string | This parameter influences the number of rows fetched per round-trip between Spark and the JDBC database, optimising data retrieval performance. Adjusting this option significantly impacts the efficiency of data extraction, controlling the volume of data retrieved in each fetch operation. More details on configuring fetch size can be found [here](https://docs.databricks.com/en/connect/external-systems/jdbc.html#control-number-of-rows-fetched-per-query) | optional(default="100") | "10000" | + +#### Key Considerations for Oracle JDBC Reader Options: +For Oracle source, the following options are automatically set: + +- "oracle.jdbc.mapDateToTimestamp": "False", +- "sessionInitStatement": "BEGIN dbms_session.set_nls('nls_date_format', '''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD HH24:MI:SS''');END;" + +While configuring Recon for Oracle source, the above options should be taken into consideration. + +### column_mapping + + + + + + + + + + +
PythonJSON
+
+
+@dataclass
+class ColumnMapping:
+    source_name: str
+    target_name: str
+
+
+
+"column_mapping":[
+  {
+    "source_name": "<SOURCE_COLUMN_NAME>",
+    "target_name": "<TARGET_COLUMN_NAME>"
+  }
+]
+
+
+ +| field_name | data_type | description | required/optional | example_value | +|-------------|-----------|--------------------|-------------------|-----------------| +| source_name | string | source column name | required | "dept_id" | +| target_name | string | target column name | required | "department_id" | + +### transformations + + + + + + + + + + +
PythonJSON
+
+
+@dataclass
+class Transformation:
+    column_name: str
+    source: str
+    target: str | None = None
+
+
+
+
+"transformations":[
+    {
+      "column_name": "<COLUMN_NAME>",
+      "source": "<TRANSFORMATION_EXPRESSION>",
+      "target": "<TRANSFORMATION_EXPRESSION>"
+    }
+]
+
+
+ + +| field_name | data_type | description | required/optional | example_value | +|-------------|-----------|------------------------------------------------------------|-------------------|----------------------------------| +| column_name | string | the column name on which the transformation to be applied | required | "s_address" | +| source | string | the transformation sql expr to be applied on source column | required | "trim(s_address)" or "s_address" | +| target | string | the transformation sql expr to be applied on source column | required | "trim(s_address)" or "s_address" | + + +> **Note:** Reconciliation also takes an udf in the transformation expr.Say for eg. we have a udf named sort_array_input() that takes an unsorted array as input and returns an array sorted.We can use that in transformation as below: + +``` +transformations=[Transformation(column_name)="array_col",source=sort_array_input(array_col),target=sort_array_input(array_col)] +``` +> **Note:** `NULL` values are defaulted to `_null_recon_` using the transformation expressions in these files: 1. [expression_generator.py](https://github.com/databrickslabs/remorph/tree/main/src/databricks/labs/remorph/reconcile/query_builder/expression_generator.py) 2. [sampling_query.py](https://github.com/databrickslabs/remorph/tree/main/src/databricks/labs/remorph/reconcile/query_builder/sampling_query.py). If User is looking for any specific behaviour, they can override these rules using [transformations](#transformations) accordingly. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Transformation Expressions
filenamefunction / variabletransformation_ruledescription
sampling_query.py_get_join_clausetransform(coalesce, default="_null_recon_", is_string=True)Applies the coalesce transformation function for String column and defaults to `_null_recon_` if column is NULL
expression_generator.pyDataType_transform_mapping(coalesce, default='_null_recon_', is_string=True)Default String column Transformation rule for all dialects. Applies the coalesce transformation function and defaults to `_null_recon_` if column is NULL
expression_generator.pyDataType_transform_mapping"oracle": DataType...NCHAR: ..."NVL(TRIM(TO_CHAR..,'_null_recon_')"Transformation rule for oracle dialect 'NCHAR' datatype. Applies TO_CHAR, TRIM transformation functions. If column is NULL, then defaults to `_null_recon_`
expression_generator.pyDataType_transform_mapping"oracle": DataType...NVARCHAR: ..."NVL(TRIM(TO_CHAR..,'_null_recon_')"Transformation rule for oracle dialect 'NVARCHAR' datatype. Applies TO_CHAR, TRIM transformation functions. If column is NULL, then defaults to `_null_recon_`
+ + +## column_thresholds + + + + + + + + + + +
PythonJSON
+
+
+@dataclass
+class ColumnThresholds:
+    column_name: str
+    lower_bound: str
+    upper_bound: str
+    type: str
+
+
+
+
+"column_thresholds":[
+  {
+    "column_name": "<COLUMN_NAME>",
+    "lower_bound": "<LOWER_BOUND>",
+    "upper_bound": "<UPPER_BOUND>",
+    "type": "<DATA_TYPE>"
+  }
+]
+
+
+ + +| field_name | data_type | description | required/optional | example_value | +|-------------|-----------|-------------------------------------------------------------------------------------------------------------|-------------------|--------------------| +| column_name | string | the column that should be considered for column threshold reconciliation | required | "product_discount" | +| lower_bound | string | the lower bound of the difference between the source value and the target value | required | -5% | +| upper_bound | string | the upper bound of the difference between the source value and the target value | required | 5% | +| type | string | The user must specify the column type. Supports SQLGLOT DataType.NUMERIC_TYPES and DataType.TEMPORAL_TYPES. | required | int | + +### table_thresholds + + + + + + + + + + +
PythonJSON
+
+
+@dataclass
+class TableThresholds:
+lower_bound: str
+upper_bound: str
+model: str
+
+
+
+
+"table_thresholds":[
+  {
+    "lower_bound": "<LOWER_BOUND>",
+    "upper_bound": "<UPPER_BOUND>",
+    "model": "<MODEL>"
+  }
+]
+
+
+ +* The threshold bounds for the table must be non-negative, with the lower bound not exceeding the upper bound. + +| field_name | data_type | description | required/optional | example_value | +|-------------|-----------|------------------------------------------------------------------------------------------------------|-------------------|---------------| +| lower_bound | string | the lower bound of the difference between the source mismatch and the target mismatch count | required | 0% | +| upper_bound | string | the upper bound of the difference between the source mismatch and the target mismatch count | required | 5% | +| model | string | The user must specify on which table model it should be applied; for now, we support only "mismatch" | required | int | + + +### filters + + + + + + + + + + +
PythonJSON
+
+@dataclass
+class Filters:
+    source: str | None = None
+    target: str | None = None
+
+
+
+"filters":{
+  "source": "<FILTER_EXPRESSION>",
+  "target": "<FILTER_EXPRESSION>"
+}
+
+
+ + +| field_name | data_type | description | required/optional | example_value | +|------------|-----------|---------------------------------------------------|------------------------|------------------------------| +| source | string | the sql expression to filter the data from source | optional(default=None) | "lower(dept_name)='finance'" | +| target | string | the sql expression to filter the data from target | optional(default=None) | "lower(dept_name)='finance'" | + +### Key Considerations: + +1. The column names are always converted to lowercase and considered for reconciliation. +2. Currently, it doesn't support case insensitivity and doesn't have collation support +3. Table Transformation internally considers the default value as the column value. It doesn't apply any default + transformations + if not provided. + ```eg:Transformation(column_name="address",source_name=None,target_name="trim(s_address)")``` + For the given example, + the source transformation is None, so the raw value in the source is considered for reconciliation. +4. If no user transformation is provided for a given column in the configuration by default, depending on the source + data + type, our reconciler will apply + default transformation on both source and target to get the matching hash value in source and target. Please find the + detailed default transformations here. +5. Always the column reference to be source column names in all the configs, except **Transformations** and **Filters** + as these are dialect-specific SQL expressions that are applied directly in the SQL. +6. **Transformations** and **Filters** are always should be in their respective dialect SQL expressions, and the + reconciler will not apply any logic + on top of this. + +[[↑ back to top](#remorph-reconciliation)] + +# Guidance for Oracle as a source + +## Driver + +### Option 1 + +* **Download `ojdbc8.jar` from Oracle:** + Visit the [official Oracle website](https://www.oracle.com/database/technologies/appdev/jdbc-downloads.html) to + acquire the `ojdbc8.jar` JAR file. This file is crucial for establishing connectivity between Databricks and Oracle + databases. + +* **Install the JAR file on Databricks:** + Upon completing the download, install the JAR file onto your Databricks cluster. Refer + to [this page](https://docs.databricks.com/en/libraries/cluster-libraries.html) + For comprehensive instructions on uploading a JAR file, Python egg, or Python wheel to your Databricks workspace. + +### Option 2 + +* **Install ojdbc8 library from Maven:** + Follow [this guide](https://docs.databricks.com/en/libraries/package-repositories.html#maven-or-spark-package) to + install the Maven library on a cluster. Refer + to [this document](https://mvnrepository.com/artifact/com.oracle.database.jdbc/ojdbc8) for obtaining the Maven + coordinates. + +This installation is a necessary step to enable seamless comparison between Oracle and Databricks, ensuring that the +required Oracle JDBC functionality is readily available within the Databricks environment. + +[[↑ back to top](#remorph-reconciliation)] + +## Commonly Used Custom Transformations + +| source_type | data_type | source_transformation | target_transformation | source_value_example | target_value_example | comments | +|-------------|---------------|------------------------------------------------------------------------|-------------------------------------------------|-------------------------|-------------------------|---------------------------------------------------------------------------------------------| +| Oracle | number(10,5) | trim(to_char(coalesce(,0.0), ’99990.99999’)) | cast(coalesce(,0.0) as decimal(10,5)) | 1.00 | 1.00000 | this can be used for any precision and scale by adjusting accordingly in the transformation | +| Snowflake | array | array_to_string(array_compact(),’,’) | concat_ws(’,’, ) | [1,undefined,2] | [1,2] | in case of removing "undefined" during migration(converts sparse array to dense array) | +| Snowflake | array | array_to_string(array_sort(array_compact(), true, true),’,’) | concat_ws(’,’, ) | [2,undefined,1] | [1,2] | in case of removing "undefined" during migration and want to sort the array | +| Snowflake | timestamp_ntz | date_part(epoch_second,) | unix_timestamp() | 2020-01-01 00:00:00.000 | 2020-01-01 00:00:00.000 | convert timestamp_ntz to epoch for getting a match between Snowflake and data bricks | + +[[↑ back to top](#remorph-reconciliation)] + +## Reconciliation Example: +For more Reconciliation Config example, please refer to [sample config][link]. + +[link]: reconcile_config_samples.md + +[[↑ back to top](#remorph-reconciliation)] + +## DataFlow Example + +Report Types Data [Visualisation](report_types_visualisation.md) + +[[↑ back to top](#remorph-reconciliation)] + +------ + +## Remorph Aggregates Reconciliation + + +Aggregates Reconcile is an utility to streamline the reconciliation process, specific aggregate metric is compared +between source and target data residing on Databricks. + +### Summary + +| operation_name | sample visualisation | description | key outputs captured in the recon metrics tables | +|--------------------------|----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **aggregates-reconcile** | [data](aggregates_reconcile_visualisation.md#data) | reconciles the data for each aggregate metric - ```join_columns``` are used to identify the mismatches at aggregated metric level | - **mismatch_data**(sample data with mismatches captured at aggregated metric level )
- **missing_in_src**(sample rows that are available in target but missing in source)
- **missing_in_tgt**(sample rows that are available in source but are missing in target)
| + + +## Supported Aggregate Functions + + +| Aggregate Functions | +|-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| **min** | +| **max** | +| **count** | +| **sum** | +| **avg** | +| **mean** | +| **mode** | +| **stddev** | +| **variance** | +| **median** | + + + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] + +## Flow Chart + +```mermaid +flowchart TD + Aggregates-Reconcile --> MISMATCH_ROWS + Aggregates-Reconcile --> MISSING_IN_SRC + Aggregates-Reconcile --> MISSING_IN_TGT +``` + + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] + + +## aggregate + + + + + + + + + + +
PythonJSON
+
+@dataclass
+class Aggregate:
+    agg_columns: list[str]
+    type: str
+    group_by_columns: list[str] | None = None
+
+
+
+{
+  "type": "MIN",
+  "agg_columns": ["<COLUMN_NAME_3>"],
+  "group_by_columns": ["<GROUP_COLUMN_NAME>"]
+}
+
+
+ +| field_name | data_type | description | required/optional | example_value | +|------------------|--------------|-----------------------------------------------------------------------|------------------------|------------------------| +| type | string | [Supported Aggregate Functions](#supported-aggregate-functions) | required | MIN | +| agg_columns | list[string] | list of columns names on which aggregate function needs to be applied | required | ["product_discount"] | +| group_by_columns | list[string] | list of column names on which grouping needs to be applied | optional(default=None) | ["product_id"] or None | + + + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] + + +### TABLE Config Examples: +Please refer [TABLE Config Elements](#TABLE-Config-Elements) for Class and JSON configs. + + + + + + + + + + +
PythonJSON
+
+
+Table(
+source_name= "",
+target_name= "",
+join_columns= ["", ""]
+aggregates= [
+Aggregate(
+agg_columns=[""],
+type= "MIN",
+group_by_columns= [""]
+),
+Aggregate(
+agg_columns=[""],
+type= "max"
+)
+]
+)
+
+
+
+
+{
+  "source_name": "<SOURCE_NAME>",
+  "target_name": "<TARGET_NAME>",
+  "join_columns": ["<COLUMN_NAME_1>","<COLUMN_NAME_2>"],
+  "aggregates": [{
+                   "type": "MIN",
+                   "agg_columns": ["<COLUMN_NAME_3>"],
+                   "group_by_columns": ["<GROUP_COLUMN_NAME>"]
+                  },
+                  {
+                    "type": "MAX",
+                    "agg_columns": ["<COLUMN_NAME_4>"],
+                  }],
+}
+
+
+ + +## Key Considerations: + +1. The aggregate column names, group by columns and type are always converted to lowercase and considered for reconciliation. +2. Currently, it doesn't support aggregates on window function using the OVER clause. +3. It doesn't support case insensitivity and does not have collation support +4. The queries with “group by” column(s) are compared based on the same group by columns. +5. The queries without “group by” column(s) are compared row-to-row. +6. Existing features like `column_mapping`, `transformations`, `JDBCReaderOptions` and `filters` are leveraged for the aggregate metric reconciliation. +7. Existing `select_columns` and `drop_columns` are not considered for the aggregate metric reconciliation. +8. Even though the user provides the `select_columns` and `drop_columns`, those are not considered. +9. If Transformations are defined, those are applied to both the “aggregate columns” and “group by columns”. + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] + +## Aggregates Reconciliation JSON Example + +Please refer this [sample config][link] for detailed example config. + +[link]: reconcile_config_samples.md#Aggregates-Reconcile-Config + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] + +## DataFlow Example + +Aggregates Reconcile Data [Visualisation](aggregates_reconcile_visualisation.md) + +[[back to aggregates-reconciliation](#remorph-aggregates-reconciliation)] + +[[↑ back to top](#remorph-reconciliation)] diff --git a/docs/recon_configurations/aggregates_reconcile_visualisation.md b/docs/recon_configurations/aggregates_reconcile_visualisation.md new file mode 100644 index 0000000000..bf5f3dbf5b --- /dev/null +++ b/docs/recon_configurations/aggregates_reconcile_visualisation.md @@ -0,0 +1,112 @@ +## data +### with group by +```mermaid +flowchart TB + subgraph source + direction TB + A["id: 1
city: New York
population: 100
state: NY"] + B["id: 2
city: Yonkers
population: 10
state: NY"] + C["id: 3
city: Los Angeles
population: 300
state: CA"] + D["id: 4
city: San Francisco
population: 30
state: CA"] + E["id: 6
city: Washington
population: 600
state: DC"] + end + + subgraph target + direction TB + F["id: 1
city: New York
population: 100
state: NY"] + G["id: 2
city: Yonkers
population: 10
state: NY"] + H["id: 3
city: Los Angeles
population: 300
state: CA"] + I["id: 5
city: San Diego
population: 40
state: CA"] + J["id: 7
city: Phoenix
population: 500
state: AZ"] + end + + subgraph source-aggregated + direction TB + K["sum(population): 110
state: NY"] + L["sum(population): 330
state: CA"] + M["sum(population): 600
state: DC"] + end + + subgraph target-aggregated + direction TB + N["sum(population): 110
state: NY"] + O["sum(population): 340
state: CA"] + P["sum(population): 500
state: AZ"] + end + + subgraph missing_in_src + direction TB + Q["sum(population): 500
state: AZ"] + end + + subgraph missing_in_tgt + direction TB + R["sum(population): 600
state: DC"] + end + + subgraph mismatch + direction TB + S["state: CA
source_sum(population): 330
target_sum(population): 340
sum(population)_match: false"] + end + + subgraph aggregates-reconcile + direction TB + T["aggregate: SUM as type
population as agg-columns
state as group_by_columns"] + end + + source --> source-aggregated + target --> target-aggregated + source-aggregated --> aggregates-reconcile + target-aggregated --> aggregates-reconcile + aggregates-reconcile --> missing_in_src + aggregates-reconcile --> missing_in_tgt + aggregates-reconcile --> mismatch +``` + + +### without group by +```mermaid +flowchart TB + subgraph source + direction TB + A["id: 1
city: New York
population: 100
state: NY"] + D["id: 4
city: San Francisco
population: 30
state: CA"] + E["id: 6
city: Washington
population: 600
state: DC"] + end + + subgraph target + direction TB + F["id: 1
city: New York
population: 100
state: NY"] + I["id: 5
city: San Diego
population: 40
state: CA"] + J["id: 7
city: Phoenix
population: 500
state: AZ"] + end + + subgraph source-aggregated + direction TB + K["min(population): 30"] + end + + subgraph target-aggregated + direction TB + O["min(population): 40"] + end + + + subgraph mismatch + direction TB + S["source_min(population): 30
target_min(population): 40
min(population)_match: false"] + end + + subgraph aggregates-reconcile + direction TB + T["aggregate: MIN as type
population as agg-columns"] + end + + source --> source-aggregated + target --> target-aggregated + source-aggregated --> aggregates-reconcile + target-aggregated --> aggregates-reconcile + aggregates-reconcile --> mismatch +``` + + diff --git a/docs/recon_configurations/reconcile_config_samples.md b/docs/recon_configurations/reconcile_config_samples.md new file mode 100644 index 0000000000..0801e09291 --- /dev/null +++ b/docs/recon_configurations/reconcile_config_samples.md @@ -0,0 +1,181 @@ +# Reconcile Config + +Consider the below tables that we want to reconcile: + +| category | catalog | schema | table_name | schema | primary_key | +|----------|----------------|---------------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------|-------------| +| source | source_catalog | source_schema | product_prod | p_id INT,
p_name STRING,
price NUMBER,
discount DECIMAL(5,3),
offer DOUBLE,
creation_date DATE
comment STRING
| p_id | +| target | target_catalog | target_schema | product | product_id INT,
product_name STRING,
price NUMBER,
discount DECIMAL(5,3),
offer DOUBLE,
creation_date DATE
comment STRING
| product_id | + +## Run with Drop,Join,Transformation,ColumnThresholds,Filter,JDBC ReaderOptions configs + + +```json + { + "source_catalog": "source_catalog", + "source_schema": "source_schema", + "target_catalog": "target_catalog", + "target_schema": "target_schema", + "tables": [ + { + "source_name": "product_prod", + "target_name": "product", + "jdbc_reader_options": { + "number_partitions": 10, + "partition_column": "p_id", + "lower_bound": "0", + "upper_bound": "10000000" + }, + "join_columns": [ + "p_id" + ], + "drop_columns": [ + "comment" + ], + "column_mapping": [ + { + "source_name": "p_id", + "target_name": "product_id" + }, + { + "source_name": "p_name", + "target_name": "product_name" + } + ], + "transformations": [ + { + "column_name": "creation_date", + "source": "creation_date", + "target": "to_date(creation_date,'yyyy-mm-dd')" + } + ], + "column_thresholds": [ + { + "column_name": "price", + "upper_bound": "-50", + "lower_bound": "50", + "type": "float" + } + ], + "table_thresholds": [ + { + "lower_bound": "0%", + "upper_bound": "5%", + "model": "mismatch" + } + ], + "filters": { + "source": "p_id > 0", + "target": "product_id > 0" + } + } + ] +} + +``` + +--- + +## Aggregates Reconcile Config + +### Aggregates-Reconcile run with Join, Column Mappings, Transformation, Filter and JDBC ReaderOptions configs + +> **Note:** Even though the user provides the `select_columns` and `drop_columns`, those are not considered. + + +```json + { + "source_catalog": "source_catalog", + "source_schema": "source_schema", + "target_catalog": "target_catalog", + "target_schema": "target_schema", + "tables": [ + { + "aggregates": [{ + "type": "MIN", + "agg_columns": ["discount"], + "group_by_columns": ["p_id"] + }, + { + "type": "AVG", + "agg_columns": ["discount"], + "group_by_columns": ["p_id"] + }, + { + "type": "MAX", + "agg_columns": ["p_id"], + "group_by_columns": ["creation_date"] + }, + { + "type": "MAX", + "agg_columns": ["p_name"] + }, + { + "type": "SUM", + "agg_columns": ["p_id"] + }, + { + "type": "MAX", + "agg_columns": ["creation_date"] + }, + { + "type": "MAX", + "agg_columns": ["p_id"], + "group_by_columns": ["creation_date"] + } + ], + "source_name": "product_prod", + "target_name": "product", + "jdbc_reader_options": { + "number_partitions": 10, + "partition_column": "p_id", + "lower_bound": "0", + "upper_bound": "10000000" + }, + "join_columns": [ + "p_id" + ], + "drop_columns": [ + "comment" + ], + "column_mapping": [ + { + "source_name": "p_id", + "target_name": "product_id" + }, + { + "source_name": "p_name", + "target_name": "product_name" + } + ], + "transformations": [ + { + "column_name": "creation_date", + "source": "creation_date", + "target": "to_date(creation_date,'yyyy-mm-dd')" + } + ], + "column_thresholds": [ + { + "column_name": "price", + "upper_bound": "-50", + "lower_bound": "50", + "type": "float" + } + ], + "table_thresholds": [ + { + "lower_bound": "0%", + "upper_bound": "5%", + "model": "mismatch" + } + ], + "filters": { + "source": "p_id > 0", + "target": "product_id > 0" + } + } + ] +} + +``` diff --git a/docs/recon_configurations/report_types_visualisation.md b/docs/recon_configurations/report_types_visualisation.md new file mode 100644 index 0000000000..ba84c7f9fd --- /dev/null +++ b/docs/recon_configurations/report_types_visualisation.md @@ -0,0 +1,171 @@ +## data +```mermaid +flowchart TB + subgraph source + direction TB + A["id: 1
city: New York"] + B["id: 2
city: Los Angeles"] + C["id: 3
city: San Francisco"] + end + + subgraph target + direction TB + D["id: 1
city: New York"] + E["id: 2
city: Brooklyn"] + F["id: 4
city: Chicago"] + end + + subgraph missing_in_src + direction TB + G["id: 4
city: Chicago"] + end + + subgraph missing_in_tgt + direction TB + H["id: 3
city: San Francisco"] + end + + subgraph mismatch + direction TB + I["id: 2
city_base: Los Angeles
city_compare: Brooklyn
city_match: false"] + end + + subgraph reconcile + direction TB + J["report type: data or all(with id as join columns)"] + end + + source --> reconcile + target --> reconcile + reconcile --> missing_in_src + reconcile --> missing_in_tgt + reconcile --> mismatch +``` + +## row +```mermaid +flowchart TB + subgraph source + direction TB + A["id: 1
city: New York"] + B["id: 2
city: Los Angeles"] + C["id: 3
city: San Francisco"] + end + + subgraph target + direction TB + D["id: 1
city: New York"] + E["id: 2
city: Brooklyn"] + F["id: 4
city: Chicago"] + end + + subgraph missing_in_src + direction TB + G["id: 2
city: Brooklyn"] + H["id: 4
city: Chicago"] + end + + subgraph missing_in_tgt + direction TB + I["id: 2
city: Los Angeles"] + J["id: 3
city: San Francisco"] + end + + subgraph reconcile + direction TB + K["report type: row(with no join column)"] + end + + source --> reconcile + target --> reconcile + reconcile --> missing_in_src + reconcile --> missing_in_tgt +``` + +## schema +```mermaid +flowchart TB + subgraph source + direction TB + A["column_name: id
data_type: number"] + B["column_name: name
city: varchar"] + C["column_name: salary
city: double"] + end + + subgraph target + direction TB + D["column_name: id
data_type: number"] + E["column_name: employee_name
city: string"] + F["column_name: salary
city: double"] + end + + subgraph reconcile + direction TB + G["report type: schema"] + end + + subgraph schema_reconcile_output + direction TB + H["source_column_name: id
databricks_column_name: id
source_datatype: int
databricks_datatype: int
is_valid: true"] + I["source_column_name: name
databricks_column_name: employee_name
source_datatype: varchar
databricks_datatype: string
is_valid: true"] + J["source_column_name: salary
databricks_column_name: salary
source_datatype: double
databricks_datatype: double
is_valid: true"] + end + + source --> reconcile + target --> reconcile + reconcile --> schema_reconcile_output +``` + +## all +```mermaid +flowchart TB + subgraph source + direction TB + A["id: 1
city: New York"] + B["id: 2
city: Los Angeles"] + C["id: 3
city: San Francisco"] + end + + subgraph target + direction TB + D["id: 1
city: New York"] + E["id: 2
city: Brooklyn"] + F["id: 4
city: Chicago"] + end + + subgraph missing_in_src + direction TB + G["id: 4
city: Chicago"] + end + + subgraph missing_in_tgt + direction TB + H["id: 3
city: San Francisco"] + end + + subgraph mismatch + direction TB + I["id: 2
city_base: Los Angeles
city_compare: Brooklyn
city_match: false"] + end + + subgraph schema_reconcile_output + direction TB + J["source_column_name: id
databricks_column_name: id
source_datatype: integer
databricks_datatype: integer
is_valid: true"] + K["source_column_name: city
databricks_column_name: city
source_datatype: varchar
databricks_datatype: string
is_valid: true"] + end + + subgraph reconcile + direction TB + L["report type: all(with id as join columns)"] + end + + source --> reconcile + target --> reconcile + reconcile --> missing_in_src + reconcile --> missing_in_tgt + reconcile --> mismatch + reconcile --> schema_reconcile_output +``` + + + diff --git a/labs.yml b/labs.yml new file mode 100644 index 0000000000..8799806cb9 --- /dev/null +++ b/labs.yml @@ -0,0 +1,96 @@ +--- +name: remorph +description: Code Transpiler and Data Reconciliation tool for Accelerating Data onboarding to Databricks from EDW, CDW and other ETL sources. +install: + min_runtime_version: 13.3 + require_running_cluster: false + require_databricks_connect: true + script: src/databricks/labs/remorph/install.py +uninstall: + script: src/databricks/labs/remorph/uninstall.py +entrypoint: src/databricks/labs/remorph/cli.py +min_python: 3.10 +commands: + - name: transpile + description: Transpile SQL script to Databricks SQL + flags: + - name: source-dialect + description: Input SQL Dialect Type Accepted Values [snowflake, tsql] + - name: input-source + description: Input SQL Folder or File + - name: output-folder + default: None + description: Output Location For Storing Transpiled Cod + - name: skip-validation + default: true + description: Validate Transpiled Code, default True validation skipped, False validate + - name: catalog-name + default: None + description: Catalog Name Applicable only when Validation Mode is DATABRICKS + - name: schema-name + default: None + description: Schema Name Applicable only when Validation Mode is DATABRICKS + - name: mode + default: current + description: Run in Current or Experimental Mode, Accepted Values [experimental, current], Default current, experimental mode will execute including any Private Preview features + + table_template: |- + total_files_processed\ttotal_queries_processed\tno_of_sql_failed_while_parsing\tno_of_sql_failed_while_validating\terror_log_file + {{range .}}{{.total_files_processed}}\t{{.total_queries_processed}}\t{{.no_of_sql_failed_while_parsing}}\t{{.no_of_sql_failed_while_validating}}\t{{.error_log_file}} + {{end}} + - name: reconcile + description: Reconcile is an utility to streamline the reconciliation process between source data and target data residing on Databricks. + - name: aggregates-reconcile + description: Aggregates Reconcile is an utility to streamline the reconciliation process, specific aggregate metric is compared between source and target data residing on Databricks. + - name: generate-lineage + description: Utility to generate a lineage of the SQL files + flags: + - name: source-dialect + description: Input SQL Dialect Type Accepted Values [snowflake, tsql] + - name: input-source + description: Input SQL Folder or File + - name: output-folder + description: Directory to store the generated lineage file + - name: configure-secrets + description: Utility to setup Scope and Secrets on Databricks Workspace + - name: debug-script + description: "[INTERNAL] Debug Script" + flags: + - name: name + description: Filename to debug + - name: dialect + description: sql dialect + - name: debug-me + description: "[INTERNAL] Debug SDK connectivity" + - name: debug-coverage + description: "[INTERNAL] Run coverage tests" + flags: + - name: dialect + description: sql dialect + - name: src + description: The parent directory under which test queries are laid out + - name: dst + description: The directory under which the report files will be written + - name: extractor + description: The strategy for extracting queries from the test files. Valid strategies are "full" (when files contain only one input query) and "comment" (when files contain an input query and the corresponding translation, separated by a comment stating the dialect of each query). + - name: debug-estimate + description: "[INTERNAL] estimate migration effort" + flags: + - name: dialect + description: sql dialect + - name: source-queries + description: The folder with queries. Otherwise will attempt to fetch query history for a dialect + - name: console-output + default: true + description: Output results to a folder + - name: dst + description: The directory for report + - name: debug-bundle + description: "[INTERNAL] Generate bundle for the translated queries" + flags: + - name: dialect + description: sql dialect + - name: source-queries + description: The folder with queries. Otherwise will attempt to fetch query history for a dialect + - name: dst + description: The directory for generated files diff --git a/linter/pom.xml b/linter/pom.xml new file mode 100644 index 0000000000..a85c3754c6 --- /dev/null +++ b/linter/pom.xml @@ -0,0 +1,134 @@ + + 4.0.0 + + com.databricks.labs + remorph + 0.2.0-SNAPSHOT + + remorph-linter + jar + + 5.10.0 + 1.8 + 1.8 + UTF-8 + 2.0.9 + 4.13.2 + + + + -- databricks sql: + + + 0.10.1 + 0.7.0 + 4.0.2 + 3.3.0-SNAP4 + 3.4.1 + + + + org.antlr + antlr4-runtime + ${antlr.version} + + + com.lihaoyi + os-lib_${scala.binary.version} + ${os-lib.version} + + + com.lihaoyi + mainargs_${scala.binary.version} + ${mainargs.version} + + + com.lihaoyi + ujson_${scala.binary.version} + ${ujson.version} + + + org.scalatest + scalatest_${scala.binary.version} + ${scalatest.version} + test + + + org.scala-lang + scala-library + ${scala.version} + + + com.databricks + databricks-sdk-java + 0.37.0 + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + com.databricks.labs.remorph.linter.Main + + -i + ${sourceDir} + -o + ${outputPath} + + + + + org.scalatest + scalatest-maven-plugin + 2.2.0 + + ${project.build.directory}/surefire-reports + . + true + + + + test + + test + + + + + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + false + true + src/main/antlr4 + true + + **/ANTLRv4Lexer.g4 + **/ANTLRv4Parser.g4 + + src/main/antlr4/library + + + + + diff --git a/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Lexer.g4 b/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Lexer.g4 new file mode 100644 index 0000000000..cfe41d4790 --- /dev/null +++ b/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Lexer.g4 @@ -0,0 +1,437 @@ +/* + * [The "BSD license"] + * Copyright (c) 2012-2015 Terence Parr + * Copyright (c) 2012-2015 Sam Harwell + * Copyright (c) 2015 Gerald Rosenberg + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +/** + * A grammar for ANTLR v4 implemented using v4 syntax + * + * Modified 2015.06.16 gbr + * -- update for compatibility with Antlr v4.5 + */ + +// $antlr-format alignTrailingComments on, columnLimit 130, minEmptyLines 1, maxEmptyLinesToKeep 1, reflowComments off +// $antlr-format useTab off, allowShortRulesOnASingleLine off, allowShortBlocksOnASingleLine on, alignSemicolons hanging +// $antlr-format alignColons hanging + +// ====================================================== +// Lexer specification +// ====================================================== + +lexer grammar ANTLRv4Lexer; + +options { + superClass = LexerAdaptor; +} + +import LexBasic; + +// Standard set of fragments +tokens { + TOKEN_REF, + RULE_REF, + LEXER_CHAR_SET +} + +channels { + OFF_CHANNEL, + COMMENT +} + +// ------------------------- +// Comments +DOC_COMMENT + : DocComment -> channel (COMMENT) + ; + +BLOCK_COMMENT + : BlockComment -> channel (COMMENT) + ; + +LINE_COMMENT + : LineComment -> channel (COMMENT) + ; + +// ------------------------- +// Integer + +INT + : DecimalNumeral + ; + +// ------------------------- +// Literal string +// +// ANTLR makes no distinction between a single character literal and a +// multi-character string. All literals are single quote delimited and +// may contain unicode escape sequences of the form \uxxxx, where x +// is a valid hexadecimal number (per Unicode standard). +STRING_LITERAL + : SQuoteLiteral + ; + +UNTERMINATED_STRING_LITERAL + : USQuoteLiteral + ; + +// ------------------------- +// Arguments +// +// Certain argument lists, such as those specifying call parameters +// to a rule invocation, or input parameters to a rule specification +// are contained within square brackets. +BEGIN_ARGUMENT + : LBrack { this.handleBeginArgument(); } + ; + +// ------------------------- +// Target Language Actions +BEGIN_ACTION + : LBrace -> pushMode (TargetLanguageAction) + ; + +// ------------------------- +// Keywords +// +// 'options', 'tokens', and 'channels' are considered keywords +// but only when followed by '{', and considered as a single token. +// Otherwise, the symbols are tokenized as RULE_REF and allowed as +// an identifier in a labeledElement. +OPTIONS + : 'options' WSNLCHARS* '{' + ; + +TOKENS + : 'tokens' WSNLCHARS* '{' + ; + +CHANNELS + : 'channels' WSNLCHARS* '{' + ; + +fragment WSNLCHARS + : ' ' + | '\t' + | '\f' + | '\n' + | '\r' + ; + +IMPORT + : 'import' + ; + +FRAGMENT + : 'fragment' + ; + +LEXER + : 'lexer' + ; + +PARSER + : 'parser' + ; + +GRAMMAR + : 'grammar' + ; + +PROTECTED + : 'protected' + ; + +PUBLIC + : 'public' + ; + +PRIVATE + : 'private' + ; + +RETURNS + : 'returns' + ; + +LOCALS + : 'locals' + ; + +THROWS + : 'throws' + ; + +CATCH + : 'catch' + ; + +FINALLY + : 'finally' + ; + +MODE + : 'mode' + ; + +// ------------------------- +// Punctuation + +COLON + : Colon + ; + +COLONCOLON + : DColon + ; + +COMMA + : Comma + ; + +SEMI + : Semi + ; + +LPAREN + : LParen + ; + +RPAREN + : RParen + ; + +LBRACE + : LBrace + ; + +RBRACE + : RBrace + ; + +RARROW + : RArrow + ; + +LT + : Lt + ; + +GT + : Gt + ; + +ASSIGN + : Equal + ; + +QUESTION + : Question + ; + +STAR + : Star + ; + +PLUS_ASSIGN + : PlusAssign + ; + +PLUS + : Plus + ; + +OR + : Pipe + ; + +DOLLAR + : Dollar + ; + +RANGE + : Range + ; + +DOT + : Dot + ; + +AT + : At + ; + +POUND + : Pound + ; + +NOT + : Tilde + ; + +// ------------------------- +// Identifiers - allows unicode rule/token names + +ID + : Id + ; + +// ------------------------- +// Whitespace + +WS + : Ws+ -> channel (OFF_CHANNEL) + ; + +// ------------------------- +// Illegal Characters +// +// This is an illegal character trap which is always the last rule in the +// lexer specification. It matches a single character of any value and being +// the last rule in the file will match when no other rule knows what to do +// about the character. It is reported as an error but is not passed on to the +// parser. This means that the parser to deal with the gramamr file anyway +// but we will not try to analyse or code generate from a file with lexical +// errors. + +// Comment this rule out to allow the error to be propagated to the parser +ERRCHAR + : . -> channel (HIDDEN) + ; + +// ====================================================== +// Lexer modes +// ------------------------- +// Arguments +mode Argument; + +// E.g., [int x, List a[]] +NESTED_ARGUMENT + : LBrack -> type (ARGUMENT_CONTENT), pushMode (Argument) + ; + +ARGUMENT_ESCAPE + : EscAny -> type (ARGUMENT_CONTENT) + ; + +ARGUMENT_STRING_LITERAL + : DQuoteLiteral -> type (ARGUMENT_CONTENT) + ; + +ARGUMENT_CHAR_LITERAL + : SQuoteLiteral -> type (ARGUMENT_CONTENT) + ; + +END_ARGUMENT + : RBrack { this.handleEndArgument(); } + ; + +// added this to return non-EOF token type here. EOF does something weird +UNTERMINATED_ARGUMENT + : EOF -> popMode + ; + +ARGUMENT_CONTENT + : . + ; + +// TODO: This grammar and the one used in the Intellij Antlr4 plugin differ +// for "actions". This needs to be resolved at some point. +// The Intellij Antlr4 grammar is here: +// https://github.com/antlr/intellij-plugin-v4/blob/1f36fde17f7fa63cb18d7eeb9cb213815ac658fb/src/main/antlr/org/antlr/intellij/plugin/parser/ANTLRv4Lexer.g4#L587 + +// ------------------------- +// Target Language Actions +// +// Many language targets use {} as block delimiters and so we +// must recursively match {} delimited blocks to balance the +// braces. Additionally, we must make some assumptions about +// literal string representation in the target language. We assume +// that they are delimited by ' or " and so consume these +// in their own alts so as not to inadvertantly match {}. +mode TargetLanguageAction; + +NESTED_ACTION + : LBrace -> type (ACTION_CONTENT), pushMode (TargetLanguageAction) + ; + +ACTION_ESCAPE + : EscAny -> type (ACTION_CONTENT) + ; + +ACTION_STRING_LITERAL + : DQuoteLiteral -> type (ACTION_CONTENT) + ; + +ACTION_CHAR_LITERAL + : SQuoteLiteral -> type (ACTION_CONTENT) + ; + +ACTION_DOC_COMMENT + : DocComment -> type (ACTION_CONTENT) + ; + +ACTION_BLOCK_COMMENT + : BlockComment -> type (ACTION_CONTENT) + ; + +ACTION_LINE_COMMENT + : LineComment -> type (ACTION_CONTENT) + ; + +END_ACTION + : RBrace { this.handleEndAction(); } + ; + +UNTERMINATED_ACTION + : EOF -> popMode + ; + +ACTION_CONTENT + : . + ; + +// ------------------------- +mode LexerCharSet; + +LEXER_CHAR_SET_BODY + : (~ [\]\\] | EscAny)+ -> more + ; + +LEXER_CHAR_SET + : RBrack -> popMode + ; + +UNTERMINATED_CHAR_SET + : EOF -> popMode + ; + +// ------------------------------------------------------------------------------ +// Grammar specific Keywords, Punctuation, etc. +fragment Id + : NameStartChar NameChar* + ; diff --git a/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Parser.g4 b/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Parser.g4 new file mode 100644 index 0000000000..90b4d951b5 --- /dev/null +++ b/linter/src/main/antlr4/com/databricks/labs/remorph/linter/ANTLRv4Parser.g4 @@ -0,0 +1,408 @@ +/* + * [The "BSD license"] + * Copyright (c) 2012-2014 Terence Parr + * Copyright (c) 2012-2014 Sam Harwell + * Copyright (c) 2015 Gerald Rosenberg + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* A grammar for ANTLR v4 written in ANTLR v4. + * + * Modified 2015.06.16 gbr + * -- update for compatibility with Antlr v4.5 + * -- add mode for channels + * -- moved members to com.databricks.labs.remporph.antlrlinter.LexerAdaptor + * -- move fragments to imports + */ + +// $antlr-format alignTrailingComments on, columnLimit 130, minEmptyLines 1, maxEmptyLinesToKeep 1, reflowComments off +// $antlr-format useTab off, allowShortRulesOnASingleLine off, allowShortBlocksOnASingleLine on, alignSemicolons hanging +// $antlr-format alignColons hanging + +parser grammar ANTLRv4Parser; + +options { + tokenVocab = ANTLRv4Lexer; +} + +// The main entry point for parsing a v4 grammar. +grammarSpec + : grammarDecl prequelConstruct* rules modeSpec* EOF + ; + +grammarDecl + : grammarType identifier SEMI + ; + +grammarType + : LEXER GRAMMAR + | PARSER GRAMMAR + | GRAMMAR + ; + +// This is the list of all constructs that can be declared before +// the set of rules that compose the grammar, and is invoked 0..n +// times by the grammarPrequel rule. + +prequelConstruct + : optionsSpec + | delegateGrammars + | tokensSpec + | channelsSpec + | action_ + ; + +// ------------ +// Options - things that affect analysis and/or code generation + +optionsSpec + : OPTIONS (option SEMI)* RBRACE + ; + +option + : identifier ASSIGN optionValue + ; + +optionValue + : identifier (DOT identifier)* + | STRING_LITERAL + | actionBlock + | INT + ; + +// ------------ +// Delegates + +delegateGrammars + : IMPORT delegateGrammar (COMMA delegateGrammar)* SEMI + ; + +delegateGrammar + : identifier ASSIGN identifier + | identifier + ; + +// ------------ +// Tokens & Channels + +tokensSpec + : TOKENS idList? RBRACE + ; + +channelsSpec + : CHANNELS idList? RBRACE + ; + +idList + : identifier (COMMA identifier)* COMMA? + ; + +// Match stuff like @parser::members {int i;} + +action_ + : AT (actionScopeName COLONCOLON)? identifier actionBlock + ; + +// Scope names could collide with keywords; allow them as ids for action scopes + +actionScopeName + : identifier + | LEXER + | PARSER + ; + +actionBlock + : BEGIN_ACTION ACTION_CONTENT* END_ACTION + ; + +argActionBlock + : BEGIN_ARGUMENT ARGUMENT_CONTENT* END_ARGUMENT + ; + +modeSpec + : MODE identifier SEMI lexerRuleSpec* + ; + +rules + : ruleSpec* + ; + +ruleSpec + : parserRuleSpec + | lexerRuleSpec + ; + +parserRuleSpec + : ruleModifiers? RULE_REF argActionBlock? ruleReturns? throwsSpec? localsSpec? rulePrequel* COLON ruleBlock SEMI + exceptionGroup + ; + +exceptionGroup + : exceptionHandler* finallyClause? + ; + +exceptionHandler + : CATCH argActionBlock actionBlock + ; + +finallyClause + : FINALLY actionBlock + ; + +rulePrequel + : optionsSpec + | ruleAction + ; + +ruleReturns + : RETURNS argActionBlock + ; + +// -------------- +// Exception spec +throwsSpec + : THROWS identifier (COMMA identifier)* + ; + +localsSpec + : LOCALS argActionBlock + ; + +/** Match stuff like @init {int i;} */ +ruleAction + : AT identifier actionBlock + ; + +ruleModifiers + : ruleModifier+ + ; + +// An individual access modifier for a rule. The 'fragment' modifier +// is an internal indication for lexer rules that they do not match +// from the input but are like subroutines for other lexer rules to +// reuse for certain lexical patterns. The other modifiers are passed +// to the code generation templates and may be ignored by the template +// if they are of no use in that language. + +ruleModifier + : PUBLIC + | PRIVATE + | PROTECTED + | FRAGMENT + ; + +ruleBlock + : ruleAltList + ; + +ruleAltList + : labeledAlt (OR labeledAlt)* + ; + +labeledAlt + : alternative (POUND identifier)? + ; + +// -------------------- +// Lexer rules + +lexerRuleSpec + : FRAGMENT? TOKEN_REF optionsSpec? COLON lexerRuleBlock SEMI + ; + +lexerRuleBlock + : lexerAltList + ; + +lexerAltList + : lexerAlt (OR lexerAlt)* + ; + +lexerAlt + : lexerElements lexerCommands? + | + // explicitly allow empty alts + ; + +lexerElements + : lexerElement+ + | + ; + +lexerElement + : lexerAtom ebnfSuffix? + | lexerBlock ebnfSuffix? + | actionBlock QUESTION? + ; + +// but preds can be anywhere + +lexerBlock + : LPAREN lexerAltList RPAREN + ; + +// E.g., channel(HIDDEN), skip, more, mode(INSIDE), push(INSIDE), pop + +lexerCommands + : RARROW lexerCommand (COMMA lexerCommand)* + ; + +lexerCommand + : lexerCommandName LPAREN lexerCommandExpr RPAREN + | lexerCommandName + ; + +lexerCommandName + : identifier + | MODE + ; + +lexerCommandExpr + : identifier + | INT + ; + +// -------------------- +// Rule Alts + +altList + : alternative (OR alternative)* + ; + +alternative + : elementOptions? element+ + | + // explicitly allow empty alts + ; + +element + : labeledElement (ebnfSuffix |) + | atom (ebnfSuffix |) + | ebnf + | actionBlock (QUESTION predicateOptions?)? + ; + +predicateOptions + : LT predicateOption (COMMA predicateOption)* GT + ; + +predicateOption + : elementOption + | identifier ASSIGN actionBlock + ; + +labeledElement + : identifier (ASSIGN | PLUS_ASSIGN) (atom | block) + ; + +// -------------------- +// EBNF and blocks + +ebnf + : block blockSuffix? + ; + +blockSuffix + : ebnfSuffix + ; + +ebnfSuffix + : QUESTION QUESTION? + | STAR QUESTION? + | PLUS QUESTION? + ; + +lexerAtom + : characterRange + | terminalDef + | notSet + | LEXER_CHAR_SET + | DOT elementOptions? + ; + +atom + : terminalDef + | ruleref + | notSet + | DOT elementOptions? + ; + +// -------------------- +// Inverted element set +notSet + : NOT setElement + | NOT blockSet + ; + +blockSet + : LPAREN setElement (OR setElement)* RPAREN + ; + +setElement + : TOKEN_REF elementOptions? + | STRING_LITERAL elementOptions? + | characterRange + | LEXER_CHAR_SET + ; + +// ------------- +// Grammar Block +block + : LPAREN (optionsSpec? ruleAction* COLON)? altList RPAREN + ; + +// ---------------- +// Parser rule ref +ruleref + : RULE_REF argActionBlock? elementOptions? + ; + +// --------------- +// Character Range +characterRange + : STRING_LITERAL RANGE STRING_LITERAL + ; + +terminalDef + : TOKEN_REF elementOptions? + | STRING_LITERAL elementOptions? + ; + +// Terminals may be adorned with certain options when +// reference in the grammar: TOK<,,,> +elementOptions + : LT elementOption (COMMA elementOption)* GT + ; + +elementOption + : identifier + | identifier ASSIGN (identifier | STRING_LITERAL) + ; + +identifier + : RULE_REF + | TOKEN_REF + ; diff --git a/linter/src/main/antlr4/library/LexBasic.g4 b/linter/src/main/antlr4/library/LexBasic.g4 new file mode 100644 index 0000000000..1939630ece --- /dev/null +++ b/linter/src/main/antlr4/library/LexBasic.g4 @@ -0,0 +1,286 @@ +/* + * [The "BSD license"] + * Copyright (c) 2014-2015 Gerald Rosenberg + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +/** + * A generally reusable set of fragments for import in to Lexer grammars. + * + * Modified 2015.06.16 gbr - + * -- generalized for inclusion into the ANTLRv4 grammar distribution + * + */ + +// $antlr-format alignTrailingComments on, columnLimit 130, minEmptyLines 1, maxEmptyLinesToKeep 1, reflowComments off +// $antlr-format useTab off, allowShortRulesOnASingleLine off, allowShortBlocksOnASingleLine on, alignSemicolons hanging +// $antlr-format alignColons hanging + +lexer grammar LexBasic; + +// ====================================================== +// Lexer fragments +// +// ----------------------------------- +// Whitespace & Comments + +fragment Ws + : Hws + | Vws + ; + +fragment Hws + : [ \t] + ; + +fragment Vws + : [\r\n\f] + ; + +fragment BlockComment + : '/*' .*? ('*/' | EOF) + ; + +fragment DocComment + : '/**' .*? ('*/' | EOF) + ; + +fragment LineComment + : '//' ~ [\r\n]* + ; + +// ----------------------------------- +// Escapes +// Any kind of escaped character that we can embed within ANTLR literal strings. + +fragment EscSeq + : Esc ([btnfr"'\\] | UnicodeEsc | . | EOF) + ; + +fragment EscAny + : Esc . + ; + +fragment UnicodeEsc + : 'u' (HexDigit (HexDigit (HexDigit HexDigit?)?)?)? + ; + +// ----------------------------------- +// Numerals + +fragment DecimalNumeral + : '0' + | [1-9] DecDigit* + ; + +// ----------------------------------- +// Digits + +fragment HexDigit + : [0-9a-fA-F] + ; + +fragment DecDigit + : [0-9] + ; + +// ----------------------------------- +// Literals + +fragment BoolLiteral + : 'true' + | 'false' + ; + +fragment CharLiteral + : SQuote (EscSeq | ~ ['\r\n\\]) SQuote + ; + +fragment SQuoteLiteral + : SQuote (EscSeq | ~ ['\r\n\\])* SQuote + ; + +fragment DQuoteLiteral + : DQuote (EscSeq | ~ ["\r\n\\])* DQuote + ; + +fragment USQuoteLiteral + : SQuote (EscSeq | ~ ['\r\n\\])* + ; + +// ----------------------------------- +// Character ranges + +fragment NameChar + : NameStartChar + | '0' .. '9' + | Underscore + | '\u00B7' + | '\u0300' .. '\u036F' + | '\u203F' .. '\u2040' + ; + +fragment NameStartChar + : 'A' .. 'Z' + | 'a' .. 'z' + | '\u00C0' .. '\u00D6' + | '\u00D8' .. '\u00F6' + | '\u00F8' .. '\u02FF' + | '\u0370' .. '\u037D' + | '\u037F' .. '\u1FFF' + | '\u200C' .. '\u200D' + | '\u2070' .. '\u218F' + | '\u2C00' .. '\u2FEF' + | '\u3001' .. '\uD7FF' + | '\uF900' .. '\uFDCF' + | '\uFDF0' .. '\uFFFD' + // ignores | ['\u10000-'\uEFFFF] + ; + +// ----------------------------------- +// Types + +fragment Int + : 'int' + ; + +// ----------------------------------- +// Symbols + +fragment Esc + : '\\' + ; + +fragment Colon + : ':' + ; + +fragment DColon + : '::' + ; + +fragment SQuote + : '\'' + ; + +fragment DQuote + : '"' + ; + +fragment LParen + : '(' + ; + +fragment RParen + : ')' + ; + +fragment LBrace + : '{' + ; + +fragment RBrace + : '}' + ; + +fragment LBrack + : '[' + ; + +fragment RBrack + : ']' + ; + +fragment RArrow + : '->' + ; + +fragment Lt + : '<' + ; + +fragment Gt + : '>' + ; + +fragment Equal + : '=' + ; + +fragment Question + : '?' + ; + +fragment Star + : '*' + ; + +fragment Plus + : '+' + ; + +fragment PlusAssign + : '+=' + ; + +fragment Underscore + : '_' + ; + +fragment Pipe + : '|' + ; + +fragment Dollar + : '$' + ; + +fragment Comma + : ',' + ; + +fragment Semi + : ';' + ; + +fragment Dot + : '.' + ; + +fragment Range + : '..' + ; + +fragment At + : '@' + ; + +fragment Pound + : '#' + ; + +fragment Tilde + : '~' + ; diff --git a/linter/src/main/java/com/databricks/labs/remorph/linter/LexerAdaptor.java b/linter/src/main/java/com/databricks/labs/remorph/linter/LexerAdaptor.java new file mode 100644 index 0000000000..1ba9b54cb4 --- /dev/null +++ b/linter/src/main/java/com/databricks/labs/remorph/linter/LexerAdaptor.java @@ -0,0 +1,151 @@ +package com.databricks.labs.remorph.linter; +/* + [The "BSD licence"] + Copyright (c) 2005-2007 Terence Parr + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +import org.antlr.v4.runtime.CharStream; +import org.antlr.v4.runtime.Lexer; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.misc.Interval; + +public abstract class LexerAdaptor extends Lexer { + + /** + * Generic type for OPTIONS, TOKENS and CHANNELS + */ + private static final int PREQUEL_CONSTRUCT = -10; + private static final int OPTIONS_CONSTRUCT = -11; + + public LexerAdaptor(CharStream input) { + super(input); + } + + /** + * Track whether we are inside of a rule and whether it is lexical parser. _currentRuleType==Token.INVALID_TYPE + * means that we are outside of a rule. At the first sign of a rule name reference and _currentRuleType==invalid, we + * can assume that we are starting a parser rule. Similarly, seeing a token reference when not already in rule means + * starting a token rule. The terminating ';' of a rule, flips this back to invalid type. + * + * This is not perfect logic but works. For example, "grammar T;" means that we start and stop a lexical rule for + * the "T;". Dangerous but works. + * + * The whole point of this state information is to distinguish between [..arg actions..] and [charsets]. Char sets + * can only occur in lexical rules and arg actions cannot occur. + */ + private int _currentRuleType = Token.INVALID_TYPE; + + private boolean insideOptionsBlock = false; + + public int getCurrentRuleType() { + return _currentRuleType; + } + + public void setCurrentRuleType(int ruleType) { + this._currentRuleType = ruleType; + } + + protected void handleBeginArgument() { + if (inLexerRule()) { + pushMode(ANTLRv4Lexer.LexerCharSet); + more(); + } else { + pushMode(ANTLRv4Lexer.Argument); + } + } + + protected void handleEndArgument() { + popMode(); + if (_modeStack.size() > 0) { + setType(ANTLRv4Lexer.ARGUMENT_CONTENT); + } + } + + protected void handleEndAction() { + int oldMode = _mode; + int newMode = popMode(); + boolean isActionWithinAction = _modeStack.size() > 0 + && newMode == ANTLRv4Lexer.TargetLanguageAction + && oldMode == newMode; + + if (isActionWithinAction) { + setType(ANTLRv4Lexer.ACTION_CONTENT); + } + } + + @Override + public Token emit() { + if ((_type == ANTLRv4Lexer.OPTIONS || _type == ANTLRv4Lexer.TOKENS || _type == ANTLRv4Lexer.CHANNELS) + && getCurrentRuleType() == Token.INVALID_TYPE) { // enter prequel construct ending with an RBRACE + setCurrentRuleType(PREQUEL_CONSTRUCT); + } else if (_type == ANTLRv4Lexer.OPTIONS && getCurrentRuleType() == ANTLRv4Lexer.TOKEN_REF) + { + setCurrentRuleType(OPTIONS_CONSTRUCT); + } else if (_type == ANTLRv4Lexer.RBRACE && getCurrentRuleType() == PREQUEL_CONSTRUCT) { // exit prequel construct + setCurrentRuleType(Token.INVALID_TYPE); + } else if (_type == ANTLRv4Lexer.RBRACE && getCurrentRuleType() == OPTIONS_CONSTRUCT) + { // exit options + setCurrentRuleType(ANTLRv4Lexer.TOKEN_REF); + } else if (_type == ANTLRv4Lexer.AT && getCurrentRuleType() == Token.INVALID_TYPE) { // enter action + setCurrentRuleType(ANTLRv4Lexer.AT); + } else if (_type == ANTLRv4Lexer.SEMI && getCurrentRuleType() == OPTIONS_CONSTRUCT) + { // ';' in options { .... }. Don't change anything. + } else if (_type == ANTLRv4Lexer.END_ACTION && getCurrentRuleType() == ANTLRv4Lexer.AT) { // exit action + setCurrentRuleType(Token.INVALID_TYPE); + } else if (_type == ANTLRv4Lexer.ID) { + String firstChar = _input.getText(Interval.of(_tokenStartCharIndex, _tokenStartCharIndex)); + if (Character.isUpperCase(firstChar.charAt(0))) { + _type = ANTLRv4Lexer.TOKEN_REF; + } else { + _type = ANTLRv4Lexer.RULE_REF; + } + + if (getCurrentRuleType() == Token.INVALID_TYPE) { // if outside of rule def + setCurrentRuleType(_type); // set to inside lexer or parser rule + } + } else if (_type == ANTLRv4Lexer.SEMI) { // exit rule def + setCurrentRuleType(Token.INVALID_TYPE); + } + + return super.emit(); + } + + private boolean inLexerRule() { + return getCurrentRuleType() == ANTLRv4Lexer.TOKEN_REF; + } + + @SuppressWarnings("unused") + private boolean inParserRule() { // not used, but added for clarity + return getCurrentRuleType() == ANTLRv4Lexer.RULE_REF; + } + + @Override + public void reset() { + setCurrentRuleType(Token.INVALID_TYPE); + insideOptionsBlock = false; + super.reset(); + } +} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/GrammarSource.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/GrammarSource.scala new file mode 100644 index 0000000000..b2e2995cbd --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/GrammarSource.scala @@ -0,0 +1,26 @@ +package com.databricks.labs.remorph.linter + +import java.io.File +import java.nio.file.{Files, Path} +import scala.collection.JavaConverters._ + +case class GrammarLint(grammarName: String, inputFile: File) + +trait GrammarLintSource { + def listGrammars: Seq[GrammarLint] +} + +class NestedFiles(root: Path) extends GrammarLintSource { + def listGrammars: Seq[GrammarLint] = { + val files = + Files + .walk(root) + .iterator() + .asScala + .filter(f => Files.isRegularFile(f)) + .toSeq + + val grammarFiles = files.filter(_.getFileName.toString.endsWith(".g4")) + grammarFiles.map(p => GrammarLint(root.relativize(p).toString, p.toFile)) + } +} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/Main.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/Main.scala new file mode 100644 index 0000000000..25d33de3b2 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/Main.scala @@ -0,0 +1,116 @@ +package com.databricks.labs.remorph.linter + +import mainargs.{ParserForMethods, TokensReader, arg, main} +import org.antlr.v4.runtime.tree.ParseTreeWalker +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} + +import java.time.Instant + +object Main { + + implicit object PathRead extends TokensReader.Simple[os.Path] { + def shortName: String = "path" + def read(strs: Seq[String]): Either[String, os.Path] = Right(os.Path(strs.head, os.pwd)) + } + + private def getCurrentCommitHash: Option[String] = { + val gitRevParse = os.proc("/usr/bin/git", "rev-parse", "--short", "HEAD").call(os.pwd) + if (gitRevParse.exitCode == 0) { + Some(gitRevParse.out.trim()) + } else { + None + } + } + + private def timeToEpochNanos(instant: Instant) = { + val epoch = Instant.ofEpochMilli(0) + java.time.Duration.between(epoch, instant).toNanos + } + + @main + def run( + @arg(short = 'i', doc = "Source path of g4 grammar files") + sourceDir: os.Path, + @arg(short = 'o', doc = "Report output path") + outputPath: os.Path, + @arg(short = 'c', doc = "Write errors to console") + toConsole: Boolean): Unit = { + try { + + val now = Instant.now + val project = "remorph-core" + val commitHash = getCurrentCommitHash + val grammarSource = new NestedFiles(sourceDir.toNIO) + + val outputFilePath = outputPath / s"${project}-lint-grammar-${timeToEpochNanos(now)}.jsonl" + os.makeDir.all(outputPath) + + var exitCode = 0 + + grammarSource.listGrammars.foreach { grammar => + val ruleTracker = new RuleTracker + val orphanedRule = new OrphanedRule(ruleTracker) + val inputStream = CharStreams.fromPath(grammar.inputFile.toPath) + val lexer = new ANTLRv4Lexer(inputStream) + val tokens = new CommonTokenStream(lexer) + val parser = new ANTLRv4Parser(tokens) + val tree = parser.grammarSpec() + val walker = new ParseTreeWalker() + walker.walk(orphanedRule, tree) + + val header = ReportEntryHeader( + project = project, + commit_hash = commitHash, + version = "latest", + timestamp = now.toString, + file = os.Path(grammar.inputFile).relativeTo(sourceDir).toString) + val summary = ruleTracker.reconcileRules() + val reportEntryJson = ReportEntry(header, summary).asJson + os.write.append(outputFilePath, ujson.write(reportEntryJson, indent = -1) + "\n") + + if (summary.hasIssues) { + exitCode = 1 + } + + // If a local build, and there were problems, then we may wish to print the report + // to the console so developers see it + // scalastyle:off + if (toConsole) { + val sourceLines = os.read(os.Path(grammar.inputFile)).linesIterator.toList + + if (summary.hasIssues) { + println("\nIssues found in grammar: " + grammar.inputFile) + if (summary.orphanedRuleDef.nonEmpty) { + println("Orphaned rules (rules defined but never used):") + summary.orphanedRuleDef.foreach { rule => + println(s" ${rule.ruleName} defined at line ${rule.lineNo}:") + sourceLines.slice(rule.lineNo - 1, rule.lineNo + 1).foreach { line => + println(" " + line) + } + println() + } + } + if (summary.undefinedRules.nonEmpty) { + println("Undefined rules (rules referenced but never defined):") + summary.undefinedRules.foreach { rule => + println(s"\n ${rule.ruleName} used at line ${rule.lineNo}, col ${rule.charStart}") + println(" " + sourceLines(rule.lineNo - 1)) + println(" " * (4 + rule.charStart) + "^" * (rule.charEnd - rule.charStart)) + } + } + } + } + // scalastyle:on + } + if (exitCode != 0) { + sys.exit(exitCode) + } + } catch { + case e: Exception => + e.printStackTrace() + sys.exit(1) + } + } + + def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args) +} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/OrphanedRule.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/OrphanedRule.scala new file mode 100644 index 0000000000..0b7e690524 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/OrphanedRule.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.remorph.linter + +import org.antlr.v4.runtime.tree.Trees +import org.antlr.v4.runtime.ParserRuleContext + +import scala.jdk.CollectionConverters._ + +class OrphanedRule(ruleTracker: RuleTracker) extends ANTLRv4ParserBaseListener { + + /** + * Checks if a rule or any of its children contains EOF as a terminal node. Rules ending in EOF are entry point rules + * called externally and may not be referenced in the grammar, only defined. They are not reported as orphaned. + * + * @param ctx + * the parser context to search within + * @return + */ + private def containsEOF(ctx: ParserRuleContext): Boolean = { + val x = Trees.findAllNodes(ctx, ANTLRv4Parser.TOKEN_REF, true).asScala + x.foreach { node => + if (node.getText == "EOF") { + return true + } + } + false + } + + /** + * Records that a rule has been defined in the parser and whether it contains EOF + */ + override def enterParserRuleSpec(ctx: ANTLRv4Parser.ParserRuleSpecContext): Unit = { + val ruleSymbol = ctx.RULE_REF().getSymbol + ruleTracker.addRuleDef(RuleDefinition(ruleSymbol.getLine, ruleSymbol.getText, containsEOF(ctx))) + } + + /** + * Records that a rule has been referenced in the parser + */ + override def enterRuleref(ctx: ANTLRv4Parser.RulerefContext): Unit = { + val ruleReference = + new RuleReference( + ctx.start.getLine, + ctx.start.getCharPositionInLine, + ctx.stop.getCharPositionInLine + ctx.getText.length, + ctx.getText) + ruleTracker.addRuleRef(ruleReference) + } +} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/ReportEntry.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/ReportEntry.scala new file mode 100644 index 0000000000..18e4296633 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/ReportEntry.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.remorph.linter + +case class ReportEntryHeader( + project: String, + commit_hash: Option[String], + version: String, + timestamp: String, + file: String) + +case class ReportEntry(header: ReportEntryHeader, report: OrphanedRuleSummary) { + + def asJson: ujson.Value.Value = { + ujson.Obj( + "project" -> ujson.Str(header.project), + "commit_hash" -> header.commit_hash.map(ujson.Str).getOrElse(ujson.Null), + "version" -> ujson.Str(header.version), + "timestamp" -> ujson.Str(header.timestamp), + "file" -> ujson.Str(header.file), + "orphans" -> report.toJSON) + } +} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleDefinition.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleDefinition.scala new file mode 100644 index 0000000000..2d0469e2c6 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleDefinition.scala @@ -0,0 +1,3 @@ +package com.databricks.labs.remorph.linter + +case class RuleDefinition(lineNo: Int, ruleName: String, isExternal: Boolean = false) diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleReference.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleReference.scala new file mode 100644 index 0000000000..9c7424be06 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleReference.scala @@ -0,0 +1,3 @@ +package com.databricks.labs.remorph.linter + +case class RuleReference(lineNo: Int, charStart: Int, charEnd: Int, ruleName: String) {} diff --git a/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleTracker.scala b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleTracker.scala new file mode 100644 index 0000000000..c36eb1ef54 --- /dev/null +++ b/linter/src/main/scala/com/databricks/labs/remorph/linter/RuleTracker.scala @@ -0,0 +1,110 @@ +package com.databricks.labs.remorph.linter + +import ujson._ + +class RuleTracker { + + private[this] var ruleDefMap: Map[String, RuleDefinition] = Map() + private[this] var ruleRefMap: Map[String, List[RuleReference]] = Map() + private[this] var orphanedRuleDefs: List[RuleDefinition] = List() + private[this] var undefinedRules: List[RuleReference] = List() + + // Definition handling + def addRuleDef(rule: RuleDefinition): Unit = { + ruleDefMap += (rule.ruleName -> rule) + } + + def getRuleDef(ruleName: String): RuleDefinition = { + ruleDefMap(ruleName) + } + + def getRuleMap: Map[String, RuleDefinition] = { + ruleDefMap + } + + def getRuleDefCount: Int = { + ruleDefMap.size + } + + def getRuleDefNames: List[String] = { + ruleDefMap.keys.toList + } + + def getRuleDefinitions: List[RuleDefinition] = { + ruleDefMap.values.toList + } + + def getRuleDefinition(ruleName: String): RuleDefinition = { + ruleDefMap(ruleName) + } + + /** + * How many times the rule definition is referenced in the grammar + * @param ruleName + * the name of the rule + * @return + */ + def getRuleDefRefCount(ruleName: String): Int = { + ruleRefMap.get(ruleName).map(_.size).getOrElse(0) + } + + def getRuleDefLineNo(ruleName: String): Int = { + ruleDefMap.get(ruleName).map(_.lineNo).getOrElse(-1) + } + + // Reference handling + + def addRuleRef(ruleRef: RuleReference): Unit = { + val name = ruleRef.ruleName + ruleRefMap += (name -> (ruleRef :: ruleRefMap.getOrElse(name, Nil))) + } + + def getRuleRefs(ruleName: String): List[RuleReference] = { + ruleRefMap(ruleName) + } + + def getRuleRefCount(ruleName: String): Int = { + ruleRefMap.get(ruleName).map(_.size).getOrElse(0) + } + + def getRuleRefsByCondition(condition: RuleReference => Boolean): List[RuleReference] = { + ruleRefMap.values.flatten.filter(condition).toList + } + + // Reconciliation + // This is where we reconcile the rule definitions with the references to discover + // undefined and unreferenced rules + + def reconcileRules(): OrphanedRuleSummary = { + orphanedRuleDefs = ruleDefMap.values.filterNot(rule => ruleRefMap.contains(rule.ruleName) || rule.isExternal).toList + undefinedRules = ruleRefMap.values.flatten.filterNot(ref => ruleDefMap.contains(ref.ruleName)).toList + OrphanedRuleSummary(orphanedRuleDefs, undefinedRules) + } + + def getOrphanedRuleDefs: List[RuleDefinition] = orphanedRuleDefs + def getUndefinedRules: List[RuleReference] = undefinedRules +} + +case class OrphanedRuleSummary(orphanedRuleDef: List[RuleDefinition], undefinedRules: List[RuleReference]) { + def toJSON: Obj = { + val orphanedRuleDefJson = orphanedRuleDef.map { rule => + Obj("lineNo" -> rule.lineNo, "ruleName" -> rule.ruleName) + } + + val undefinedRulesJson = undefinedRules.map { rule => + Obj( + "lineNo" -> rule.lineNo, + "charStart" -> rule.charStart, + "charEnd" -> rule.charEnd, + "ruleName" -> rule.ruleName) + } + + Obj( + "orphanCount" -> orphanedRuleDef.size, + "undefinedRuleCount" -> undefinedRules.size, + "orphanedRuleDef" -> orphanedRuleDefJson, + "undefinedRules" -> undefinedRulesJson) + } + + def hasIssues: Boolean = orphanedRuleDef.nonEmpty || undefinedRules.nonEmpty +} diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000000..07c6a2f53e --- /dev/null +++ b/pom.xml @@ -0,0 +1,317 @@ + + + 4.0.0 + com.databricks.labs + remorph + 0.2.0-SNAPSHOT + pom + Databricks Labs Remorph + Remorph stands as a comprehensive toolkit meticulously crafted to + facilitate seamless migrations to Databricks. This suite of tools is dedicated + to simplifying and optimizing the entire migration process, offering two distinctive + functionalities – Transpile and Reconcile. Whether you are navigating code translation + or resolving potential conflicts, Remorph ensures a smooth journey for any migration + project. With Remorph as your trusted ally, the migration experience becomes not only + efficient but also well-managed, setting the stage for a successful transition to + the Databricks platform + https://github.com/databrickslabs/remorph + + + Databricks License + https://github.com/databrickslabs/remorph/blob/main/LICENSE + + + + linter + core + py + + + scm:git:https://github.com/databrickslabs/remorph.git + scm:git:https://github.com/databrickslabs/remorph.git + v${project.version} + https://github.com/databrickslabs/remorph/tree/v${project.version} + + + GitHub Issues + https://github.com/databrickslabs/remorph/issues + + + GitHub Actions + https://github.com/databrickslabs/remorph/blob/main/.github/workflows/push.yml + + + 1.8 + 1.8 + 1.8 + UTF-8 + 2.12 + + 2.12.20 + 2.0.6 + 3.3.0 + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.7.1 + + false + + jar-with-dependencies + + + + com.databricks.labs.remorph.Main + + + + + + make-assembly + + single + + package + + + + + net.alchim31.maven + scala-maven-plugin + 4.9.2 + + ${scala.version} + + -unchecked + -deprecation + -explaintypes + -feature + -language:existentials + -language:implicitConversions + -language:reflectiveCalls + -Xfatal-warnings + -Xlint + + + + + + compile + testCompile + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.6.0 + + + add-source + + add-source + + generate-sources + + + src/main/scala + + + + + add-test-source + + add-test-source + + generate-sources + + + src/test/scala + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + 8 + ${java.version} + ${java.version} + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + true + + + + org.scoverage + scoverage-maven-plugin + ${scoverage.plugin.version} + + ${scala.version} + true + .*Main + true + + + + com.diffplug.spotless + spotless-maven-plugin + 2.43.0 + + + + + + + + + + + + pom.xml + + + false + false + + + true + true + true + + + + + + + + + format + + + + org.antipathy + mvn-scalafmt_2.12 + 1.1.1640084764.9f463a9 + + .scalafmt.conf + + + + validate + + format + + + + + + + + + release + + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.6.3 + + + + inheritDoc + m + Overrides: + false + + + + + + attach-javadocs + + jar + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.0.1 + + + sign-artifacts + + sign + + verify + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.6.13 + true + + ossrh + https://oss.sonatype.org/ + true + + + + + + + diff --git a/py/pom.xml b/py/pom.xml new file mode 100644 index 0000000000..65235c8542 --- /dev/null +++ b/py/pom.xml @@ -0,0 +1,16 @@ + + + 4.0.0 + + com.databricks.labs + remorph + 0.2.0-SNAPSHOT + ../pom.xml + + remorph-py + + ../src + ../tests + + diff --git a/pyproject.toml b/pyproject.toml index 28c4ac8682..7891b25898 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,148 +1,750 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "databricks-labs-remorph" -dynamic = ["version"] -description = '' +description = 'SQL code converter and data reconcilation tool for accelerating data onboarding to Databricks from EDW, CDW and other ETL sources.' +license-files = { paths = ["LICENSE", "NOTICE"] } +keywords = ["Databricks"] readme = "README.md" requires-python = ">=3.10" -license = "MIT" -keywords = [] -authors = [ -] +dynamic = ["version"] classifiers = [ - "Development Status :: 5 - Alpha", + "Development Status :: 3 - Alpha", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = [] +dependencies = [ + "databricks-sdk~=0.29.0", + "sqlglot==25.33.0", + "databricks-labs-blueprint[yaml]>=0.2.3", + "databricks-labs-lsql>=0.7.5,<0.14.0", # TODO: Limit the LSQL version until dependencies are correct. + "cryptography>=41.0.3", +] [project.urls] Documentation = "https://github.com/databrickslabs/remorph" Issues = "https://github.com/databrickslabs/remorph/issues" Source = "https://github.com/databrickslabs/remorph" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +sources = ["src"] +include = ["src"] + [tool.hatch.version] -path = "src/remorph/__about__.py" +path = "src/databricks/labs/remorph/__about__.py" [tool.hatch.envs.default] +python="3.10" + +# store virtual env as the child of this folder. Helps VSCode to run better +path = ".venv" + dependencies = [ + "pylint~=3.2.2", + "pylint-pytest==2.0.0a0", "coverage[toml]>=6.5", "pytest", + "pytest-cov>=4.0.0,<5.0.0", + "black>=23.1.0", + "ruff>=0.0.243", + "databricks-connect==15.1", + "types-pyYAML", + "types-pytz", + "databricks-labs-pylint~=0.4.0", + "mypy~=1.10.0", + "numpy==1.26.4", ] + +[project.entry-points.databricks] +reconcile = "databricks.labs.remorph.reconcile.execute:main" + [tool.hatch.envs.default.scripts] -test = "pytest {args:tests}" -test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +test = "pytest --cov src --cov-report=xml tests/unit" +coverage = "pytest --cov src tests/unit --cov-report=html" +integration = "pytest --cov src tests/integration --durations 20" +fmt = ["black .", + "ruff check . --fix", + "mypy --disable-error-code 'annotation-unchecked' .", + "pylint --output-format=colorized -j 0 src tests"] +verify = ["black --check .", + "ruff check .", + "mypy --disable-error-code 'annotation-unchecked' .", + "pylint --output-format=colorized -j 0 src tests"] -[tool.hatch.envs.lint] -detached = true +[tool.hatch.envs.sqlglot-latest] +python="3.10" dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] -[tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/databricks tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", + "databricks-labs-remorph", + "sqlglot", ] + +[tool.pytest.ini_options] +addopts = "-s -p no:warnings -vv --cache-clear" +cache_dir = ".venv/pytest-cache" + [tool.black] target-version = ["py310"] line-length = 120 skip-string-normalization = true [tool.ruff] +cache-dir = ".venv/ruff-cache" target-version = "py310" line-length = 120 -select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "FBT", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", -] -ignore = [ + +lint.ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", - # Ignore checks for possible passwords - "S105", "S106", "S107", + # Ignore checks for possible passwords and SQL statement construction + "S105", "S106", "S107", "S603", "S608", + # Allow print statements + "T201", + # Allow asserts + "S101", + # Allow standard random generators + "S311", # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # Ignore Exception must not use a string literal, assign to variable first + "EM101", + "PLR2004", + "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` ] -unfixable = [ - # Don't touch unused imports - "F401", +extend-exclude = [ + "notebooks/*.py" ] -[tool.ruff.isort] -known-first-party = ["databricks"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +[tool.ruff.lint.per-file-ignores] + +"tests/**/*" = [ + "PLR2004", "S101", "TID252", # tests can use magic values, assertions, and relative imports + "ARG001" # tests may not use the provided fixtures +] [tool.coverage.run] -source_pkgs = ["remorph", "tests"] branch = true parallel = true -omit = [ - "src/remorph/__about__.py", -] - -[tool.coverage.paths] -remorph = ["src/remorph", "*/remorph/src/remorph"] -tests = ["tests", "*/remorph/tests"] [tool.coverage.report] +omit = ["src/databricks/labs/remorph/coverage/*", + "src/databricks/labs/remorph/helpers/execution_time.py", + "__about__.py"] exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", + "def main()", +] + +[tool.pylint.main] +# PyLint configuration is adapted from Google Python Style Guide with modifications. +# Sources https://google.github.io/styleguide/pylintrc +# License: https://github.com/google/styleguide/blob/gh-pages/LICENSE + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint in +# a server-like mode. +# clear-cache-post-run = + +# Always return a 0 (non-error) status code, even if lint errors are found. This +# is primarily useful in continuous integration scripts. +# exit-zero = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +# extension-pkg-allow-list = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +# extension-pkg-whitelist = + +# Specify a score threshold under which the program will exit with error. +fail-under = 10.0 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +# from-stdin = + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +# ignore-paths = + +# Files or directories matching the regular expression patterns are skipped. The +# regex matches against base names, not paths. The default value ignores Emacs +# file locks +ignore-patterns = ["^\\.#"] + +# List of module names for which member attributes should not be checked (useful +# for modules/projects where namespaces are manipulated during runtime and thus +# existing member attributes cannot be deduced by static analysis). It supports +# qualified module names, as well as Unix pattern matching. +# ignored-modules = + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +# init-hook = + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs = 0 + +# Control the amount of potential inferred values when inferring a single object. +# This can help the performance when dealing with large functions or complex, +# nested conditions. +limit-inference-results = 100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins = [ + "pylint_pytest", + "pylint.extensions.bad_builtin", + "pylint.extensions.broad_try_clause", + "pylint.extensions.check_elif", + "pylint.extensions.code_style", + "pylint.extensions.confusing_elif", + "pylint.extensions.comparison_placement", + "pylint.extensions.consider_refactoring_into_while_condition", + "pylint.extensions.dict_init_mutate", + "pylint.extensions.docparams", + "pylint.extensions.dunder", + "pylint.extensions.for_any_all", + "pylint.extensions.mccabe", + "pylint.extensions.overlapping_exceptions", + "pylint.extensions.private_import", + "pylint.extensions.redefined_variable_type", + "pylint.extensions.set_membership", + "pylint.extensions.typing", ] + +# Pickle collected data for later comparisons. +persistent = true + +# Minimum Python version to use for version dependent checks. Will default to the +# version used to run pylint. +py-version = "3.10" + +# Discover python modules and packages in the file system subtree. +# recursive = + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +# source-roots = + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode = true + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +# unsafe-load-any-extension = + +[tool.pylint.basic] +# Naming style matching correct argument names. +argument-naming-style = "snake_case" + +# Regular expression matching correct argument names. Overrides argument-naming- +# style. If left empty, argument names will be checked with the set naming style. +argument-rgx = "[a-z_][a-z0-9_]{2,30}$" + +# Naming style matching correct attribute names. +attr-naming-style = "snake_case" + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +attr-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Bad variable names which should always be refused, separated by a comma. +bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"] + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +# bad-names-rgxs = + +# Naming style matching correct class attribute names. +class-attribute-naming-style = "any" + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +class-attribute-rgx = "([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$" + +# Naming style matching correct class constant names. +class-const-naming-style = "UPPER_CASE" + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +# class-const-rgx = + +# Naming style matching correct class names. +class-naming-style = "PascalCase" + +# Regular expression matching correct class names. Overrides class-naming-style. +# If left empty, class names will be checked with the set naming style. +class-rgx = "[A-Z_][a-zA-Z0-9]+$" + +# Naming style matching correct constant names. +const-naming-style = "UPPER_CASE" + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming style. +const-rgx = "(([A-Z_][A-Z0-9_]*)|(__.*__))$" + +# Minimum line length for functions/classes that require docstrings, shorter ones +# are exempt. +docstring-min-length = -1 + +# Naming style matching correct function names. +function-naming-style = "snake_case" + +# Regular expression matching correct function names. Overrides function-naming- +# style. If left empty, function names will be checked with the set naming style. +function-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Good variable names which should always be accepted, separated by a comma. +good-names = [ + "f", # use for file handles + "i", "j", "k", # use for loops + "df", # use for pyspark.sql.DataFrame + "ex", "e", # use for exceptions + "fn", "cb", # use for callbacks + "_", # use for ignores + "a", # use for databricks.sdk.AccountClient + "w", "ws" # use for databricks.sdk.WorkspaceClient +] + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +# good-names-rgxs = + +# Include a hint for the correct naming format with invalid-name. +# include-naming-hint = + +# Naming style matching correct inline iteration names. +inlinevar-naming-style = "any" + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +inlinevar-rgx = "[A-Za-z_][A-Za-z0-9_]*$" + +# Naming style matching correct method names. +method-naming-style = "snake_case" + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +method-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Naming style matching correct module names. +module-naming-style = "snake_case" + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +module-rgx = "(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$" + +# Colon-delimited sets of names that determine each other's naming style when the +# name regexes allow several styles. +# name-group = + +# Regular expression which should only match function or class names that do not +# require a docstring. +no-docstring-rgx = "__.*__" + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. These +# decorators are taken in consideration only for invalid-name. +property-classes = ["abc.abstractproperty"] + +# Regular expression matching correct type alias names. If left empty, type alias +# names will be checked with the set naming style. +# typealias-rgx = + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +# typevar-rgx = + +# Naming style matching correct variable names. +variable-naming-style = "snake_case" + +# Regular expression matching correct variable names. Overrides variable-naming- +# style. If left empty, variable names will be checked with the set naming style. +variable-rgx = "[a-z_][a-z0-9_]{2,30}$" + +[tool.pylint.broad_try_clause] +# Maximum number of statements allowed in a try clause +max-try-statements = 7 + +[tool.pylint.classes] +# Warn about protected attribute access inside special methods +# check-protected-access-in-special-methods = + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp", "__post_init__"] + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make"] + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg = ["cls"] + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg = ["mcs"] + +[tool.pylint.deprecated_builtins] +# List of builtins function names that should not be used, separated by a comma +bad-functions = ["map", "input"] + +[tool.pylint.design] +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +# exclude-too-few-public-methods = + +# List of qualified class names to ignore when counting class parents (see R0901) +# ignored-parents = + +# Maximum number of arguments for function / method. +max-args = 9 + +# Maximum number of attributes for a class (see R0902). +max-attributes = 13 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr = 5 + +# Maximum number of branch for function / method body. +max-branches = 20 + +# Maximum number of locals for function / method body. +max-locals = 19 + +# Maximum number of parents for a class (see R0901). +max-parents = 7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods = 20 + +# Maximum number of return / yield for function / method body. +max-returns = 11 + +# Maximum number of statements in function / method body. +max-statements = 50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods = 2 + +[tool.pylint.exceptions] +# Exceptions that will emit a warning when caught. +overgeneral-exceptions = ["builtins.Exception"] + +[tool.pylint.format] +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +# expected-line-ending-format = + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines = "^\\s*(# )??$" + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren = 4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string = " " + +# Maximum number of characters on a single line. +max-line-length = 100 + +# Maximum number of lines in a module. +max-module-lines = 2000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +# single-line-class-stmt = + +# Allow the body of an if to be on the same line as the test if there is no else. +# single-line-if-stmt = + +[tool.pylint.imports] +# List of modules that can be imported at any level, not just the top level one. +# allow-any-import-level = + +# Allow explicit reexports by alias from a package __init__. +# allow-reexport-from-package = + +# Allow wildcard imports from modules that define __all__. +# allow-wildcard-with-all = + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec"] + +# Output a graph (.gv or any supported image format) of external dependencies to +# the given file (report RP0402 must not be disabled). +# ext-import-graph = + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be disabled). +# import-graph = + +# Output a graph (.gv or any supported image format) of internal dependencies to +# the given file (report RP0402 must not be disabled). +# int-import-graph = + +# Force import order to recognize a module as part of the standard compatibility +# libraries. +# known-standard-library = + +# Force import order to recognize a module as part of a third party library. +known-third-party = ["enchant"] + +# Couples of modules and preferred modules, separated by a comma. +# preferred-modules = + +[tool.pylint.logging] +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style = "new" + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules = ["logging"] + +[tool.pylint."messages control"] +# Only show warnings with the listed confidence levels. Leave empty to show all. +# Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence = ["HIGH", "CONTROL_FLOW", "INFERENCE", "INFERENCE_FAILURE", "UNDEFINED"] + +# Disable the message, report, category or checker with the given id(s). You can +# either give multiple identifiers separated by comma (,) or put this option +# multiple times (only on the command line, not in the configuration file where +# it should appear only once). You can also use "--disable=all" to disable +# everything first and then re-enable specific checks. For example, if you want +# to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable = [ + "prefer-typing-namedtuple", + "attribute-defined-outside-init", + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "too-few-public-methods", + "line-too-long", + "trailing-whitespace", + "missing-final-newline", + "trailing-newlines", + "unnecessary-semicolon", + "mixed-line-endings", + "unexpected-line-ending-format", + "fixme", + "consider-using-assignment-expr", + "logging-fstring-interpolation", + "consider-using-any-or-all" +] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where it +# should appear only once). See also the "--disable" option for examples. +enable = ["useless-suppression", "use-symbolic-message-instead"] + +[tool.pylint.method_args] +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods = ["requests.api.delete", "requests.api.get", "requests.api.head", "requests.api.options", "requests.api.patch", "requests.api.post", "requests.api.put", "requests.api.request"] + +[tool.pylint.miscellaneous] +# List of note tags to take in consideration, separated by a comma. +notes = ["FIXME", "XXX", "TODO"] + +# Regular expression of note tags to take in consideration. +# notes-rgx = + +[tool.pylint.parameter_documentation] +# Whether to accept totally missing parameter documentation in the docstring of a +# function that has parameters. +accept-no-param-doc = true + +# Whether to accept totally missing raises documentation in the docstring of a +# function that raises an exception. +accept-no-raise-doc = true + +# Whether to accept totally missing return documentation in the docstring of a +# function that returns a statement. +accept-no-return-doc = true + +# Whether to accept totally missing yields documentation in the docstring of a +# generator. +accept-no-yields-doc = true + +# If the docstring type cannot be guessed the specified docstring type will be +# used. +default-docstring-type = "default" + +[tool.pylint.refactoring] +# Maximum number of nested blocks for function / method body +max-nested-blocks = 5 + +# Complete name of functions that never returns. When checking for inconsistent- +# return-statements if a never returning function is called then it will be +# considered as an explicit return statement and no message will be printed. +never-returning-functions = ["sys.exit", "argparse.parse_error"] + +[tool.pylint.reports] +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each category, +# as well as 'statement' which is the total number of statements analyzed. This +# score is used by the global evaluation report (RP0004). +evaluation = "max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))" + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +# msg-template = + +# Set the output format. Available formats are: text, parseable, colorized, json2 +# (improved json format), json (old json format) and msvs (visual studio). You +# can also give a reporter class, e.g. mypackage.mymodule.MyReporterClass. +# output-format = + +# Tells whether to display a full report or only the messages. +# reports = + +# Activate the evaluation score. +score = true + +[tool.pylint.similarities] +# Comments are removed from the similarity computation +ignore-comments = true + +# Docstrings are removed from the similarity computation +ignore-docstrings = true + +# Imports are removed from the similarity computation +ignore-imports = true + +# Signatures are removed from the similarity computation +ignore-signatures = true + +# Minimum lines number of a similarity. +min-similarity-lines = 6 + +[tool.pylint.spelling] +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions = 2 + +# Spelling dictionary name. No available dictionaries : You need to install both +# the python package and the system dependency for enchant to work. +# spelling-dict = + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,mypy:,pragma:,# noinspection" + +# List of comma separated words that should not be checked. +# spelling-ignore-words = + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file = ".pyenchant_pylint_custom_dict.txt" + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +# spelling-store-unknown-words = + +[tool.pylint.typecheck] +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators = ["contextlib.contextmanager"] + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members = "REQUEST,acl_users,aq_parent,argparse.Namespace" + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +# Tells whether to warn about missing members when the owner of the attribute is +# inferred to be None. +ignore-none = true + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference can +# return multiple potential results while evaluating a Python object, but some +# branches might not be evaluated, which results in partial inference. In that +# case, it might be useful to still emit no-member and other checks for the rest +# of the inferred objects. +ignore-on-opaque-inference = true + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"] + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes = ["SQLObject", "optparse.Values", "thread._local", "_thread._local"] + +# Show a hint with possible names when a member name was not found. The aspect of +# finding the hint is based on edit distance. +missing-member-hint = true + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance = 1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices = 1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx = ".*MixIn" + +# List of decorators that change the signature of a decorated function. +# signature-mutators = + +[tool.pylint.variables] +# List of additional names supposed to be defined in builtins. Remember that you +# should avoid defining new builtins when possible. +# additional-builtins = + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables = true + +# List of names allowed to shadow builtins +# allowed-redefined-builtins = + +# List of strings which can identify a callback function by name. A callback name +# must start or end with one of those strings. +callbacks = ["cb_", "_cb"] + +# A regular expression matching the name of dummy variables (i.e. expected to not +# be used). +dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" + +# Argument names that match this expression will be ignored. +ignored-argument-names = "_.*|^ignored_|^unused_" + +# Tells whether we should check for unused import in __init__ files. +# init-import = + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"] diff --git a/scalastyle-config.xml b/scalastyle-config.xml new file mode 100644 index 0000000000..88c1177580 --- /dev/null +++ b/scalastyle-config.xml @@ -0,0 +1,441 @@ + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, + TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + + + ^FunSuite[A-Za-z]*$ + + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + ^println$ + + + + + + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + + + + + @VisibleForTesting + + + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + + + mutable\.SynchronizedBuffer + + + + + + + Class\.forName + + + + + + + Await\.result + + + + + + + Await\.ready + + + + + + + (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) + + + + + + + + throw new \w+Error\( + + + + + + + + JavaConversions + + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + + + org\.apache\.commons\.lang\. + + Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead + of Commons Lang 2 (package org.apache.commons.lang.*) + + + + + + extractOpt + + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + is slower. + + + + + + COMMA + + + + + + + \)\{ + + + + + + + case[^\n>]*=>\s*\{ + + Omit braces in case clauses. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + + + 30 + + + + + + + 10 + + + + + + + 50 + + + + + + + + + + + + + + + -1,0,1,2,3 + + + + diff --git a/src/databricks/labs/remorph/__about__.py b/src/databricks/labs/remorph/__about__.py index f102a9cadf..ec84298734 100644 --- a/src/databricks/labs/remorph/__about__.py +++ b/src/databricks/labs/remorph/__about__.py @@ -1 +1,2 @@ -__version__ = "0.0.1" +# DO NOT MODIFY THIS FILE +__version__ = "0.9.0" diff --git a/src/databricks/labs/remorph/__init__.py b/src/databricks/labs/remorph/__init__.py index e69de29bb2..349653e0fb 100644 --- a/src/databricks/labs/remorph/__init__.py +++ b/src/databricks/labs/remorph/__init__.py @@ -0,0 +1,11 @@ +from databricks.sdk.core import with_user_agent_extra, with_product +from databricks.labs.blueprint.logger import install_logger +from databricks.labs.remorph.__about__ import __version__ + +install_logger() + +# Add remorph/ for projects depending on remorph as a library +with_user_agent_extra("remorph", __version__) + +# Add remorph/ for re-packaging of remorph, where product name is omitted +with_product("remorph", __version__) diff --git a/src/databricks/labs/remorph/cli.py b/src/databricks/labs/remorph/cli.py new file mode 100644 index 0000000000..7986c57e28 --- /dev/null +++ b/src/databricks/labs/remorph/cli.py @@ -0,0 +1,170 @@ +import json +import os + +from databricks.labs.blueprint.cli import App +from databricks.labs.blueprint.entrypoint import get_logger +from databricks.labs.remorph.config import SQLGLOT_DIALECTS, TranspileConfig +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.helpers.recon_config_utils import ReconConfigPrompts +from databricks.labs.remorph.reconcile.runner import ReconcileRunner +from databricks.labs.remorph.lineage import lineage_generator +from databricks.labs.remorph.transpiler.execute import transpile as do_transpile +from databricks.labs.remorph.reconcile.execute import RECONCILE_OPERATION_NAME, AGG_RECONCILE_OPERATION_NAME +from databricks.labs.remorph.jvmproxy import proxy_command + +from databricks.sdk import WorkspaceClient + +remorph = App(__file__) +logger = get_logger(__file__) + +DIALECTS = {name for name, dialect in SQLGLOT_DIALECTS.items()} + + +def raise_validation_exception(msg: str) -> Exception: + raise ValueError(msg) + + +proxy_command(remorph, "debug-script") +proxy_command(remorph, "debug-me") +proxy_command(remorph, "debug-coverage") +proxy_command(remorph, "debug-estimate") +proxy_command(remorph, "debug-bundle") + + +@remorph.command +def transpile( + w: WorkspaceClient, + source_dialect: str, + input_source: str, + output_folder: str | None, + skip_validation: str, + catalog_name: str, + schema_name: str, + mode: str, +): + """Transpiles source dialect to databricks dialect""" + ctx = ApplicationContext(w) + logger.info(f"User: {ctx.current_user}") + default_config = ctx.transpile_config + if not default_config: + raise SystemExit("Installed transpile config not found. Please install Remorph transpile first.") + _override_workspace_client_config(ctx, default_config.sdk_config) + mode = mode if mode else "current" # not checking for default config as it will always be current + # TODO get rid of the sqlglot dependency + if source_dialect.lower() not in SQLGLOT_DIALECTS: + raise_validation_exception( + f"Error: Invalid value for '--source-dialect': '{source_dialect}' is not one of {DIALECTS}." + ) + if not input_source or not os.path.exists(input_source): + raise_validation_exception(f"Error: Invalid value for '--input-source': Path '{input_source}' does not exist.") + if not output_folder and default_config.output_folder: + output_folder = str(default_config.output_folder) + if skip_validation.lower() not in {"true", "false"}: + raise_validation_exception( + f"Error: Invalid value for '--skip-validation': '{skip_validation}' is not one of 'true', 'false'." + ) + if mode.lower() not in {"current", "experimental"}: + raise_validation_exception( + f"Error: Invalid value for '--mode': '{mode}' " f"is not one of 'current', 'experimental'." + ) + + sdk_config = default_config.sdk_config if default_config.sdk_config else None + catalog_name = catalog_name if catalog_name else default_config.catalog_name + schema_name = schema_name if schema_name else default_config.schema_name + + config = TranspileConfig( + source_dialect=source_dialect.lower(), + input_source=input_source, + output_folder=output_folder, + skip_validation=skip_validation.lower() == "true", # convert to bool + catalog_name=catalog_name, + schema_name=schema_name, + mode=mode, + sdk_config=sdk_config, + ) + + status = do_transpile(ctx.workspace_client, config) + + print(json.dumps(status)) + + +def _override_workspace_client_config(ctx: ApplicationContext, overrides: dict[str, str] | None): + """ + Override the Workspace client's SDK config with the user provided SDK config. + Users can provide the cluster_id and warehouse_id during the installation. + This will update the default config object in-place. + """ + if not overrides: + return + + warehouse_id = overrides.get("warehouse_id") + if warehouse_id: + ctx.connect_config.warehouse_id = warehouse_id + + cluster_id = overrides.get("cluster_id") + if cluster_id: + ctx.connect_config.cluster_id = cluster_id + + +@remorph.command +def reconcile(w: WorkspaceClient): + """[EXPERIMENTAL] Reconciles source to Databricks datasets""" + ctx = ApplicationContext(w) + logger.info(f"User: {ctx.current_user}") + recon_runner = ReconcileRunner( + ctx.workspace_client, + ctx.installation, + ctx.install_state, + ctx.prompts, + ) + recon_runner.run(operation_name=RECONCILE_OPERATION_NAME) + + +@remorph.command +def aggregates_reconcile(w: WorkspaceClient): + """[EXPERIMENTAL] Reconciles Aggregated source to Databricks datasets""" + ctx = ApplicationContext(w) + logger.info(f"User: {ctx.current_user}") + recon_runner = ReconcileRunner( + ctx.workspace_client, + ctx.installation, + ctx.install_state, + ctx.prompts, + ) + + recon_runner.run(operation_name=AGG_RECONCILE_OPERATION_NAME) + + +@remorph.command +def generate_lineage(w: WorkspaceClient, source_dialect: str, input_source: str, output_folder: str): + """[Experimental] Generates a lineage of source SQL files or folder""" + ctx = ApplicationContext(w) + logger.info(f"User: {ctx.current_user}") + if source_dialect.lower() not in SQLGLOT_DIALECTS: + raise_validation_exception( + f"Error: Invalid value for '--source-dialect': '{source_dialect}' is not one of {DIALECTS}." + ) + if not input_source or not os.path.exists(input_source): + raise_validation_exception(f"Error: Invalid value for '--input-source': Path '{input_source}' does not exist.") + if not os.path.exists(output_folder) or output_folder in {None, ""}: + raise_validation_exception( + f"Error: Invalid value for '--output-folder': Path '{output_folder}' does not exist." + ) + + lineage_generator(source_dialect, input_source, output_folder) + + +@remorph.command +def configure_secrets(w: WorkspaceClient): + """Setup reconciliation connection profile details as Secrets on Databricks Workspace""" + recon_conf = ReconConfigPrompts(w) + + # Prompt for source + source = recon_conf.prompt_source() + + logger.info(f"Setting up Scope, Secrets for `{source}` reconciliation") + recon_conf.prompt_and_save_connection_details() + + +if __name__ == "__main__": + remorph() diff --git a/src/databricks/labs/remorph/config.py b/src/databricks/labs/remorph/config.py new file mode 100644 index 0000000000..03c8d590a3 --- /dev/null +++ b/src/databricks/labs/remorph/config.py @@ -0,0 +1,131 @@ +import logging +from dataclasses import dataclass + +from sqlglot.dialects.dialect import Dialect, Dialects, DialectType + +from databricks.labs.remorph.transpiler.transpile_status import ParserError +from databricks.labs.remorph.reconcile.recon_config import Table +from databricks.labs.remorph.transpiler.sqlglot.generator import databricks +from databricks.labs.remorph.transpiler.sqlglot.parsers import oracle, presto, snowflake, bigquery + +logger = logging.getLogger(__name__) + +SQLGLOT_DIALECTS: dict[str, DialectType] = { + "athena": Dialects.ATHENA, + "bigquery": bigquery.BigQuery, + "databricks": databricks.Databricks, + "mysql": Dialects.MYSQL, + "netezza": Dialects.POSTGRES, + "oracle": oracle.Oracle, + "postgresql": Dialects.POSTGRES, + "presto": presto.Presto, + "redshift": Dialects.REDSHIFT, + "snowflake": snowflake.Snowflake, + "sqlite": Dialects.SQLITE, + "teradata": Dialects.TERADATA, + "trino": Dialects.TRINO, + "tsql": Dialects.TSQL, + "vertica": Dialects.POSTGRES, +} + + +def get_dialect(engine: str) -> Dialect: + return Dialect.get_or_raise(SQLGLOT_DIALECTS.get(engine)) + + +def get_key_from_dialect(input_dialect: Dialect) -> str: + return [source_key for source_key, dialect in SQLGLOT_DIALECTS.items() if dialect == input_dialect][0] + + +@dataclass +class TranspileConfig: + __file__ = "config.yml" + __version__ = 1 + + source_dialect: str + input_source: str | None = None + output_folder: str | None = None + sdk_config: dict[str, str] | None = None + skip_validation: bool = False + catalog_name: str = "remorph" + schema_name: str = "transpiler" + mode: str = "current" + + def get_read_dialect(self): + return get_dialect(self.source_dialect) + + def get_write_dialect(self): + if self.mode == "experimental": + return get_dialect("experimental") + return get_dialect("databricks") + + +@dataclass +class TableRecon: + __file__ = "recon_config.yml" + __version__ = 1 + + source_schema: str + target_catalog: str + target_schema: str + tables: list[Table] + source_catalog: str | None = None + + def __post_init__(self): + self.source_schema = self.source_schema.lower() + self.target_schema = self.target_schema.lower() + self.target_catalog = self.target_catalog.lower() + self.source_catalog = self.source_catalog.lower() if self.source_catalog else self.source_catalog + + +@dataclass +class DatabaseConfig: + source_schema: str + target_catalog: str + target_schema: str + source_catalog: str | None = None + + +@dataclass +class TranspilationResult: + transpiled_sql: list[str] + parse_error_list: list[ParserError] + + +@dataclass +class ValidationResult: + validated_sql: str + exception_msg: str | None + + +@dataclass +class ReconcileTablesConfig: + filter_type: str # all/include/exclude + tables_list: list[str] # [*, table1, table2] + + +@dataclass +class ReconcileMetadataConfig: + catalog: str = "remorph" + schema: str = "reconcile" + volume: str = "reconcile_volume" + + +@dataclass +class ReconcileConfig: + __file__ = "reconcile.yml" + __version__ = 1 + + data_source: str + report_type: str + secret_scope: str + database_config: DatabaseConfig + metadata_config: ReconcileMetadataConfig + job_id: str | None = None + tables: ReconcileTablesConfig | None = None + + +@dataclass +class RemorphConfigs: + transpile: TranspileConfig | None = None + reconcile: ReconcileConfig | None = None diff --git a/src/databricks/labs/remorph/contexts/__init__.py b/src/databricks/labs/remorph/contexts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/contexts/application.py b/src/databricks/labs/remorph/contexts/application.py new file mode 100644 index 0000000000..e6ab58df3b --- /dev/null +++ b/src/databricks/labs/remorph/contexts/application.py @@ -0,0 +1,133 @@ +import logging +from functools import cached_property + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.upgrades import Upgrades +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.lsql.backends import SqlBackend +from databricks.sdk import WorkspaceClient +from databricks.sdk.config import Config +from databricks.sdk.errors import NotFound +from databricks.sdk.service.iam import User + +from databricks.labs.remorph.config import TranspileConfig, ReconcileConfig, RemorphConfigs +from databricks.labs.remorph.deployment.configurator import ResourceConfigurator +from databricks.labs.remorph.deployment.dashboard import DashboardDeployment +from databricks.labs.remorph.deployment.installation import WorkspaceInstallation +from databricks.labs.remorph.deployment.recon import TableDeployment, JobDeployment, ReconDeployment +from databricks.labs.remorph.helpers import db_sql +from databricks.labs.remorph.helpers.metastore import CatalogOperations + +logger = logging.getLogger(__name__) + + +class ApplicationContext: + def __init__(self, ws: WorkspaceClient): + self._ws = ws + + def replace(self, **kwargs): + """Replace cached properties for unit testing purposes.""" + for key, value in kwargs.items(): + self.__dict__[key] = value + return self + + @cached_property + def workspace_client(self) -> WorkspaceClient: + return self._ws + + @cached_property + def current_user(self) -> User: + return self.workspace_client.current_user.me() + + @cached_property + def product_info(self) -> ProductInfo: + return ProductInfo.from_class(RemorphConfigs) + + @cached_property + def installation(self) -> Installation: + return Installation.assume_user_home(self.workspace_client, self.product_info.product_name()) + + @cached_property + def transpile_config(self) -> TranspileConfig | None: + try: + return self.installation.load(TranspileConfig) + except NotFound as err: + logger.debug(f"Couldn't find existing `transpile` installation: {err}") + return None + + @cached_property + def recon_config(self) -> ReconcileConfig | None: + try: + return self.installation.load(ReconcileConfig) + except NotFound as err: + logger.debug(f"Couldn't find existing `reconcile` installation: {err}") + return None + + @cached_property + def remorph_config(self) -> RemorphConfigs: + return RemorphConfigs(transpile=self.transpile_config, reconcile=self.recon_config) + + @cached_property + def connect_config(self) -> Config: + return self.workspace_client.config + + @cached_property + def install_state(self) -> InstallState: + return InstallState.from_installation(self.installation) + + @cached_property + def sql_backend(self) -> SqlBackend: + return db_sql.get_sql_backend(self.workspace_client) + + @cached_property + def catalog_operations(self) -> CatalogOperations: + return CatalogOperations(self.workspace_client) + + @cached_property + def prompts(self) -> Prompts: + return Prompts() + + @cached_property + def resource_configurator(self) -> ResourceConfigurator: + return ResourceConfigurator(self.workspace_client, self.prompts, self.catalog_operations) + + @cached_property + def table_deployment(self) -> TableDeployment: + return TableDeployment(self.sql_backend) + + @cached_property + def job_deployment(self) -> JobDeployment: + return JobDeployment(self.workspace_client, self.installation, self.install_state, self.product_info) + + @cached_property + def dashboard_deployment(self) -> DashboardDeployment: + return DashboardDeployment(self.workspace_client, self.installation, self.install_state) + + @cached_property + def recon_deployment(self) -> ReconDeployment: + return ReconDeployment( + self.workspace_client, + self.installation, + self.install_state, + self.product_info, + self.table_deployment, + self.job_deployment, + self.dashboard_deployment, + ) + + @cached_property + def workspace_installation(self) -> WorkspaceInstallation: + return WorkspaceInstallation( + self.workspace_client, + self.prompts, + self.installation, + self.recon_deployment, + self.product_info, + self.upgrades, + ) + + @cached_property + def upgrades(self): + return Upgrades(self.product_info, self.installation) diff --git a/src/databricks/labs/remorph/coverage/__init__.py b/src/databricks/labs/remorph/coverage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/coverage/commons.py b/src/databricks/labs/remorph/coverage/commons.py new file mode 100644 index 0000000000..8b8043adfc --- /dev/null +++ b/src/databricks/labs/remorph/coverage/commons.py @@ -0,0 +1,223 @@ +# pylint: disable=all +import collections +import dataclasses +import json +import logging +import os +import subprocess +import time +from collections.abc import Generator +from datetime import datetime, timezone +from pathlib import Path +from typing import TextIO, List + +import sqlglot +from sqlglot.expressions import Expression +from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.databricks import Databricks +from sqlglot.errors import ErrorLevel + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ReportEntry: + project: str + commit_hash: str | None + version: str + timestamp: str + source_dialect: str + target_dialect: str + file: str + parsed: int = 0 # 1 for success, 0 for failure + statements: int = 0 # number of statements parsed + transpiled: int = 0 # 1 for success, 0 for failure + transpiled_statements: int = 0 # number of statements transpiled + failures: List[dict] = dataclasses.field(default_factory=lambda: []) + + +def sqlglot_run_coverage(dialect, subfolder): + input_dir = get_env_var("INPUT_DIR_PARENT", required=True) + output_dir = get_env_var("OUTPUT_DIR", required=True) + sqlglot_version = sqlglot.__version__ + SQLGLOT_COMMIT_HASH = "" # C0103 pylint + + if not input_dir: + raise ValueError("Environment variable `INPUT_DIR_PARENT` is required") + if not output_dir: + raise ValueError("Environment variable `OUTPUT_DIR` is required") + + collect_transpilation_stats( + "SQLGlot", + SQLGLOT_COMMIT_HASH, + sqlglot_version, + dialect, + Databricks, + Path(input_dir) / subfolder, + Path(output_dir), + ) + + +def local_report(output_dir: Path): + all = collections.defaultdict(list) + for file in output_dir.rglob("*.json"): + with file.open("r", encoding="utf8") as f: + for line in f: + raw = json.loads(line) + entry = ReportEntry(**raw) + all[(entry.project, entry.source_dialect)].append(entry) + for (project, dialect), entries in sorted(all.items()): + total = len(entries) + parsed = sum(entry.parsed for entry in entries) + transpiled = sum(entry.transpiled for entry in entries) + parse_ratio = parsed / total + transpile_ratio = transpiled / total + print( + f"{project} -> {dialect}: {parse_ratio:.2%} parsed ({parsed}/{total}), " + f"{transpile_ratio:.2%} transpiled ({transpiled}/{total})" + ) + + +def get_supported_sql_files(input_dir: Path) -> Generator[Path, None, None]: + yield from filter(lambda item: item.is_file() and item.suffix.lower() in [".sql", ".ddl"], input_dir.rglob("*")) + + +def write_json_line(file: TextIO, content: ReportEntry): + json.dump(dataclasses.asdict(content), file) + file.write("\n") + + +def get_env_var(env_var: str, *, required: bool = False) -> str | None: + """ + Get the value of an environment variable. + + :param env_var: The name of the environment variable to get the value of. + :param required: Indicates if the environment variable is required and raises a ValueError if it's not set. + :return: Returns the environment variable's value, or None if it's not set and not required. + """ + value = os.getenv(env_var) + if value is None and required: + message = f"Environment variable {env_var} is not set" + raise ValueError(message) + return value + + +def get_current_commit_hash() -> str | None: + try: + return ( + subprocess.check_output( + ["/usr/bin/git", "rev-parse", "--short", "HEAD"], + cwd=Path(__file__).resolve().parent, + ) + .decode("ascii") + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError) as e: + logger.warning(f"Could not get the current commit hash. {e!s}") + return None + + +def get_current_time_utc() -> datetime: + return datetime.now(timezone.utc) + + +def parse_sql(sql: str, dialect: type[Dialect]) -> list[Expression]: + return [ + expression for expression in sqlglot.parse(sql, read=dialect, error_level=ErrorLevel.IMMEDIATE) if expression + ] + + +def generate_sql(expressions: list[Expression], dialect: type[Dialect]) -> list[str]: + generator_dialect = Dialect.get_or_raise(dialect) + return [generator_dialect.generate(expression, copy=False) for expression in expressions if expression] + + +def _ensure_valid_io_paths(input_dir: Path, result_dir: Path): + if not input_dir.exists() or not input_dir.is_dir(): + message = f"The input path {input_dir} doesn't exist or is not a directory" + raise NotADirectoryError(message) + + if not result_dir.exists(): + logger.info(f"Creating the output directory {result_dir}") + result_dir.mkdir(parents=True) + elif not result_dir.is_dir(): + message = f"The output path {result_dir} exists but is not a directory" + raise NotADirectoryError(message) + + +def _get_report_file_path( + project: str, + source_dialect: type[Dialect], + target_dialect: type[Dialect], + result_dir: Path, +) -> Path: + source_dialect_name = source_dialect.__name__ + target_dialect_name = target_dialect.__name__ + current_time_ns = time.time_ns() + return result_dir / f"{project}_{source_dialect_name}_{target_dialect_name}_{current_time_ns}.json".lower() + + +def _prepare_report_entry( + project: str, + commit_hash: str, + version: str, + source_dialect: type[Dialect], + target_dialect: type[Dialect], + file_path: str, + sql: str, +) -> ReportEntry: + report_entry = ReportEntry( + project=project, + commit_hash=commit_hash, + version=version, + timestamp=get_current_time_utc().isoformat(), + source_dialect=source_dialect.__name__, + target_dialect=target_dialect.__name__, + file=file_path, + ) + try: + expressions = parse_sql(sql, source_dialect) + report_entry.parsed = 1 + report_entry.statements = len(expressions) + except Exception as pe: + report_entry.failures.append({'error_code': type(pe).__name__, 'error_message': repr(pe)}) + return report_entry + + try: + generated_sqls = generate_sql(expressions, target_dialect) + report_entry.transpiled = 1 + report_entry.transpiled_statements = len([sql for sql in generated_sqls if sql.strip()]) + except Exception as te: + report_entry.failures.append({'error_code': type(te).__name__, 'error_message': repr(te)}) + + return report_entry + + +def collect_transpilation_stats( + project: str, + commit_hash: str, + version: str, + source_dialect: type[Dialect], + target_dialect: type[Dialect], + input_dir: Path, + result_dir: Path, +): + _ensure_valid_io_paths(input_dir, result_dir) + report_file_path = _get_report_file_path(project, source_dialect, target_dialect, result_dir) + + with report_file_path.open("w", encoding="utf8") as report_file: + for input_file in get_supported_sql_files(input_dir): + with input_file.open("r", encoding="utf-8-sig") as file: + sql = file.read() + + file_path = str(input_file.absolute().relative_to(input_dir.parent.absolute())) + report_entry = _prepare_report_entry( + project, + commit_hash, + version, + source_dialect, + target_dialect, + file_path, + sql, + ) + write_json_line(report_file, report_entry) diff --git a/src/databricks/labs/remorph/coverage/local_report.py b/src/databricks/labs/remorph/coverage/local_report.py new file mode 100644 index 0000000000..a2b83c581d --- /dev/null +++ b/src/databricks/labs/remorph/coverage/local_report.py @@ -0,0 +1,9 @@ +from pathlib import Path + +from databricks.labs.remorph.coverage import commons + +if __name__ == "__main__": + output_dir = commons.get_env_var("OUTPUT_DIR", required=True) + if not output_dir: + raise ValueError("Environment variable `OUTPUT_DIR` is required") + commons.local_report(Path(output_dir)) diff --git a/src/databricks/labs/remorph/coverage/remorph_snow_transpilation_coverage.py b/src/databricks/labs/remorph/coverage/remorph_snow_transpilation_coverage.py new file mode 100644 index 0000000000..ef1b56676a --- /dev/null +++ b/src/databricks/labs/remorph/coverage/remorph_snow_transpilation_coverage.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.remorph.coverage import commons +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks +from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake + +if __name__ == "__main__": + input_dir = commons.get_env_var("INPUT_DIR_PARENT", required=True) + output_dir = commons.get_env_var("OUTPUT_DIR", required=True) + + REMORPH_COMMIT_HASH = commons.get_current_commit_hash() or "" # C0103 pylint + product_info = ProductInfo(__file__) + remorph_version = product_info.unreleased_version() + + if not input_dir: + raise ValueError("Environment variable `INPUT_DIR_PARENT` is required") + if not output_dir: + raise ValueError("Environment variable `OUTPUT_DIR` is required") + + commons.collect_transpilation_stats( + "Remorph", + REMORPH_COMMIT_HASH, + remorph_version, + Snowflake, + Databricks, + Path(input_dir) / 'snowflake', + Path(output_dir), + ) diff --git a/src/databricks/labs/remorph/coverage/sqlglot_snow_transpilation_coverage.py b/src/databricks/labs/remorph/coverage/sqlglot_snow_transpilation_coverage.py new file mode 100644 index 0000000000..978167d48f --- /dev/null +++ b/src/databricks/labs/remorph/coverage/sqlglot_snow_transpilation_coverage.py @@ -0,0 +1,5 @@ +from sqlglot.dialects.snowflake import Snowflake +from databricks.labs.remorph.coverage.commons import sqlglot_run_coverage + +if __name__ == "__main__": + sqlglot_run_coverage(Snowflake, "snowflake") diff --git a/src/databricks/labs/remorph/coverage/sqlglot_tsql_transpilation_coverage.py b/src/databricks/labs/remorph/coverage/sqlglot_tsql_transpilation_coverage.py new file mode 100644 index 0000000000..67f92630fa --- /dev/null +++ b/src/databricks/labs/remorph/coverage/sqlglot_tsql_transpilation_coverage.py @@ -0,0 +1,5 @@ +from sqlglot.dialects.tsql import TSQL +from databricks.labs.remorph.coverage.commons import sqlglot_run_coverage + +if __name__ == "__main__": + sqlglot_run_coverage(TSQL, "tsql") diff --git a/src/databricks/labs/remorph/deployment/__init__.py b/src/databricks/labs/remorph/deployment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/deployment/configurator.py b/src/databricks/labs/remorph/deployment/configurator.py new file mode 100644 index 0000000000..048f4748e5 --- /dev/null +++ b/src/databricks/labs/remorph/deployment/configurator.py @@ -0,0 +1,199 @@ +import logging +import time + +from databricks.labs.blueprint.tui import Prompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.catalog import Privilege, SecurableType +from databricks.sdk.service.sql import ( + CreateWarehouseRequestWarehouseType, + EndpointInfoWarehouseType, + SpotInstancePolicy, +) + +from databricks.labs.remorph.helpers.metastore import CatalogOperations + +logger = logging.getLogger(__name__) + + +class ResourceConfigurator: + """ + Handles the setup of common Databricks resources like + catalogs, schemas, volumes, and warehouses used across remorph modules. + """ + + def __init__(self, ws: WorkspaceClient, prompts: Prompts, catalog_ops: CatalogOperations): + self._ws = ws + self._user = ws.current_user.me() + self._prompts = prompts + self._catalog_ops = catalog_ops + + def prompt_for_catalog_setup( + self, + ) -> str: + catalog_name = self._prompts.question("Enter catalog name", default="remorph") + catalog = self._catalog_ops.get_catalog(catalog_name) + if catalog: + logger.info(f"Found existing catalog `{catalog_name}`") + return catalog_name + if self._prompts.confirm(f"Catalog `{catalog_name}` doesn't exist. Create it?"): + result = self._catalog_ops.create_catalog(catalog_name) + assert result.name is not None + return result.name + raise SystemExit("Cannot continue installation, without a valid catalog, Aborting the installation.") + + def prompt_for_schema_setup( + self, + catalog: str, + default_schema_name: str, + ) -> str: + schema_name = self._prompts.question("Enter schema name", default=default_schema_name) + schema = self._catalog_ops.get_schema(catalog, schema_name) + if schema: + logger.info(f"Found existing schema `{schema_name}` in catalog `{catalog}`") + return schema_name + if self._prompts.confirm(f"Schema `{schema_name}` doesn't exist in catalog `{catalog}`. Create it?"): + result = self._catalog_ops.create_schema(schema_name, catalog) + assert result.name is not None + return result.name + raise SystemExit("Cannot continue installation, without a valid schema. Aborting the installation.") + + def prompt_for_volume_setup( + self, + catalog: str, + schema: str, + default_volume_name: str, + ) -> str: + volume_name = self._prompts.question("Enter volume name", default=default_volume_name) + volume = self._catalog_ops.get_volume(catalog, schema, volume_name) + if volume: + logger.info(f"Found existing volume `{volume_name}` in catalog `{catalog}` and schema `{schema}`") + return volume_name + if self._prompts.confirm( + f"Volume `{volume_name}` doesn't exist in catalog `{catalog}` and schema `{schema}`. Create it?" + ): + result = self._catalog_ops.create_volume(catalog, schema, volume_name) + assert result.name is not None + return result.name + raise SystemExit("Cannot continue installation, without a valid volume. Aborting the installation.") + + def prompt_for_warehouse_setup(self, warehouse_name_prefix: str) -> str: + def warehouse_type(_): + return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS" + + pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | { + f"{_.name} ({_.id}, {warehouse_type(_)}, {_.state.value})": _.id + for _ in self._ws.warehouses.list() + if _.warehouse_type == EndpointInfoWarehouseType.PRO + } + warehouse_id = self._prompts.choice_from_dict( + "Select PRO or SERVERLESS SQL warehouse", + pro_warehouses, + ) + if warehouse_id == "create_new": + new_warehouse = self._ws.warehouses.create( + name=f"{warehouse_name_prefix} {time.time_ns()}", + spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED, + warehouse_type=CreateWarehouseRequestWarehouseType.PRO, + cluster_size="Small", + max_num_clusters=1, + ) + warehouse_id = new_warehouse.id + return warehouse_id + + def has_necessary_catalog_access( + self, catalog_name: str, user_name: str, privilege_sets: tuple[set[Privilege], ...] + ): + catalog = self._catalog_ops.get_catalog(catalog_name) + assert catalog, f"Catalog not found {catalog_name}" + if self._catalog_ops.has_catalog_access(catalog, user_name, privilege_sets): + return True + missing_permissions = self._get_missing_permissions( + user_name, SecurableType.CATALOG, catalog.name, privilege_sets + ) + logger.error( + f"User `{user_name}` doesn't have required privileges :: \n`{missing_permissions}`\n to access catalog `{catalog_name}` " + ) + return False + + def has_necessary_schema_access( + self, catalog_name: str, schema_name: str, user_name: str, privilege_sets: tuple[set[Privilege], ...] + ): + schema = self._catalog_ops.get_schema(catalog_name, schema_name) + assert schema, f"Schema not found {catalog_name}.{schema_name}" + if self._catalog_ops.has_schema_access(schema, user_name, privilege_sets): + return True + missing_permissions = self._get_missing_permissions( + user_name, SecurableType.SCHEMA, schema.full_name, privilege_sets + ) + logger.error( + f"User `{user_name}` doesn't have required privileges :: \n`{missing_permissions}`\n to access schema `{schema.full_name}` " + ) + return False + + def has_necessary_volume_access( + self, + catalog_name: str, + schema_name: str, + volume_name: str, + user_name: str, + privilege_sets: tuple[set[Privilege], ...], + ): + volume = self._catalog_ops.get_volume(catalog_name, schema_name, volume_name) + assert volume, f"Volume not found {catalog_name}.{schema_name}.{volume_name}" + if self._catalog_ops.has_volume_access(volume, user_name, privilege_sets): + return True + missing_permissions = self._get_missing_permissions( + user_name, SecurableType.VOLUME, volume.full_name, privilege_sets + ) + logger.error( + f"User `{user_name}` doesn't have required privileges :: \n`{missing_permissions}`\n to access volume `{volume.full_name}` " + ) + return False + + def _get_missing_permissions( + self, + user_name: str, + securable_type: SecurableType, + resource_name: str | None, + privilege_sets: tuple[set[Privilege], ...], + ): + assert resource_name, f"Catalog Resource name must be provided {resource_name}" + missing_permissions_list = [] + for privilege_set in privilege_sets: + permissions = self._catalog_ops.has_privileges(user_name, securable_type, resource_name, privilege_set) + if not permissions: + missing_privileges = ", ".join([privilege.name for privilege in privilege_set]) + missing_permissions_list.append(f" * '{missing_privileges}' ") + + return " OR \n".join(missing_permissions_list) + + def has_necessary_access(self, catalog_name: str, schema_name: str, volume_name: str | None): + catalog_required_privileges: tuple[set[Privilege], ...] = ( + {Privilege.ALL_PRIVILEGES}, + {Privilege.USE_CATALOG}, + ) + schema_required_privileges: tuple[set[Privilege], ...] = ( + {Privilege.ALL_PRIVILEGES}, + {Privilege.USE_SCHEMA, Privilege.MODIFY, Privilege.SELECT, Privilege.CREATE_VOLUME}, + {Privilege.USE_SCHEMA, Privilege.MODIFY, Privilege.SELECT}, + ) + volume_required_privileges: tuple[set[Privilege], ...] = ( + {Privilege.ALL_PRIVILEGES}, + {Privilege.READ_VOLUME, Privilege.WRITE_VOLUME}, + ) + + user_name = self._user.user_name + assert user_name is not None + + catalog_access = self.has_necessary_catalog_access(catalog_name, user_name, catalog_required_privileges) + schema_access = self.has_necessary_schema_access( + catalog_name, schema_name, user_name, schema_required_privileges + ) + required_access = catalog_access and schema_access + if volume_name: + volume_access = self.has_necessary_volume_access( + catalog_name, schema_name, volume_name, user_name, volume_required_privileges + ) + required_access = required_access and volume_access + if not required_access: + raise SystemExit("Cannot continue installation, without necessary access. Aborting the installation.") diff --git a/src/databricks/labs/remorph/deployment/dashboard.py b/src/databricks/labs/remorph/deployment/dashboard.py new file mode 100644 index 0000000000..5b76f09b8d --- /dev/null +++ b/src/databricks/labs/remorph/deployment/dashboard.py @@ -0,0 +1,140 @@ +import logging +from datetime import timedelta +from pathlib import Path + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.lsql.dashboards import DashboardMetadata, Dashboards +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + InvalidParameterValue, + NotFound, + DeadlineExceeded, + InternalError, + ResourceAlreadyExists, +) +from databricks.sdk.retries import retried +from databricks.sdk.service.dashboards import LifecycleState, Dashboard + +from databricks.labs.remorph.config import ReconcileConfig, ReconcileMetadataConfig + +logger = logging.getLogger(__name__) + + +class DashboardDeployment: + + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + + def deploy( + self, + folder: Path, + config: ReconcileConfig, + ): + """ + Create dashboards from Dashboard metadata files. + The given folder is expected to contain subfolders each containing metadata for individual dashboards. + + :param folder: Path to the base folder. + :param config: Configuration for reconciliation. + """ + logger.info(f"Deploying dashboards from base folder {folder}") + parent_path = f"{self._installation.install_folder()}/dashboards" + try: + self._ws.workspace.mkdirs(parent_path) + except ResourceAlreadyExists: + logger.info(f"Dashboard parent path already exists: {parent_path}") + + valid_dashboard_refs = set() + for dashboard_folder in folder.iterdir(): + if not dashboard_folder.is_dir(): + continue + valid_dashboard_refs.add(self._dashboard_reference(dashboard_folder)) + dashboard = self._update_or_create_dashboard(dashboard_folder, parent_path, config.metadata_config) + logger.info( + f"Dashboard deployed with URL: {self._ws.config.host}/sql/dashboardsv3/{dashboard.dashboard_id}" + ) + self._install_state.save() + + self._remove_deprecated_dashboards(valid_dashboard_refs) + + def _dashboard_reference(self, folder: Path) -> str: + return f"{folder.stem}".lower() + + # InternalError and DeadlineExceeded are retried because of Lakeview internal issues + # These issues have been reported to and are resolved by the Lakeview team + # Keeping the retry for resilience + @retried(on=[InternalError, DeadlineExceeded], timeout=timedelta(minutes=3)) + def _update_or_create_dashboard( + self, + folder: Path, + ws_parent_path: str, + config: ReconcileMetadataConfig, + ) -> Dashboard: + logging.info(f"Reading dashboard folder {folder}") + metadata = DashboardMetadata.from_path(folder).replace_database( + catalog=config.catalog, + catalog_to_replace="remorph", + database=config.schema, + database_to_replace="reconcile", + ) + + metadata.display_name = self._name_with_prefix(metadata.display_name) + reference = self._dashboard_reference(folder) + dashboard_id = self._install_state.dashboards.get(reference) + if dashboard_id is not None: + try: + dashboard_id = self._handle_existing_dashboard(dashboard_id, metadata.display_name) + except (NotFound, InvalidParameterValue): + logger.info(f"Recovering invalid dashboard: {metadata.display_name} ({dashboard_id})") + try: + dashboard_path = f"{ws_parent_path}/{metadata.display_name}.lvdash.json" + self._ws.workspace.delete(dashboard_path) # Cannot recreate dashboard if file still exists + logger.debug( + f"Deleted dangling dashboard {metadata.display_name} ({dashboard_id}): {dashboard_path}" + ) + except NotFound: + pass + dashboard_id = None # Recreate the dashboard if it's reference is corrupted (manually) + + dashboard = Dashboards(self._ws).create_dashboard( + metadata, + dashboard_id=dashboard_id, + parent_path=ws_parent_path, + warehouse_id=self._ws.config.warehouse_id, + publish=True, + ) + assert dashboard.dashboard_id is not None + self._install_state.dashboards[reference] = dashboard.dashboard_id + return dashboard + + def _name_with_prefix(self, name: str) -> str: + prefix = self._installation.product() + return f"[{prefix.upper()}] {name}" + + def _handle_existing_dashboard(self, dashboard_id: str, display_name: str) -> str | None: + dashboard = self._ws.lakeview.get(dashboard_id) + if dashboard.lifecycle_state is None: + raise NotFound(f"Dashboard life cycle state: {display_name} ({dashboard_id})") + if dashboard.lifecycle_state == LifecycleState.TRASHED: + logger.info(f"Recreating trashed dashboard: {display_name} ({dashboard_id})") + return None # Recreate the dashboard if it is trashed (manually) + return dashboard_id # Update the existing dashboard + + def _remove_deprecated_dashboards(self, valid_dashboard_refs: set[str]): + for ref, dashboard_id in self._install_state.dashboards.items(): + if ref not in valid_dashboard_refs: + try: + logger.info(f"Removing dashboard_id={dashboard_id}, as it is no longer needed.") + del self._install_state.dashboards[ref] + self._ws.lakeview.trash(dashboard_id) + except (InvalidParameterValue, NotFound): + logger.warning(f"Dashboard `{dashboard_id}` doesn't exist anymore for some reason.") + continue diff --git a/src/databricks/labs/remorph/deployment/installation.py b/src/databricks/labs/remorph/deployment/installation.py new file mode 100644 index 0000000000..f7a1daadd9 --- /dev/null +++ b/src/databricks/labs/remorph/deployment/installation.py @@ -0,0 +1,125 @@ +import logging +from ast import literal_eval +from pathlib import Path + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.blueprint.upgrades import Upgrades +from databricks.labs.blueprint.wheels import ProductInfo, Version +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.sdk.mixins.compute import SemVer +from databricks.sdk.errors.platform import InvalidParameterValue, ResourceDoesNotExist + +from databricks.labs.remorph.config import RemorphConfigs +from databricks.labs.remorph.deployment.recon import ReconDeployment + +logger = logging.getLogger("databricks.labs.remorph.install") + + +class WorkspaceInstallation: + def __init__( + self, + ws: WorkspaceClient, + prompts: Prompts, + installation: Installation, + recon_deployment: ReconDeployment, + product_info: ProductInfo, + upgrades: Upgrades, + ): + self._ws = ws + self._prompts = prompts + self._installation = installation + self._recon_deployment = recon_deployment + self._product_info = product_info + self._upgrades = upgrades + + def _get_local_version_file_path(self): + user_home = f"{Path(__file__).home()}" + return Path(f"{user_home}/.databricks/labs/{self._product_info.product_name()}/state/version.json") + + def _get_local_version_file(self, file_path: Path): + data = None + with file_path.open("r") as f: + data = literal_eval(f.read()) + assert data, "Unable to read local version file." + local_installed_version = data["version"] + try: + SemVer.parse(local_installed_version) + except ValueError: + logger.warning(f"{local_installed_version} is not a valid version.") + local_installed_version = "v0.3.0" + local_installed_date = data["date"] + logger.debug(f"Found local installation version: {local_installed_version} {local_installed_date}") + return Version( + version=local_installed_version, + date=local_installed_date, + wheel=f"databricks_labs_remorph-{local_installed_version}-py3-none-any.whl", + ) + + def _get_ws_version(self): + try: + return self._installation.load(Version) + except ResourceDoesNotExist as err: + logger.warning(f"Unable to get Workspace Version due to: {err}") + return None + + def _apply_upgrades(self): + """ + * If remote version doesn't exist and local version exists: + Upload Version file to workspace to handle previous installations. + * If remote version or local_version exists, then only apply upgrades. + * No need to apply upgrades for fresh installation. + """ + ws_version = self._get_ws_version() + local_version_path = self._get_local_version_file_path() + local_version = local_version_path.exists() + if not ws_version and local_version: + self._installation.save(self._get_local_version_file(local_version_path)) + + if ws_version or local_version: + try: + self._upgrades.apply(self._ws) + logger.debug("Upgrades applied successfully.") + except (InvalidParameterValue, NotFound) as err: + logger.warning(f"Unable to apply Upgrades due to: {err}") + + def _upload_wheel(self): + wheels = self._product_info.wheels(self._ws) + with wheels: + wheel_paths = [wheels.upload_to_wsfs()] + wheel_paths = [f"/Workspace{wheel}" for wheel in wheel_paths] + return wheel_paths + + def install(self, config: RemorphConfigs): + self._apply_upgrades() + wheel_paths: list[str] = self._upload_wheel() + if config.reconcile: + logger.info("Installing Remorph reconcile Metadata components.") + self._recon_deployment.install(config.reconcile, wheel_paths) + + def uninstall(self, config: RemorphConfigs): + # This will remove all the Remorph modules + if not self._prompts.confirm( + "Do you want to uninstall Remorph from the workspace too, this would " + "remove Remorph project folder, jobs, metadata and dashboards" + ): + return + logger.info(f"Uninstalling Remorph from {self._ws.config.host}.") + try: + self._installation.files() + except NotFound: + logger.error(f"Check if {self._installation.install_folder()} is present. Aborting uninstallation.") + return + + if config.transpile: + logging.info( + f"Won't remove transpile validation schema `{config.transpile.schema_name}` " + f"from catalog `{config.transpile.catalog_name}`. Please remove it manually." + ) + + if config.reconcile: + self._recon_deployment.uninstall(config.reconcile) + + self._installation.remove() + logger.info("Uninstallation completed successfully.") diff --git a/src/databricks/labs/remorph/deployment/job.py b/src/databricks/labs/remorph/deployment/job.py new file mode 100644 index 0000000000..97c4fae176 --- /dev/null +++ b/src/databricks/labs/remorph/deployment/job.py @@ -0,0 +1,147 @@ +import dataclasses +import logging +from datetime import datetime, timezone, timedelta +from typing import Any + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue +from databricks.sdk.service import compute +from databricks.sdk.service.jobs import Task, PythonWheelTask, JobCluster, JobSettings, JobParameterDefinition + +from databricks.labs.remorph.config import ReconcileConfig +from databricks.labs.remorph.reconcile.constants import ReconSourceType + +logger = logging.getLogger(__name__) + +_TEST_JOBS_PURGE_TIMEOUT = timedelta(hours=1, minutes=15) + + +class JobDeployment: + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + product_info: ProductInfo, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + self._product_info = product_info + + def deploy_recon_job(self, name, recon_config: ReconcileConfig, remorph_wheel_path: str): + logger.info("Deploying reconciliation job.") + job_id = self._update_or_create_recon_job(name, recon_config, remorph_wheel_path) + logger.info(f"Reconciliation job deployed with job_id={job_id}") + logger.info(f"Job URL: {self._ws.config.host}#job/{job_id}") + self._install_state.save() + + def _update_or_create_recon_job(self, name, recon_config: ReconcileConfig, remorph_wheel_path: str) -> str: + description = "Run the reconciliation process" + task_key = "run_reconciliation" + + job_settings = self._recon_job_settings(name, task_key, description, recon_config, remorph_wheel_path) + if name in self._install_state.jobs: + try: + job_id = int(self._install_state.jobs[name]) + logger.info(f"Updating configuration for job `{name}`, job_id={job_id}") + self._ws.jobs.reset(job_id, JobSettings(**job_settings)) + return str(job_id) + except InvalidParameterValue: + del self._install_state.jobs[name] + logger.warning(f"Job `{name}` does not exist anymore for some reason") + return self._update_or_create_recon_job(name, recon_config, remorph_wheel_path) + + logger.info(f"Creating new job configuration for job `{name}`") + new_job = self._ws.jobs.create(**job_settings) + assert new_job.job_id is not None + self._install_state.jobs[name] = str(new_job.job_id) + return str(new_job.job_id) + + def _recon_job_settings( + self, + job_name: str, + task_key: str, + description: str, + recon_config: ReconcileConfig, + remorph_wheel_path: str, + ) -> dict[str, Any]: + latest_lts_spark = self._ws.clusters.select_spark_version(latest=True, long_term_support=True) + version = self._product_info.version() + version = version if not self._ws.config.is_gcp else version.replace("+", "-") + tags = {"version": f"v{version}"} + if self._is_testing(): + # Add RemoveAfter tag for test job cleanup + date_to_remove = self._get_test_purge_time() + tags.update({"RemoveAfter": date_to_remove}) + + return { + "name": self._name_with_prefix(job_name), + "tags": tags, + "job_clusters": [ + JobCluster( + job_cluster_key="Remorph_Reconciliation_Cluster", + new_cluster=compute.ClusterSpec( + data_security_mode=compute.DataSecurityMode.USER_ISOLATION, + spark_conf={}, + node_type_id=self._get_default_node_type_id(), + autoscale=compute.AutoScale(min_workers=2, max_workers=10), + spark_version=latest_lts_spark, + ), + ) + ], + "tasks": [ + self._job_recon_task( + Task( + task_key=task_key, + description=description, + job_cluster_key="Remorph_Reconciliation_Cluster", + ), + recon_config, + remorph_wheel_path, + ), + ], + "max_concurrent_runs": 2, + "parameters": [JobParameterDefinition(name="operation_name", default="reconcile")], + } + + def _job_recon_task(self, jobs_task: Task, recon_config: ReconcileConfig, remorph_wheel_path: str) -> Task: + libraries = [ + compute.Library(whl=remorph_wheel_path), + ] + source = recon_config.data_source + if source == ReconSourceType.ORACLE.value: + # TODO: Automatically fetch a version list for `ojdbc8` + oracle_driver_version = "23.4.0.24.05" + libraries.append( + compute.Library( + maven=compute.MavenLibrary(f"com.oracle.database.jdbc:ojdbc8:{oracle_driver_version}"), + ), + ) + + return dataclasses.replace( + jobs_task, + libraries=libraries, + python_wheel_task=PythonWheelTask( + package_name="databricks_labs_remorph", + entry_point="reconcile", + parameters=["{{job.parameters.[operation_name]}}"], + ), + ) + + def _is_testing(self): + return self._product_info.product_name() != "remorph" + + @staticmethod + def _get_test_purge_time() -> str: + return (datetime.now(timezone.utc) + _TEST_JOBS_PURGE_TIMEOUT).strftime("%Y%m%d%H") + + def _get_default_node_type_id(self) -> str: + return self._ws.clusters.select_node_type(local_disk=True, min_memory_gb=16) + + def _name_with_prefix(self, name: str) -> str: + prefix = self._installation.product() + return f"[{prefix.upper()}] {name}" diff --git a/src/databricks/labs/remorph/deployment/recon.py b/src/databricks/labs/remorph/deployment/recon.py new file mode 100644 index 0000000000..a9db3b569f --- /dev/null +++ b/src/databricks/labs/remorph/deployment/recon.py @@ -0,0 +1,143 @@ +import logging +from importlib.resources import files + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.blueprint.wheels import find_project_root +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue, NotFound + +import databricks.labs.remorph.resources +from databricks.labs.remorph.config import ReconcileConfig +from databricks.labs.remorph.deployment.dashboard import DashboardDeployment +from databricks.labs.remorph.deployment.job import JobDeployment +from databricks.labs.remorph.deployment.table import TableDeployment + +logger = logging.getLogger(__name__) + +_RECON_PREFIX = "Reconciliation" +RECON_JOB_NAME = f"{_RECON_PREFIX} Runner" + + +class ReconDeployment: + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + product_info: ProductInfo, + table_deployer: TableDeployment, + job_deployer: JobDeployment, + dashboard_deployer: DashboardDeployment, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + self._product_info = product_info + self._table_deployer = table_deployer + self._job_deployer = job_deployer + self._dashboard_deployer = dashboard_deployer + + def install(self, recon_config: ReconcileConfig | None, wheel_paths: list[str]): + if not recon_config: + logger.warning("Recon Config is empty.") + return + logger.info("Installing reconcile components.") + self._deploy_tables(recon_config) + self._deploy_dashboards(recon_config) + remorph_wheel_path = [whl for whl in wheel_paths if "remorph" in whl][0] + self._deploy_jobs(recon_config, remorph_wheel_path) + self._install_state.save() + logger.info("Installation of reconcile components completed successfully.") + + def uninstall(self, recon_config: ReconcileConfig | None): + if not recon_config: + return + logger.info("Uninstalling reconcile components.") + self._remove_dashboards() + self._remove_jobs() + logging.info( + f"Won't remove reconcile metadata schema `{recon_config.metadata_config.schema}` " + f"from catalog `{recon_config.metadata_config.catalog}`. Please remove it and the tables inside manually." + ) + logging.info( + f"Won't remove configured reconcile secret scope `{recon_config.secret_scope}`. " + f"Please remove it manually." + ) + + def _deploy_tables(self, recon_config: ReconcileConfig): + logger.info("Deploying reconciliation metadata tables.") + catalog = recon_config.metadata_config.catalog + schema = recon_config.metadata_config.schema + resources = files(databricks.labs.remorph.resources) + query_dir = resources.joinpath("reconcile/queries/installation") + + sqls_to_deploy = [ + "main.sql", + "metrics.sql", + "details.sql", + "aggregate_metrics.sql", + "aggregate_details.sql", + "aggregate_rules.sql", + ] + + for sql_file in sqls_to_deploy: + table_sql_file = query_dir.joinpath(sql_file) + self._table_deployer.deploy_table_from_ddl_file(catalog, schema, sql_file.strip(".sql"), table_sql_file) + + def _deploy_dashboards(self, recon_config: ReconcileConfig): + logger.info("Deploying reconciliation dashboards.") + dashboard_base_dir = find_project_root(__file__) / "src/databricks/labs/remorph/resources/reconcile/dashboards" + self._dashboard_deployer.deploy(dashboard_base_dir, recon_config) + + def _get_dashboards(self) -> list[tuple[str, str]]: + return list(self._install_state.dashboards.items()) + + def _remove_dashboards(self): + logger.info("Removing reconciliation dashboards.") + for dashboard_ref, dashboard_id in self._get_dashboards(): + try: + logger.info(f"Removing dashboard with id={dashboard_id}.") + del self._install_state.dashboards[dashboard_ref] + self._ws.lakeview.trash(dashboard_id) + except (InvalidParameterValue, NotFound): + logger.warning(f"Dashboard with id={dashboard_id} doesn't exist anymore for some reason.") + continue + + def _deploy_jobs(self, recon_config: ReconcileConfig, remorph_wheel_path: str): + logger.info("Deploying reconciliation jobs.") + self._job_deployer.deploy_recon_job(RECON_JOB_NAME, recon_config, remorph_wheel_path) + for job_name, job_id in self._get_deprecated_jobs(): + try: + logger.info(f"Removing job_id={job_id}, as it is no longer needed.") + del self._install_state.jobs[job_name] + self._ws.jobs.delete(job_id) + except (InvalidParameterValue, NotFound): + logger.warning(f"{job_name} doesn't exist anymore for some reason.") + continue + + def _get_jobs(self) -> list[tuple[str, int]]: + return [ + (job_name, int(job_id)) + for job_name, job_id in self._install_state.jobs.items() + if job_name.startswith(_RECON_PREFIX) + ] + + def _get_deprecated_jobs(self) -> list[tuple[str, int]]: + return [ + (job_name, int(job_id)) + for job_name, job_id in self._install_state.jobs.items() + if job_name.startswith(_RECON_PREFIX) and job_name != RECON_JOB_NAME + ] + + def _remove_jobs(self): + logger.info("Removing Reconciliation Jobs.") + for job_name, job_id in self._get_jobs(): + try: + logger.info(f"Removing job {job_name} with job_id={job_id}.") + del self._install_state.jobs[job_name] + self._ws.jobs.delete(int(job_id)) + except (InvalidParameterValue, NotFound): + logger.warning(f"{job_name} doesn't exist anymore for some reason.") + continue diff --git a/src/databricks/labs/remorph/deployment/table.py b/src/databricks/labs/remorph/deployment/table.py new file mode 100644 index 0000000000..bcd7cf8502 --- /dev/null +++ b/src/databricks/labs/remorph/deployment/table.py @@ -0,0 +1,30 @@ +import logging +from importlib.abc import Traversable + +from databricks.labs.lsql.backends import SqlBackend + +logger = logging.getLogger(__name__) + + +class TableDeployment: + def __init__(self, sql_backend: SqlBackend): + self._sql_backend = sql_backend + + def deploy_table_from_ddl_file( + self, + catalog: str, + schema: str, + table_name: str, + ddl_query_filepath: Traversable, + ): + """ + Deploys a table to the given catalog and schema + :param catalog: The table catalog + :param schema: The table schema + :param table_name: The table to deploy + :param ddl_query_filepath: DDL file path + """ + query = ddl_query_filepath.read_text() + logger.info(f"Deploying table {table_name} in {catalog}.{schema}") + logger.info(f"SQL Backend used for deploying table: {type(self._sql_backend).__name__}") + self._sql_backend.execute(query, catalog=catalog, schema=schema) diff --git a/src/databricks/labs/remorph/deployment/upgrade_common.py b/src/databricks/labs/remorph/deployment/upgrade_common.py new file mode 100644 index 0000000000..cbc2129d8d --- /dev/null +++ b/src/databricks/labs/remorph/deployment/upgrade_common.py @@ -0,0 +1,124 @@ +import logging +import re +from importlib.resources import files + +import databricks.labs.remorph.resources + +from databricks.labs.blueprint.tui import Prompts +from databricks.sdk import WorkspaceClient +from databricks.labs.remorph.helpers import db_sql + +logger = logging.getLogger(__name__) + + +def replace_patterns(sql_text: str) -> str: + """ + Replace the STRUCT and MAP datatypes in the SQL text with empty string + """ + # Pattern to match nested STRUCT and MAP datatypes + pattern = r'(STRUCT<[^<>]*?(?:<[^<>]*?>[^<>]*?)*>|MAP<[^<>]*?(?:<[^<>]*?>[^<>]*?)*>)' + parsed_sql_text = re.sub(pattern, "", sql_text, flags=re.DOTALL) + return parsed_sql_text + + +def extract_columns_with_datatype(sql_text: str) -> list[str]: + """ + Extract the columns with datatype from the SQL text + Example: + Input: CREATE TABLE main ( + recon_table_id BIGINT NOT NULL, + report_type STRING NOT NULL + ); + Output: [recon_table_id BIGINT NOT NULL, + report_type STRING NOT NULL] + """ + return sql_text[sql_text.index("(") + 1 : sql_text.index(")")].strip().split(",") + + +def extract_column_name(column_with_datatype: str) -> str: + """ + Extract the column name from the column with datatype. + Example: + Input: \n recon_table_id BIGINT NOT NULL, + Output: recon_table_id + """ + return column_with_datatype.strip("\n").strip().split(" ")[0] + + +def table_original_query(table_name: str, full_table_name: str) -> str: + """ + Get the main table DDL from the main.sql file + :return: str + """ + resources = files(databricks.labs.remorph.resources) + query_dir = resources.joinpath("reconcile/queries/installation") + return ( + query_dir.joinpath(f"{table_name}.sql") + .read_text() + .replace(f"CREATE TABLE IF NOT EXISTS {table_name}", f"CREATE OR REPLACE TABLE {full_table_name}") + ) + + +def current_table_columns(table_name: str, full_table_name: str) -> list[str]: + """ + Extract the column names from the main table DDL + :return: column_names: list[str] + """ + main_sql = replace_patterns(table_original_query(table_name, full_table_name)) + main_table_columns = [ + extract_column_name(main_table_column) for main_table_column in extract_columns_with_datatype(main_sql) + ] + return main_table_columns + + +def installed_table_columns(ws: WorkspaceClient, table_identifier: str) -> list[str]: + """ + Fetch the column names from the installed table on Databricks Workspace using SQL Backend + :return: column_names: list[str] + """ + main_table_columns = list(db_sql.get_sql_backend(ws).fetch(f"DESC {table_identifier}")) + return [row.col_name for row in main_table_columns] + + +def check_table_mismatch( + installed_table, + current_table, +) -> bool: + # Compare the current main table columns with the installed main table columns + if len(installed_table) != len(current_table) or sorted(installed_table) != sorted(current_table): + return True + return False + + +def recreate_table_sql( + table_identifier: str, + installed_table: list[str], + current_table: list[str], + prompts: Prompts, +) -> str | None: + """ + * Verify all the current main table columns are present in the installed main table and then use CTAS to recreate the main table + * If any of the current main table columns are missing in the installed main table, prompt the user to recreate the main table: + - If the user confirms, recreate the main table using the main DDL file, else log an error message and exit + :param table_identifier: + :param installed_table: + :param current_table: + :param prompts: + :return: + """ + table_name = table_identifier.split('.')[-1] + sql: str | None = ( + f"CREATE OR REPLACE TABLE {table_identifier} AS SELECT {','.join(current_table)} FROM {table_identifier}" + ) + + if not set(current_table).issubset(installed_table): + if prompts.confirm( + f"The `{table_identifier}` table columns are not as expected. Do you want to recreate the `{table_identifier}` table?" + ): + sql = table_original_query(table_name, table_identifier) + else: + logger.error( + f"The `{table_identifier}` table columns are not as expected. Please check and recreate the `{table_identifier}` table." + ) + sql = None + return sql diff --git a/src/databricks/labs/remorph/helpers/__init__.py b/src/databricks/labs/remorph/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/helpers/db_sql.py b/src/databricks/labs/remorph/helpers/db_sql.py new file mode 100644 index 0000000000..0c989b264c --- /dev/null +++ b/src/databricks/labs/remorph/helpers/db_sql.py @@ -0,0 +1,21 @@ +import logging +import os + +from databricks.labs.lsql.backends import ( + DatabricksConnectBackend, + RuntimeBackend, + SqlBackend, + StatementExecutionBackend, +) +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +def get_sql_backend(ws: WorkspaceClient, warehouse_id: str | None = None) -> SqlBackend: + warehouse_id = warehouse_id or ws.config.warehouse_id + if warehouse_id: + sql_backend: SqlBackend = StatementExecutionBackend(ws, warehouse_id) + else: + sql_backend = RuntimeBackend() if "DATABRICKS_RUNTIME_VERSION" in os.environ else DatabricksConnectBackend(ws) + return sql_backend diff --git a/src/databricks/labs/remorph/helpers/execution_time.py b/src/databricks/labs/remorph/helpers/execution_time.py new file mode 100644 index 0000000000..d7a081d249 --- /dev/null +++ b/src/databricks/labs/remorph/helpers/execution_time.py @@ -0,0 +1,20 @@ +import inspect +import logging +import time +from functools import wraps + +logger = logging.getLogger(__name__) + + +def timeit(func): + @wraps(func) + def timeit_wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + name = inspect.getmodule(func).__name__.split(".")[3].capitalize() + logger.info(f"{name} Took {total_time:.4f} seconds") + return result + + return timeit_wrapper diff --git a/src/databricks/labs/remorph/helpers/file_utils.py b/src/databricks/labs/remorph/helpers/file_utils.py new file mode 100644 index 0000000000..59185246ff --- /dev/null +++ b/src/databricks/labs/remorph/helpers/file_utils.py @@ -0,0 +1,104 @@ +import codecs +from pathlib import Path +from collections.abc import Generator + + +# Optionally check to see if a string begins with a Byte Order Mark +# such a character will cause the transpiler to fail +def remove_bom(input_string: str) -> str: + """ + Removes the Byte Order Mark (BOM) from the given string if it exists. + :param input_string: String to remove BOM from + :return: String without BOM + """ + output_string = input_string + + # Check and remove UTF-16 (LE and BE) BOM + if input_string.startswith(codecs.BOM_UTF16_BE.decode("utf-16-be")): + output_string = input_string[len(codecs.BOM_UTF16_BE.decode("utf-16-be")) :] + elif input_string.startswith(codecs.BOM_UTF16_LE.decode("utf-16-le")): + output_string = input_string[len(codecs.BOM_UTF16_LE.decode("utf-16-le")) :] + elif input_string.startswith(codecs.BOM_UTF16.decode("utf-16")): + output_string = input_string[len(codecs.BOM_UTF16.decode("utf-16")) :] + # Check and remove UTF-32 (LE and BE) BOM + elif input_string.startswith(codecs.BOM_UTF32_BE.decode("utf-32-be")): + output_string = input_string[len(codecs.BOM_UTF32_BE.decode("utf-32-be")) :] + elif input_string.startswith(codecs.BOM_UTF32_LE.decode("utf-32-le")): + output_string = input_string[len(codecs.BOM_UTF32_LE.decode("utf-32-le")) :] + elif input_string.startswith(codecs.BOM_UTF32.decode("utf-32")): + output_string = input_string[len(codecs.BOM_UTF32.decode("utf-32")) :] + # Check and remove UTF-8 BOM + elif input_string.startswith(codecs.BOM_UTF8.decode("utf-8")): + output_string = input_string[len(codecs.BOM_UTF8.decode("utf-8")) :] + + return output_string + + +def is_sql_file(file: str | Path) -> bool: + """ + Checks if the given file is a SQL file. + + :param file: The name of the file to check. + :return: True if the file is a SQL file (i.e., its extension is either .sql or .ddl), False otherwise. + """ + file_extension = Path(file).suffix + return file_extension.lower() in {".sql", ".ddl"} + + +def make_dir(path: str | Path) -> None: + """ + Creates a directory at the specified path if it does not already exist. + + :param path: The path where the directory should be created. + """ + Path(path).mkdir(parents=True, exist_ok=True) + + +def dir_walk(root: Path): + """ + Walks the directory tree rooted at the given path, yielding a tuple containing the root directory, a list of + :param root: Path + :return: tuple of root, subdirectory , files + """ + sub_dirs = [d for d in root.iterdir() if d.is_dir()] + files = [f for f in root.iterdir() if f.is_file()] + yield root, sub_dirs, files + + for each_dir in sub_dirs: + yield from dir_walk(each_dir) + + +def get_sql_file(input_path: str | Path) -> Generator[Path, None, None]: + """ + Returns Generator that yields the names of all SQL files in the given directory. + :param input_path: Path + :return: List of SQL files + """ + for _, _, files in dir_walk(Path(input_path)): + for filename in files: + if is_sql_file(filename): + yield filename + + +def read_file(filename: str | Path) -> str: + """ + Reads the contents of the given file and returns it as a string. + :param filename: Input File Path + :return: File Contents as String + """ + # pylint: disable=unspecified-encoding + with Path(filename).open() as file: + return file.read() + + +def refactor_hexadecimal_chars(input_string: str) -> str: + """ + Updates the HexaDecimal characters ( \x1b[\\d+m ) in the given string as below. + :param input_string: String with HexaDecimal characters. ex: ( \x1b[4mWHERE\x1b[0m ) + :return: String with HexaDecimal characters refactored to arrows. ex: ( --> WHERE <--) + """ + output_string = input_string + highlight = {"\x1b[4m": "--> ", "\x1b[0m": " <--"} + for key, value in highlight.items(): + output_string = output_string.replace(key, value) + return output_string diff --git a/src/databricks/labs/remorph/helpers/metastore.py b/src/databricks/labs/remorph/helpers/metastore.py new file mode 100644 index 0000000000..1e27136e6a --- /dev/null +++ b/src/databricks/labs/remorph/helpers/metastore.py @@ -0,0 +1,164 @@ +import functools +import logging +from itertools import chain + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.sdk.service.catalog import ( + CatalogInfo, + Privilege, + SchemaInfo, + SecurableType, + VolumeInfo, + VolumeType, +) + +logger = logging.getLogger(__name__) + + +class CatalogOperations: + def __init__(self, ws: WorkspaceClient): + self._ws = ws + + def get_catalog(self, name: str) -> CatalogInfo | None: + try: + return self._ws.catalogs.get(name) + except NotFound: + return None + + def get_schema(self, catalog_name: str, schema_name: str) -> SchemaInfo | None: + try: + return self._ws.schemas.get(f"{catalog_name}.{schema_name}") + except NotFound: + return None + + def get_volume(self, catalog: str, schema: str, name: str) -> VolumeInfo | None: + try: + return self._ws.volumes.read(f"{catalog}.{schema}.{name}") + except NotFound: + return None + + def create_catalog(self, name: str) -> CatalogInfo: + logger.debug(f"Creating catalog `{name}`.") + catalog_info = self._ws.catalogs.create(name) + logger.info(f"Created catalog `{name}`.") + return catalog_info + + def create_schema(self, schema_name: str, catalog_name: str) -> SchemaInfo: + logger.debug(f"Creating schema `{schema_name}` in catalog `{catalog_name}`.") + schema_info = self._ws.schemas.create(schema_name, catalog_name) + logger.info(f"Created schema `{schema_name}` in catalog `{catalog_name}`.") + return schema_info + + def create_volume( + self, + catalog: str, + schema: str, + name: str, + volume_type: VolumeType = VolumeType.MANAGED, + ) -> VolumeInfo: + logger.debug(f"Creating volume `{name}` in catalog `{catalog}` and schema `{schema}`") + volume_info = self._ws.volumes.create(catalog, schema, name, volume_type) + logger.info(f"Created volume `{name}` in catalog `{catalog}` and schema `{schema}`") + return volume_info + + def has_catalog_access( + self, + catalog: CatalogInfo, + user_name: str, + privilege_sets: tuple[set[Privilege], ...], + ) -> bool: + """ + Check if a user has access to a catalog based on ownership or a set of privileges. + :param catalog: A catalog to check access for. + :param user_name: Username to check. + :param privilege_sets: A tuple of sets, where each set contains Privilege objects. + The function checks if the user has any of these sets of privileges. For example: + ({Privilege.ALL_PRIVILEGES}, {Privilege.USE_CATALOG, Privilege.APPLY_TAG}) + In this case, the user would need either ALL_PRIVILEGES, + or both USE_CATALOG and APPLY_TAG. + """ + if user_name == catalog.owner: + return True + + return any( + self.has_privileges(user_name, SecurableType.CATALOG, catalog.name, privilege_set) + for privilege_set in privilege_sets + ) + + def has_schema_access( + self, + schema: SchemaInfo, + user_name: str, + privilege_sets: tuple[set[Privilege], ...], + ) -> bool: + """ + Check if a user has access to a schema based on ownership or a set of privileges. + :param schema: A schema to check access for. + :param user_name: Username to check. + :param privilege_sets: The function checks if the user has any of these sets of privileges. For example: + ({Privilege.ALL_PRIVILEGES}, {Privilege.USE_SCHEMA, Privilege.CREATE_TABLE}) + In this case, the user would need either ALL_PRIVILEGES, + or both USE_SCHEMA and CREATE_TABLE. + """ + if user_name == schema.owner: + return True + + return any( + self.has_privileges(user_name, SecurableType.SCHEMA, schema.full_name, privilege_set) + for privilege_set in privilege_sets + ) + + def has_volume_access( + self, + volume: VolumeInfo, + user_name: str, + privilege_sets: tuple[set[Privilege], ...], + ) -> bool: + """ + Check if a user has access to a volume based on ownership or a set of privileges. + :param volume: A volume to check access for. + :param user_name: Username to check. + :param privilege_sets: The function checks if the user has any of these sets of privileges. For example: + ({Privilege.ALL_PRIVILEGES}, {Privilege.READ_VOLUME, Privilege.WRITE_VOLUME}) + In this case, the user would need either ALL_PRIVILEGES, + or both READ_VOLUME and WRITE_VOLUME. + """ + if user_name == volume.owner: + return True + + return any( + self.has_privileges(user_name, SecurableType.VOLUME, volume.full_name, privilege_set) + for privilege_set in privilege_sets + ) + + def has_privileges( + self, + user: str | None, + securable_type: SecurableType, + full_name: str | None, + privileges: set[Privilege], + ) -> bool: + """ + Check if a user has a set of privileges for a securable object. + """ + assert user, "User must be provided" + assert full_name, "Full name must be provided" + user_privileges = self._get_user_privileges(user, securable_type, full_name) + result = privileges.issubset(user_privileges) + if not result: + logger.debug(f"User {user} doesn't have privilege set {privileges} for {securable_type} {full_name}") + return result + + @functools.lru_cache(maxsize=1024) + def _get_user_privileges(self, user: str, securable_type: SecurableType, full_name: str) -> set[Privilege]: + permissions = self._ws.grants.get_effective(securable_type, full_name, principal=user) + if not permissions or not permissions.privilege_assignments: + return set() + return { + p.privilege + for p in chain.from_iterable( + privilege.privileges for privilege in permissions.privilege_assignments if privilege.privileges + ) + if p.privilege + } diff --git a/src/databricks/labs/remorph/helpers/recon_config_utils.py b/src/databricks/labs/remorph/helpers/recon_config_utils.py new file mode 100644 index 0000000000..82a5a7015c --- /dev/null +++ b/src/databricks/labs/remorph/helpers/recon_config_utils.py @@ -0,0 +1,174 @@ +import logging + +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.remorph.reconcile.constants import ReconSourceType +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors.platform import ResourceDoesNotExist + +logger = logging.getLogger(__name__) + + +class ReconConfigPrompts: + def __init__(self, ws: WorkspaceClient, prompts: Prompts = Prompts()): + self._source = None + self._prompts = prompts + self._ws = ws + + def _scope_exists(self, scope_name: str) -> bool: + scope_exists = scope_name in [scope.name for scope in self._ws.secrets.list_scopes()] + + if not scope_exists: + logger.error( + f"Error: Cannot find Secret Scope: `{scope_name}` in Databricks Workspace." + f"\nUse `remorph configure-secrets` to setup Scope and Secrets" + ) + return False + logger.debug(f"Found Scope: `{scope_name}` in Databricks Workspace") + return True + + def _ensure_scope_exists(self, scope_name: str): + """ + Get or Create a new Scope in Databricks Workspace + :param scope_name: + """ + scope_exists = self._scope_exists(scope_name) + if not scope_exists: + allow_scope_creation = self._prompts.confirm("Do you want to create a new one?") + if not allow_scope_creation: + msg = "Scope is needed to store Secrets in Databricks Workspace" + raise SystemExit(msg) + + try: + logger.debug(f" Creating a new Scope: `{scope_name}`") + self._ws.secrets.create_scope(scope_name) + except Exception as ex: + logger.error(f"Exception while creating Scope `{scope_name}`: {ex}") + raise ex + + logger.info(f" Created a new Scope: `{scope_name}`") + logger.info(f" Using Scope: `{scope_name}`...") + + def _secret_key_exists(self, scope_name: str, secret_key: str) -> bool: + try: + self._ws.secrets.get_secret(scope_name, secret_key) + logger.info(f"Found Secret key `{secret_key}` in Scope `{scope_name}`") + return True + except ResourceDoesNotExist: + logger.debug(f"Secret key `{secret_key}` not found in Scope `{scope_name}`") + return False + + def _store_secret(self, scope_name: str, secret_key: str, secret_value: str): + try: + logger.debug(f"Storing Secret: *{secret_key}* in Scope: `{scope_name}`") + self._ws.secrets.put_secret(scope=scope_name, key=secret_key, string_value=secret_value) + except Exception as ex: + logger.error(f"Exception while storing Secret `{secret_key}`: {ex}") + raise ex + + def store_connection_secrets(self, scope_name: str, conn_details: tuple[str, dict[str, str]]): + engine = conn_details[0] + secrets = conn_details[1] + + logger.debug(f"Storing `{engine}` Connection Secrets in Scope: `{scope_name}`") + + for key, value in secrets.items(): + secret_key = key + logger.debug(f"Processing Secret: *{secret_key}*") + debug_op = "Storing" + info_op = "Stored" + if self._secret_key_exists(scope_name, secret_key): + overwrite_secret = self._prompts.confirm(f"Do you want to overwrite `{secret_key}`?") + if not overwrite_secret: + continue + debug_op = "Overwriting" + info_op = "Overwritten" + + logger.debug(f"{debug_op} Secret: *{secret_key}* in Scope: `{scope_name}`") + self._store_secret(scope_name, secret_key, value) + logger.info(f"{info_op} Secret: *{secret_key}* in Scope: `{scope_name}`") + + def prompt_source(self): + source = self._prompts.choice("Select the source", [source_type.value for source_type in ReconSourceType]) + self._source = source + return source + + def _prompt_snowflake_connection_details(self) -> tuple[str, dict[str, str]]: + """ + Prompt for Snowflake connection details + :return: tuple[str, dict[str, str]] + """ + logger.info( + f"Please answer a couple of questions to configure `{ReconSourceType.SNOWFLAKE.value}` Connection profile" + ) + + sf_url = self._prompts.question("Enter Snowflake URL") + account = self._prompts.question("Enter Account Name") + sf_user = self._prompts.question("Enter User") + sf_password = self._prompts.question("Enter Password") + sf_db = self._prompts.question("Enter Database") + sf_schema = self._prompts.question("Enter Schema") + sf_warehouse = self._prompts.question("Enter Snowflake Warehouse") + sf_role = self._prompts.question("Enter Role", default=" ") + + sf_conn_details = { + "sfUrl": sf_url, + "account": account, + "sfUser": sf_user, + "sfPassword": sf_password, + "sfDatabase": sf_db, + "sfSchema": sf_schema, + "sfWarehouse": sf_warehouse, + "sfRole": sf_role, + } + + sf_conn_dict = (ReconSourceType.SNOWFLAKE.value, sf_conn_details) + return sf_conn_dict + + def _prompt_oracle_connection_details(self) -> tuple[str, dict[str, str]]: + """ + Prompt for Oracle connection details + :return: tuple[str, dict[str, str]] + """ + logger.info( + f"Please answer a couple of questions to configure `{ReconSourceType.ORACLE.value}` Connection profile" + ) + user = self._prompts.question("Enter User") + password = self._prompts.question("Enter Password") + host = self._prompts.question("Enter host") + port = self._prompts.question("Enter port") + database = self._prompts.question("Enter database/SID") + + oracle_conn_details = {"user": user, "password": password, "host": host, "port": port, "database": database} + + oracle_conn_dict = (ReconSourceType.ORACLE.value, oracle_conn_details) + return oracle_conn_dict + + def _connection_details(self): + """ + Prompt for connection details based on the source + :return: None + """ + logger.debug(f"Prompting for `{self._source}` connection details") + match self._source: + case ReconSourceType.SNOWFLAKE.value: + return self._prompt_snowflake_connection_details() + case ReconSourceType.ORACLE.value: + return self._prompt_oracle_connection_details() + + def prompt_and_save_connection_details(self): + """ + Prompt for connection details and save them as Secrets in Databricks Workspace + """ + # prompt for connection_details only if source is other than Databricks + if self._source == ReconSourceType.DATABRICKS.value: + logger.info("*Databricks* as a source is supported only for **Hive MetaStore (HMS) setup**") + return + + # Prompt for secret scope + scope_name = self._prompts.question("Enter Secret Scope name") + self._ensure_scope_exists(scope_name) + + # Prompt for connection details + connection_details = self._connection_details() + logger.debug(f"Storing `{self._source}` connection details as Secrets in Databricks Workspace...") + self.store_connection_secrets(scope_name, connection_details) diff --git a/src/databricks/labs/remorph/helpers/validation.py b/src/databricks/labs/remorph/helpers/validation.py new file mode 100644 index 0000000000..af6d473fdc --- /dev/null +++ b/src/databricks/labs/remorph/helpers/validation.py @@ -0,0 +1,101 @@ +import logging +from io import StringIO + +from databricks.labs.lsql.backends import SqlBackend +from databricks.labs.remorph.config import TranspileConfig, ValidationResult +from databricks.sdk.errors.base import DatabricksError + +logger = logging.getLogger(__name__) + + +class Validator: + """ + The Validator class is used to validate SQL queries. + """ + + def __init__(self, sql_backend: SqlBackend): + self._sql_backend = sql_backend + + def validate_format_result(self, config: TranspileConfig, input_sql: str) -> ValidationResult: + """ + Validates the SQL query and formats the result. + + This function validates the SQL query based on the provided configuration. If the query is valid, + it appends a semicolon to the end of the query. If the query is not valid, it formats the error message. + + Parameters: + - config (MorphConfig): The configuration for the validation. + - input_sql (str): The SQL query to be validated. + + Returns: + - tuple: A tuple containing the result of the validation and the exception message (if any). + """ + logger.debug(f"Validating query with catalog {config.catalog_name} and schema {config.schema_name}") + (is_valid, exception_type, exception_msg) = self._query( + self._sql_backend, + input_sql, + config.catalog_name, + config.schema_name, + ) + if is_valid: + result = input_sql + "\n;\n" + if exception_type is not None: + exception_msg = f"[{exception_type.upper()}]: {exception_msg}" + else: + query = "" + if "[UNRESOLVED_ROUTINE]" in str(exception_msg): + query = input_sql + buffer = StringIO() + buffer.write("-------------- Exception Start-------------------\n") + buffer.write("/* \n") + buffer.write(str(exception_msg)) + buffer.write("\n */ \n") + buffer.write(query) + buffer.write("\n ---------------Exception End --------------------\n") + + result = buffer.getvalue() + + return ValidationResult(result, exception_msg) + + def _query( + self, sql_backend: SqlBackend, query: str, catalog: str, schema: str + ) -> tuple[bool, str | None, str | None]: + """ + Validate a given SQL query using the provided SQL backend + + Parameters: + - query (str): The SQL query to be validated. + - sql_backend (SqlBackend): The SQL backend to be used for validation. + + Returns: + - tuple: A tuple containing a boolean indicating whether the query is valid or not, + and a string containing a success message or an exception message. + """ + # When variables is mentioned Explain fails we need way to replace them before explain is executed. + explain_query = f'EXPLAIN {query.replace("${", "`{").replace("}", "}`").replace("``", "`")}' + try: + rows = list(sql_backend.fetch(explain_query, catalog=catalog, schema=schema)) + if not rows: + return False, "error", "No results returned from explain query." + + if "Error occurred during query planning" in rows[0].asDict().get("plan", ""): + error_details = rows[1].asDict().get("plan", "Unknown error.") if len(rows) > 1 else "Unknown error." + raise DatabricksError(error_details) + return True, None, None + except DatabricksError as dbe: + err_msg = str(dbe) + if "[PARSE_SYNTAX_ERROR]" in err_msg: + logger.debug(f"Syntax Exception : NOT IGNORED. Flag as syntax error: {err_msg}") + return False, "error", err_msg + if "[UNRESOLVED_ROUTINE]" in err_msg: + logger.debug(f"Analysis Exception : NOT IGNORED: Flag as Function Missing error {err_msg}") + return False, "error", err_msg + if "[TABLE_OR_VIEW_NOT_FOUND]" in err_msg or "[TABLE_OR_VIEW_ALREADY_EXISTS]" in err_msg: + logger.debug(f"Analysis Exception : IGNORED: {err_msg}") + return True, "warning", err_msg + if "Hive support is required to CREATE Hive TABLE (AS SELECT).;" in err_msg: + logger.debug(f"Analysis Exception : IGNORED: {err_msg}") + return True, "warning", err_msg + + logger.debug(f"Unknown Exception: {err_msg}") + return False, "error", err_msg diff --git a/src/databricks/labs/remorph/install.py b/src/databricks/labs/remorph/install.py new file mode 100644 index 0000000000..3402ee6a80 --- /dev/null +++ b/src/databricks/labs/remorph/install.py @@ -0,0 +1,299 @@ +import dataclasses +import logging +import os +import webbrowser + +from databricks.labs.blueprint.entrypoint import get_logger, is_in_debug +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installation import SerdeError +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, PermissionDenied + +from databricks.labs.remorph.__about__ import __version__ +from databricks.labs.remorph.config import ( + TranspileConfig, + ReconcileConfig, + SQLGLOT_DIALECTS, + DatabaseConfig, + RemorphConfigs, + ReconcileMetadataConfig, +) +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.deployment.configurator import ResourceConfigurator +from databricks.labs.remorph.deployment.installation import WorkspaceInstallation +from databricks.labs.remorph.reconcile.constants import ReconReportType, ReconSourceType + +logger = logging.getLogger(__name__) + +TRANSPILER_WAREHOUSE_PREFIX = "Remorph Transpiler Validation" +MODULES = sorted({"transpile", "reconcile", "all"}) + + +class WorkspaceInstaller: + def __init__( + self, + ws: WorkspaceClient, + prompts: Prompts, + installation: Installation, + install_state: InstallState, + product_info: ProductInfo, + resource_configurator: ResourceConfigurator, + workspace_installation: WorkspaceInstallation, + environ: dict[str, str] | None = None, + ): + self._ws = ws + self._prompts = prompts + self._installation = installation + self._install_state = install_state + self._product_info = product_info + self._resource_configurator = resource_configurator + self._ws_installation = workspace_installation + + if not environ: + environ = dict(os.environ.items()) + + if "DATABRICKS_RUNTIME_VERSION" in environ: + msg = "WorkspaceInstaller is not supposed to be executed in Databricks Runtime" + raise SystemExit(msg) + + def run( + self, + config: RemorphConfigs | None = None, + ) -> RemorphConfigs: + logger.info(f"Installing Remorph v{self._product_info.version()}") + if not config: + config = self.configure() + if self._is_testing(): + return config + self._ws_installation.install(config) + logger.info("Installation completed successfully! Please refer to the documentation for the next steps.") + return config + + def configure(self, module: str | None = None) -> RemorphConfigs: + selected_module = module or self._prompts.choice("Select a module to configure:", MODULES) + match selected_module: + case "transpile": + logger.info("Configuring remorph `transpile`.") + return RemorphConfigs(self._configure_transpile(), None) + case "reconcile": + logger.info("Configuring remorph `reconcile`.") + return RemorphConfigs(None, self._configure_reconcile()) + case "all": + logger.info("Configuring remorph `transpile` and `reconcile`.") + return RemorphConfigs( + self._configure_transpile(), + self._configure_reconcile(), + ) + case _: + raise ValueError(f"Invalid input: {selected_module}") + + def _is_testing(self): + return self._product_info.product_name() != "remorph" + + def _configure_transpile(self) -> TranspileConfig: + try: + self._installation.load(TranspileConfig) + logger.info("Remorph `transpile` is already installed on this workspace.") + if not self._prompts.confirm("Do you want to override the existing installation?"): + raise SystemExit( + "Remorph `transpile` is already installed and no override has been requested. Exiting..." + ) + except NotFound: + logger.info("Couldn't find existing `transpile` installation") + except (PermissionDenied, SerdeError, ValueError, AttributeError): + install_dir = self._installation.install_folder() + logger.warning( + f"Existing `transpile` installation at {install_dir} is corrupted. Continuing new installation..." + ) + + config = self._configure_new_transpile_installation() + logger.info("Finished configuring remorph `transpile`.") + return config + + def _configure_new_transpile_installation(self) -> TranspileConfig: + default_config = self._prompt_for_new_transpile_installation() + runtime_config = None + catalog_name = "remorph" + schema_name = "transpiler" + if not default_config.skip_validation: + catalog_name = self._configure_catalog() + schema_name = self._configure_schema(catalog_name, "transpile") + self._has_necessary_access(catalog_name, schema_name) + runtime_config = self._configure_runtime() + + config = dataclasses.replace( + default_config, + catalog_name=catalog_name, + schema_name=schema_name, + sdk_config=runtime_config, + ) + self._save_config(config) + return config + + def _prompt_for_new_transpile_installation(self) -> TranspileConfig: + logger.info("Please answer a few questions to configure remorph `transpile`") + source = self._prompts.choice("Select the source:", list(SQLGLOT_DIALECTS.keys())) + input_sql = self._prompts.question("Enter input SQL path (directory/file)") + output_folder = self._prompts.question("Enter output directory", default="transpiled") + run_validation = self._prompts.confirm( + "Would you like to validate the syntax and semantics of the transpiled queries?" + ) + + return TranspileConfig( + source_dialect=source, + skip_validation=(not run_validation), + mode="current", # mode will not have a prompt as this is a hidden flag + input_source=input_sql, + output_folder=output_folder, + ) + + def _configure_catalog( + self, + ) -> str: + return self._resource_configurator.prompt_for_catalog_setup() + + def _configure_schema( + self, + catalog: str, + default_schema_name: str, + ) -> str: + return self._resource_configurator.prompt_for_schema_setup( + catalog, + default_schema_name, + ) + + def _configure_runtime(self) -> dict[str, str]: + if self._prompts.confirm("Do you want to use SQL Warehouse for validation?"): + warehouse_id = self._resource_configurator.prompt_for_warehouse_setup(TRANSPILER_WAREHOUSE_PREFIX) + return {"warehouse_id": warehouse_id} + + if self._ws.config.cluster_id: + logger.info(f"Using cluster {self._ws.config.cluster_id} for validation") + return {"cluster_id": self._ws.config.cluster_id} + + cluster_id = self._prompts.question("Enter a valid cluster_id to proceed") + return {"cluster_id": cluster_id} + + def _configure_reconcile(self) -> ReconcileConfig: + try: + self._installation.load(ReconcileConfig) + logger.info("Remorph `reconcile` is already installed on this workspace.") + if not self._prompts.confirm("Do you want to override the existing installation?"): + raise SystemExit( + "Remorph `reconcile` is already installed and no override has been requested. Exiting..." + ) + except NotFound: + logger.info("Couldn't find existing `reconcile` installation") + except (PermissionDenied, SerdeError, ValueError, AttributeError): + install_dir = self._installation.install_folder() + logger.warning( + f"Existing `reconcile` installation at {install_dir} is corrupted. Continuing new installation..." + ) + + config = self._configure_new_reconcile_installation() + logger.info("Finished configuring remorph `reconcile`.") + return config + + def _configure_new_reconcile_installation(self) -> ReconcileConfig: + default_config = self._prompt_for_new_reconcile_installation() + self._save_config(default_config) + return default_config + + def _prompt_for_new_reconcile_installation(self) -> ReconcileConfig: + logger.info("Please answer a few questions to configure remorph `reconcile`") + data_source = self._prompts.choice( + "Select the Data Source:", [source_type.value for source_type in ReconSourceType] + ) + report_type = self._prompts.choice( + "Select the report type:", [report_type.value for report_type in ReconReportType] + ) + scope_name = self._prompts.question( + f"Enter Secret scope name to store `{data_source.capitalize()}` connection details / secrets", + default=f"remorph_{data_source}", + ) + + db_config = self._prompt_for_reconcile_database_config(data_source) + metadata_config = self._prompt_for_reconcile_metadata_config() + + return ReconcileConfig( + data_source=data_source, + report_type=report_type, + secret_scope=scope_name, + database_config=db_config, + metadata_config=metadata_config, + ) + + def _prompt_for_reconcile_database_config(self, source) -> DatabaseConfig: + source_catalog = None + if source == ReconSourceType.SNOWFLAKE.value: + source_catalog = self._prompts.question(f"Enter source catalog name for `{source.capitalize()}`") + + schema_prompt = f"Enter source schema name for `{source.capitalize()}`" + if source in {ReconSourceType.ORACLE.value}: + schema_prompt = f"Enter source database name for `{source.capitalize()}`" + + source_schema = self._prompts.question(schema_prompt) + target_catalog = self._prompts.question("Enter target catalog name for Databricks") + target_schema = self._prompts.question("Enter target schema name for Databricks") + + return DatabaseConfig( + source_schema=source_schema, + target_catalog=target_catalog, + target_schema=target_schema, + source_catalog=source_catalog, + ) + + def _prompt_for_reconcile_metadata_config(self) -> ReconcileMetadataConfig: + logger.info("Configuring reconcile metadata.") + catalog = self._configure_catalog() + schema = self._configure_schema( + catalog, + "reconcile", + ) + volume = self._configure_volume(catalog, schema, "reconcile_volume") + self._has_necessary_access(catalog, schema, volume) + return ReconcileMetadataConfig(catalog=catalog, schema=schema, volume=volume) + + def _configure_volume( + self, + catalog: str, + schema: str, + default_volume_name: str, + ) -> str: + return self._resource_configurator.prompt_for_volume_setup( + catalog, + schema, + default_volume_name, + ) + + def _save_config(self, config: TranspileConfig | ReconcileConfig): + logger.info(f"Saving configuration file {config.__file__}") + self._installation.save(config) + ws_file_url = self._installation.workspace_link(config.__file__) + if self._prompts.confirm(f"Open config file {ws_file_url} in the browser?"): + webbrowser.open(ws_file_url) + + def _has_necessary_access(self, catalog_name: str, schema_name: str, volume_name: str | None = None): + self._resource_configurator.has_necessary_access(catalog_name, schema_name, volume_name) + + +if __name__ == "__main__": + logger = get_logger(__file__) + logger.setLevel("INFO") + if is_in_debug(): + logging.getLogger("databricks").setLevel(logging.DEBUG) + + app_context = ApplicationContext(WorkspaceClient(product="remorph", product_version=__version__)) + installer = WorkspaceInstaller( + app_context.workspace_client, + app_context.prompts, + app_context.installation, + app_context.install_state, + app_context.product_info, + app_context.resource_configurator, + app_context.workspace_installation, + ) + installer.run() diff --git a/src/databricks/labs/remorph/intermediate/__init__.py b/src/databricks/labs/remorph/intermediate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/intermediate/dag.py b/src/databricks/labs/remorph/intermediate/dag.py new file mode 100644 index 0000000000..18c4cfd822 --- /dev/null +++ b/src/databricks/labs/remorph/intermediate/dag.py @@ -0,0 +1,88 @@ +import logging + +logger = logging.getLogger(__name__) + + +class Node: + def __init__(self, name: str): + self.name = name.lower() + self.children: list[str] = [] + self.parents: list[str] = [] + + def add_parent(self, node: str) -> None: + self.parents.append(node) + + def add_child(self, node: str) -> None: + self.children.append(node) + + def __repr__(self) -> str: + return f"Node({self.name}, {self.children})" + + +class DAG: + def __init__(self): + self.nodes: dict[str, Node] = {} + + def add_node(self, node_name: str) -> None: + if node_name not in self.nodes and node_name not in {None, "none"}: + self.nodes[node_name.lower()] = Node(node_name.lower()) + + def add_edge(self, parent_name: str, child_name: str) -> None: + parent_name = parent_name.lower() if parent_name is not None else None + child_name = child_name.lower() if child_name is not None else None + logger.debug(f"Adding edge: {parent_name} -> {child_name}") + if parent_name not in self.nodes: + self.add_node(parent_name) + if child_name not in self.nodes: + self.add_node(child_name) + + if child_name is not None: + self.nodes[parent_name].add_child(child_name) + self.nodes[child_name].add_parent(parent_name) + + def identify_immediate_parents(self, table_name: str) -> list[str]: + table_name = table_name.lower() # convert to lower() case + if table_name not in self.nodes: + logger.debug(f"Table with the name {table_name} not found in the DAG") + return [] + + return list(self.nodes[table_name].parents) + + def identify_immediate_children(self, table_name: str) -> list[str]: + table_name = table_name.lower() # convert to lower() case + if table_name not in self.nodes: + logger.debug(f"Table with the name {table_name} not found in the DAG") + return [] + + return list(self.nodes[table_name].children) + + def _is_root_node(self, node_name: str) -> bool: + return len(self.identify_immediate_parents(node_name)) == 0 + + def walk_bfs(self, node: Node, level: int) -> set: + tables_at_level = set() + queue = [(node, 0)] # The queue for the BFS. Each element is a tuple (node, level). + while queue: + current_node, node_level = queue.pop(0) + + if node_level == level: + tables_at_level.add(current_node.name) + elif node_level > level: + break + + for child_name in self.identify_immediate_children(current_node.name): + queue.append((self.nodes[child_name], node_level + 1)) + return tables_at_level + + def identify_root_tables(self, level: int) -> set: + all_nodes = set(self.nodes.values()) + root_tables_at_level = set() + + for node in all_nodes: + if self._is_root_node(node.name): + root_tables_at_level.update(self.walk_bfs(node, level)) + + return root_tables_at_level + + def __repr__(self) -> str: + return str({node_name: str(node) for node_name, node in self.nodes.items()}) diff --git a/src/databricks/labs/remorph/intermediate/engine_adapter.py b/src/databricks/labs/remorph/intermediate/engine_adapter.py new file mode 100644 index 0000000000..3cb7c16339 --- /dev/null +++ b/src/databricks/labs/remorph/intermediate/engine_adapter.py @@ -0,0 +1,27 @@ +import logging +from pathlib import Path + +from sqlglot.dialects.dialect import Dialect + +from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine + +logger = logging.getLogger(__name__) + + +class EngineAdapter: + def __init__(self, dialect: Dialect): + self.dialect = dialect + + def select_engine(self, input_type: str): + if input_type.lower() not in {"sqlglot"}: + msg = f"Unsupported input type: {input_type}" + logger.error(msg) + raise ValueError(msg) + return SqlglotEngine(self.dialect) + + def parse_sql_content(self, dag, sql_content: str, file_name: str | Path, engine: str): + # Not added type hints for dag as it is a cyclic import + parser = self.select_engine(engine) + for root_table, child in parser.parse_sql_content(sql_content, file_name): + dag.add_node(child) + dag.add_edge(root_table, child) diff --git a/src/databricks/labs/remorph/intermediate/root_tables.py b/src/databricks/labs/remorph/intermediate/root_tables.py new file mode 100644 index 0000000000..25d639a60f --- /dev/null +++ b/src/databricks/labs/remorph/intermediate/root_tables.py @@ -0,0 +1,39 @@ +import logging +from pathlib import Path + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.helpers.file_utils import ( + get_sql_file, + is_sql_file, + read_file, +) +from databricks.labs.remorph.intermediate.dag import DAG +from databricks.labs.remorph.intermediate.engine_adapter import EngineAdapter + +logger = logging.getLogger(__name__) + + +class RootTableIdentifier: + def __init__(self, source_dialect: str, input_path: str | Path): + self.source_dialect = source_dialect + self.input_path = input_path + self.engine_adapter = EngineAdapter(get_dialect(source_dialect)) + + def generate_lineage(self, engine="sqlglot") -> DAG: + dag = DAG() + + # when input is sql file then parse the file + if is_sql_file(self.input_path): + filename = self.input_path + logger.debug(f"Generating Lineage file: {filename}") + sql_content = read_file(filename) + self.engine_adapter.parse_sql_content(dag, sql_content, filename, engine) + return dag # return after processing the file + + # when the input is a directory + for filename in get_sql_file(self.input_path): + logger.debug(f"Generating Lineage file: {filename}") + sql_content = read_file(filename) + self.engine_adapter.parse_sql_content(dag, sql_content, filename, engine) + + return dag diff --git a/src/databricks/labs/remorph/jvmproxy.py b/src/databricks/labs/remorph/jvmproxy.py new file mode 100644 index 0000000000..2887390914 --- /dev/null +++ b/src/databricks/labs/remorph/jvmproxy.py @@ -0,0 +1,55 @@ +import logging +import os +import sys +import subprocess + +from databricks.labs.blueprint.entrypoint import find_project_root +from databricks.labs.blueprint.cli import App + + +def proxy_command(app: App, command: str): + def fn(**_): + proxy = JvmProxy() + proxy.run() + + fn.__name__ = command + fn.__doc__ = f"Proxy to run {command} in JVM" + app.command(fn) + + +class JvmProxy: + def __init__(self): + self._root = find_project_root(__file__) + databricks_logger = logging.getLogger("databricks") + self._debug = databricks_logger.level == logging.DEBUG + + def _recompile(self): + subprocess.run( + ["mvn", "compile", "-f", f'{self._root}/pom.xml'], + stdout=sys.stdout, + stderr=sys.stderr, + check=True, + ) + + def run(self): + if self._debug: + self._recompile() + classpath = self._root / 'core/target/classpath.txt' + classes = self._root / 'core/target/scala-2.12/classes' + # TODO: use the os-specific path separator + args = [ + "java", + "--class-path", + f'{classes.as_posix()}:{classpath.read_text()}', + "com.databricks.labs.remorph.Main", + sys.argv[1], + ] + with subprocess.Popen( + args, + stdin=sys.stdin, + stdout=sys.stdout, + stderr=sys.stderr, + env=os.environ.copy(), + text=True, + ) as process: + return process.wait() diff --git a/src/databricks/labs/remorph/lineage.py b/src/databricks/labs/remorph/lineage.py new file mode 100644 index 0000000000..e32630d79c --- /dev/null +++ b/src/databricks/labs/remorph/lineage.py @@ -0,0 +1,41 @@ +import datetime +import logging +from pathlib import Path + +from databricks.labs.remorph.intermediate.dag import DAG +from databricks.labs.remorph.intermediate.root_tables import RootTableIdentifier + +logger = logging.getLogger(__name__) + + +def _generate_dot_file_contents(dag: DAG) -> str: + _lineage_str = "flowchart TD\n" + for node_name, node in dag.nodes.items(): + if node.parents: + for parent in node.parents: + _lineage_str += f" {node_name.capitalize()} --> {parent.capitalize()}\n" + else: + # Include nodes without parents to ensure they appear in the diagram + _lineage_str += f" {node_name.capitalize()}\n" + return _lineage_str + + +def lineage_generator(source: str, input_sql: str, output_folder: str): + input_sql_path = Path(input_sql) + output_folder = output_folder if output_folder.endswith('/') else output_folder + '/' + + msg = f"Processing for SQLs at this location: {input_sql_path}" + logger.info(msg) + root_table_identifier = RootTableIdentifier(source, input_sql_path) + generated_dag = root_table_identifier.generate_lineage() + lineage_file_content = _generate_dot_file_contents(generated_dag) + + date_str = datetime.datetime.now().strftime("%d%m%y") + + output_filename = Path(f"{output_folder}lineage_{date_str}.dot") + if output_filename.exists(): + logger.warning(f'The output file already exists and will be replaced: {output_filename}') + logger.info(f"Attempting to write the lineage to {output_filename}") + with output_filename.open('w', encoding='utf-8') as f: + f.write(lineage_file_content) + logger.info(f"Succeeded to write the lineage to {output_filename}") diff --git a/src/databricks/labs/remorph/reconcile/__init__.py b/src/databricks/labs/remorph/reconcile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/reconcile/compare.py b/src/databricks/labs/remorph/reconcile/compare.py new file mode 100644 index 0000000000..f719d4168d --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/compare.py @@ -0,0 +1,402 @@ +import logging +from functools import reduce +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col, expr, lit + +from databricks.labs.remorph.reconcile.exception import ColumnMismatchException +from databricks.labs.remorph.reconcile.recon_capture import ( + ReconIntermediatePersist, +) +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + MismatchOutput, + AggregateRule, + ColumnMapping, +) + +logger = logging.getLogger(__name__) + +_HASH_COLUMN_NAME = "hash_value_recon" +_SAMPLE_ROWS = 50 + + +def raise_column_mismatch_exception(msg: str, source_missing: list[str], target_missing: list[str]) -> Exception: + error_msg = ( + f"{msg}\n" + f"columns missing in source: {','.join(source_missing) if source_missing else None}\n" + f"columns missing in target: {','.join(target_missing) if target_missing else None}\n" + ) + return ColumnMismatchException(error_msg) + + +def _generate_join_condition(source_alias, target_alias, key_columns): + conditions = [ + col(f"{source_alias}.{key_column}").eqNullSafe(col(f"{target_alias}.{key_column}")) + for key_column in key_columns + ] + return reduce(lambda a, b: a & b, conditions) + + +def reconcile_data( + source: DataFrame, + target: DataFrame, + key_columns: list[str], + report_type: str, + spark: SparkSession, + path: str, +) -> DataReconcileOutput: + source_alias = "src" + target_alias = "tgt" + if report_type not in {"data", "all"}: + key_columns = [_HASH_COLUMN_NAME] + df = ( + source.alias(source_alias) + .join( + other=target.alias(target_alias), + on=_generate_join_condition(source_alias, target_alias, key_columns), + how="full", + ) + .selectExpr( + *[f'{source_alias}.{col_name} as {source_alias}_{col_name}' for col_name in source.columns], + *[f'{target_alias}.{col_name} as {target_alias}_{col_name}' for col_name in target.columns], + ) + ) + + # Write unmatched df to volume + df = ReconIntermediatePersist(spark, path).write_and_read_unmatched_df_with_volumes(df) + logger.warning(f"Unmatched data is written to {path} successfully") + + mismatch = _get_mismatch_data(df, source_alias, target_alias) if report_type in {"all", "data"} else None + + missing_in_src = ( + df.filter(col(f"{source_alias}_{_HASH_COLUMN_NAME}").isNull()) + .select( + *[ + col(col_name).alias(col_name.replace(f'{target_alias}_', '').lower()) + for col_name in df.columns + if col_name.startswith(f'{target_alias}_') + ] + ) + .drop(_HASH_COLUMN_NAME) + ) + + missing_in_tgt = ( + df.filter(col(f"{target_alias}_{_HASH_COLUMN_NAME}").isNull()) + .select( + *[ + col(col_name).alias(col_name.replace(f'{source_alias}_', '').lower()) + for col_name in df.columns + if col_name.startswith(f'{source_alias}_') + ] + ) + .drop(_HASH_COLUMN_NAME) + ) + mismatch_count = 0 + if mismatch: + mismatch_count = mismatch.count() + + missing_in_src_count = missing_in_src.count() + missing_in_tgt_count = missing_in_tgt.count() + + return DataReconcileOutput( + mismatch_count=mismatch_count, + missing_in_src_count=missing_in_src_count, + missing_in_tgt_count=missing_in_tgt_count, + missing_in_src=missing_in_src.limit(_SAMPLE_ROWS), + missing_in_tgt=missing_in_tgt.limit(_SAMPLE_ROWS), + mismatch=MismatchOutput(mismatch_df=mismatch), + ) + + +def _get_mismatch_data(df: DataFrame, src_alias: str, tgt_alias: str) -> DataFrame: + return ( + df.filter( + (col(f"{src_alias}_{_HASH_COLUMN_NAME}").isNotNull()) + & (col(f"{tgt_alias}_{_HASH_COLUMN_NAME}").isNotNull()) + ) + .withColumn( + "hash_match", + col(f"{src_alias}_{_HASH_COLUMN_NAME}") == col(f"{tgt_alias}_{_HASH_COLUMN_NAME}"), + ) + .filter(col("hash_match") == lit(False)) + .select( + *[ + col(col_name).alias(col_name.replace(f'{src_alias}_', '').lower()) + for col_name in df.columns + if col_name.startswith(f'{src_alias}_') + ] + ) + .drop(_HASH_COLUMN_NAME) + ) + + +def capture_mismatch_data_and_columns(source: DataFrame, target: DataFrame, key_columns: list[str]) -> MismatchOutput: + source_columns = source.columns + target_columns = target.columns + + if source_columns != target_columns: + message = "source and target should have same columns for capturing the mismatch data" + source_missing = [column for column in target_columns if column not in source_columns] + target_missing = [column for column in source_columns if column not in target_columns] + raise raise_column_mismatch_exception(message, source_missing, target_missing) + + check_columns = [column for column in source_columns if column not in key_columns] + mismatch_df = _get_mismatch_df(source, target, key_columns, check_columns) + mismatch_columns = _get_mismatch_columns(mismatch_df, check_columns) + return MismatchOutput(mismatch_df, mismatch_columns) + + +def _get_mismatch_columns(df: DataFrame, columns: list[str]): + # Collect the DataFrame to a local variable + local_df = df.collect() + mismatch_columns = [] + for column in columns: + # Check if any row has False in the column + if any(not row[column + "_match"] for row in local_df): + mismatch_columns.append(column) + return mismatch_columns + + +def _get_mismatch_df(source: DataFrame, target: DataFrame, key_columns: list[str], column_list: list[str]): + source_aliased = [col('base.' + column).alias(column + '_base') for column in column_list] + target_aliased = [col('compare.' + column).alias(column + '_compare') for column in column_list] + + match_expr = [expr(f"{column}_base=={column}_compare").alias(column + "_match") for column in column_list] + key_cols = [col(column) for column in key_columns] + select_expr = key_cols + source_aliased + target_aliased + match_expr + + filter_columns = " and ".join([column + "_match" for column in column_list]) + filter_expr = ~expr(filter_columns) + + mismatch_df = ( + source.alias('base') + .join(other=target.alias('compare'), on=key_columns, how="inner") + .select(*select_expr) + .filter(filter_expr) + ) + compare_columns = [column for column in mismatch_df.columns if column not in key_columns] + return mismatch_df.select(*key_columns + sorted(compare_columns)) + + +def alias_column_str(alias: str, columns: list[str]) -> list[str]: + return [f"{alias}.{column}" for column in columns] + + +def _generate_agg_join_condition(source_alias: str, target_alias: str, key_columns: list[str]): + join_columns: list[ColumnMapping] = [ + ColumnMapping(source_name=f"source_group_by_{key_col}", target_name=f"target_group_by_{key_col}") + for key_col in key_columns + ] + conditions = [ + col(f"{source_alias}.{mapping.source_name}").eqNullSafe(col(f"{target_alias}.{mapping.target_name}")) + for mapping in join_columns + ] + return reduce(lambda a, b: a & b, conditions) + + +def _agg_conditions( + cols: list[ColumnMapping] | None, + condition_type: str = "group_filter", + op_type: str = "and", +): + """ + Generate conditions for aggregated data comparison based on the condition type + and reduces it based on the operator (and, or) + + e.g., cols = [(source_min_col1, target_min_col1)] + 1. condition_type = "group_filter" + source_group_by_col1 is not null and target_group_by_col1 is not null + 2. condition_type = "select" + source_min_col1 == target_min_col1 + 3. condition_type = "missing_in_src" + source_min_col1 is null + 4. condition_type = "missing_in_tgt" + target_min_col1 is null + + :param cols: List of columns to compare + :param condition_type: Type of condition to generate + :param op_type: and, or + :return: Reduced column expressions + """ + assert cols, "Columns must be specified for aggregation conditions" + + if condition_type == "group_filter": + conditions_list = [ + (col(f"{mapping.source_name}").isNotNull() & col(f"{mapping.target_name}").isNotNull()) for mapping in cols + ] + elif condition_type == "select": + conditions_list = [col(f"{mapping.source_name}") == col(f"{mapping.target_name}") for mapping in cols] + elif condition_type == "missing_in_src": + conditions_list = [col(f"{mapping.source_name}").isNull() for mapping in cols] + elif condition_type == "missing_in_tgt": + conditions_list = [col(f"{mapping.target_name}").isNull() for mapping in cols] + else: + raise ValueError(f"Invalid condition type: {condition_type}") + + return reduce(lambda a, b: a & b if op_type == "and" else a | b, conditions_list) + + +def _generate_match_columns(select_cols: list[ColumnMapping]): + """ + Generate match columns for the given select columns + e.g., select_cols = [(source_min_col1, target_min_col1), (source_count_col3, target_count_col3)] + |--------------------------------------|---------------------| + | match_min_col1 | match_count_col3 | + |--------------------------------------|--------------------| + source_min_col1 == target_min_col1 | source_count_col3 == target_count_col3 + --------------------------------------|---------------------| + + :param select_cols: + :return: + """ + items = [] + for mapping in select_cols: + match_col_name = mapping.source_name.replace("source_", "match_") + items.append((match_col_name, col(f"{mapping.source_name}") == col(f"{mapping.target_name}"))) + return items + + +def _get_mismatch_agg_data( + df: DataFrame, + select_cols: list[ColumnMapping], + group_cols: list[ColumnMapping] | None, +) -> DataFrame: + # TODO: Integrate with _get_mismatch_data function + """ + For each rule select columns, generate a match column to compare the aggregated data between Source and Target + + e.g., select_cols = [(source_min_col1, target_min_col1), (source_count_col3, target_count_col3)] + + source_min_col1 | target_min_col1 | match_min_col1 | agg_data_match | + -----------------|--------------------|----------------|-------------------| + 11 | 12 |source_min_col1 == target_min_col1 | False | + + :param df: Joined DataFrame with aggregated data from Source and Target + :param select_cols: Rule specific select columns + :param group_cols: Rule specific group by columns, if any + :return: DataFrame with match__ and agg_data_match columns + to identify the aggregate data mismatch between Source and Target + """ + df_with_match_cols = df + + if group_cols: + # Filter Conditions are in the format of: source_group_by_col1 is not null and target_group_by_col1 is not null + filter_conditions = _agg_conditions(group_cols) + df_with_match_cols = df_with_match_cols.filter(filter_conditions) + + # Generate match columns for the select columns. e.g., match__ + for match_column_name, match_column in _generate_match_columns(select_cols): + df_with_match_cols = df_with_match_cols.withColumn(match_column_name, match_column) + + # e.g., source_min_col1 == target_min_col1 and source_count_col3 == target_count_col3 + select_conditions = _agg_conditions(select_cols, "select") + + return df_with_match_cols.withColumn("agg_data_match", select_conditions).filter( + col("agg_data_match") == lit(False) + ) + + +def reconcile_agg_data_per_rule( + joined_df: DataFrame, + source_columns: list[str], + target_columns: list[str], + rule: AggregateRule, +) -> DataReconcileOutput: + """ " + Generates the reconciliation output for the given rule + """ + # Generates select columns in the format of: + # [(source_min_col1, target_min_col1), (source_count_col3, target_count_col3) ... ] + + rule_select_columns = [ + ColumnMapping( + source_name=f"source_{rule.agg_type}_{rule.agg_column}", + target_name=f"target_{rule.agg_type}_{rule.agg_column}", + ) + ] + + rule_group_columns = None + if rule.group_by_columns: + rule_group_columns = [ + ColumnMapping(source_name=f"source_group_by_{group_col}", target_name=f"target_group_by_{group_col}") + for group_col in rule.group_by_columns + ] + rule_select_columns.extend(rule_group_columns) + + df_rule_columns = [] + for mapping in rule_select_columns: + df_rule_columns.extend([mapping.source_name, mapping.target_name]) + + joined_df_with_rule_cols = joined_df.selectExpr(*df_rule_columns) + + # Data mismatch between Source and Target aggregated data + mismatch = _get_mismatch_agg_data(joined_df_with_rule_cols, rule_select_columns, rule_group_columns) + + # Data missing in Source DataFrame + rule_target_columns = set(target_columns).intersection([mapping.target_name for mapping in rule_select_columns]) + + missing_in_src = joined_df_with_rule_cols.filter(_agg_conditions(rule_select_columns, "missing_in_src")).select( + *rule_target_columns + ) + + # Data missing in Target DataFrame + rule_source_columns = set(source_columns).intersection([mapping.source_name for mapping in rule_select_columns]) + missing_in_tgt = joined_df_with_rule_cols.filter(_agg_conditions(rule_select_columns, "missing_in_tgt")).select( + *rule_source_columns + ) + + mismatch_count = 0 + if mismatch: + mismatch_count = mismatch.count() + + rule_reconcile_output = DataReconcileOutput( + mismatch_count=mismatch_count, + missing_in_src_count=missing_in_src.count(), + missing_in_tgt_count=missing_in_tgt.count(), + missing_in_src=missing_in_src.limit(_SAMPLE_ROWS), + missing_in_tgt=missing_in_tgt.limit(_SAMPLE_ROWS), + mismatch=MismatchOutput(mismatch_df=mismatch), + ) + + return rule_reconcile_output + + +def join_aggregate_data( + source: DataFrame, + target: DataFrame, + key_columns: list[str] | None, + spark: SparkSession, + path: str, +) -> DataFrame: + # TODO: Integrate with reconcile_data function + + source_alias = "src" + target_alias = "tgt" + + # Generates group by columns in the format of: + # [(source_group_by_col1, target_group_by_col1), (source_group_by_col2, target_group_by_col2) ... ] + + if key_columns: + # If there are Group By columns, do Full join on the grouped columns + df = source.alias(source_alias).join( + other=target.alias(target_alias), + on=_generate_agg_join_condition(source_alias, target_alias, key_columns), + how="full", + ) + else: + # If there is no Group By condition, do Cross join as there is only one record + df = source.alias(source_alias).join( + other=target.alias(target_alias), + how="cross", + ) + + joined_df = df.selectExpr( + *source.columns, + *target.columns, + ) + + # Write the joined df to volume path + joined_volume_df = ReconIntermediatePersist(spark, path).write_and_read_unmatched_df_with_volumes(joined_df).cache() + logger.warning(f"Unmatched data is written to {path} successfully") + + return joined_volume_df diff --git a/src/databricks/labs/remorph/reconcile/connectors/__init__.py b/src/databricks/labs/remorph/reconcile/connectors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py new file mode 100644 index 0000000000..120e176a0a --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -0,0 +1,72 @@ +import logging +from abc import ABC, abstractmethod + +from pyspark.sql import DataFrame + +from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema + +logger = logging.getLogger(__name__) + + +class DataSource(ABC): + + @abstractmethod + def read_data( + self, + catalog: str | None, + schema: str, + table: str, + query: str, + options: JdbcReaderOptions | None, + ) -> DataFrame: + return NotImplemented + + @abstractmethod + def get_schema( + self, + catalog: str | None, + schema: str, + table: str, + ) -> list[Schema]: + return NotImplemented + + @classmethod + def log_and_throw_exception(cls, exception: Exception, fetch_type: str, query: str): + error_msg = f"Runtime exception occurred while fetching {fetch_type} using {query} : {exception}" + logger.warning(error_msg) + raise DataSourceRuntimeException(error_msg) from exception + + +class MockDataSource(DataSource): + + def __init__( + self, + dataframe_repository: dict[tuple[str, str, str], DataFrame], + schema_repository: dict[tuple[str, str, str], list[Schema]], + exception: Exception = RuntimeError("Mock Exception"), + ): + self._dataframe_repository: dict[tuple[str, str, str], DataFrame] = dataframe_repository + self._schema_repository: dict[tuple[str, str, str], list[Schema]] = schema_repository + self._exception = exception + + def read_data( + self, + catalog: str | None, + schema: str, + table: str, + query: str, + options: JdbcReaderOptions | None, + ) -> DataFrame: + catalog_str = catalog if catalog else "" + mock_df = self._dataframe_repository.get((catalog_str, schema, query)) + if not mock_df: + return self.log_and_throw_exception(self._exception, "data", f"({catalog}, {schema}, {query})") + return mock_df + + def get_schema(self, catalog: str | None, schema: str, table: str) -> list[Schema]: + catalog_str = catalog if catalog else "" + mock_schema = self._schema_repository.get((catalog_str, schema, table)) + if not mock_schema: + return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})") + return mock_schema diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py new file mode 100644 index 0000000000..b866413a1a --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -0,0 +1,87 @@ +import logging +import re +from datetime import datetime + +from pyspark.errors import PySparkException +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col +from sqlglot import Dialect + +from databricks.labs.remorph.reconcile.connectors.data_source import DataSource +from databricks.labs.remorph.reconcile.connectors.secrets import SecretsMixin +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +def _get_schema_query(catalog: str, schema: str, table: str): + # TODO: Ensure that the target_catalog in the configuration is not set to "hive_metastore". The source_catalog + # can only be set to "hive_metastore" if the source type is "databricks". + if schema == "global_temp": + return f"describe table global_temp.{table}" + if catalog == "hive_metastore": + return f"describe table {catalog}.{schema}.{table}" + + query = f"""select + lower(column_name) as col_name, + full_data_type as data_type + from {catalog}.information_schema.columns + where lower(table_catalog)='{catalog}' + and lower(table_schema)='{schema}' + and lower(table_name) ='{table}' + order by col_name""" + return re.sub(r'\s+', ' ', query) + + +class DatabricksDataSource(DataSource, SecretsMixin): + + def __init__( + self, + engine: Dialect, + spark: SparkSession, + ws: WorkspaceClient, + secret_scope: str, + ): + self._engine = engine + self._spark = spark + self._ws = ws + self._secret_scope = secret_scope + + def read_data( + self, + catalog: str | None, + schema: str, + table: str, + query: str, + options: JdbcReaderOptions | None, + ) -> DataFrame: + namespace_catalog = "hive_metastore" if not catalog else catalog + if schema == "global_temp": + namespace_catalog = "global_temp" + else: + namespace_catalog = f"{namespace_catalog}.{schema}" + table_with_namespace = f"{namespace_catalog}.{table}" + table_query = query.replace(":tbl", table_with_namespace) + try: + df = self._spark.sql(table_query) + return df.select([col(column).alias(column.lower()) for column in df.columns]) + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "data", table_query) + + def get_schema( + self, + catalog: str | None, + schema: str, + table: str, + ) -> list[Schema]: + catalog_str = catalog if catalog else "hive_metastore" + schema_query = _get_schema_query(catalog_str, schema, table) + try: + logger.debug(f"Fetching schema using query: \n`{schema_query}`") + logger.info(f"Fetching Schema: Started at: {datetime.now()}") + schema_metadata = self._spark.sql(schema_query).where("col_name not like '#%'").distinct().collect() + logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}") + return [Schema(field.col_name.lower(), field.data_type.lower()) for field in schema_metadata] + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "schema", schema_query) diff --git a/src/databricks/labs/remorph/reconcile/connectors/jdbc_reader.py b/src/databricks/labs/remorph/reconcile/connectors/jdbc_reader.py new file mode 100644 index 0000000000..c5b7c9336a --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/jdbc_reader.py @@ -0,0 +1,30 @@ +from pyspark.sql import SparkSession + +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions + + +class JDBCReaderMixin: + _spark: SparkSession + + # TODO update the url + def _get_jdbc_reader(self, query, jdbc_url, driver): + driver_class = { + "oracle": "oracle.jdbc.driver.OracleDriver", + "snowflake": "net.snowflake.client.jdbc.SnowflakeDriver", + } + return ( + self._spark.read.format("jdbc") + .option("url", jdbc_url) + .option("driver", driver_class.get(driver, driver)) + .option("dbtable", f"({query}) tmp") + ) + + @staticmethod + def _get_jdbc_reader_options(options: JdbcReaderOptions): + return { + "numPartitions": options.number_partitions, + "partitionColumn": options.partition_column, + "lowerBound": options.lower_bound, + "upperBound": options.upper_bound, + "fetchsize": options.fetch_size, + } diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py new file mode 100644 index 0000000000..59e2b03cba --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -0,0 +1,100 @@ +import re +import logging +from datetime import datetime + +from pyspark.errors import PySparkException +from pyspark.sql import DataFrame, DataFrameReader, SparkSession +from sqlglot import Dialect + +from databricks.labs.remorph.reconcile.connectors.data_source import DataSource +from databricks.labs.remorph.reconcile.connectors.jdbc_reader import JDBCReaderMixin +from databricks.labs.remorph.reconcile.connectors.secrets import SecretsMixin +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): + _DRIVER = "oracle" + _SCHEMA_QUERY = """select column_name, case when (data_precision is not null + and data_scale <> 0) + then data_type || '(' || data_precision || ',' || data_scale || ')' + when (data_precision is not null and data_scale = 0) + then data_type || '(' || data_precision || ')' + when data_precision is null and (lower(data_type) in ('date') or + lower(data_type) like 'timestamp%') then data_type + when CHAR_LENGTH == 0 then data_type + else data_type || '(' || CHAR_LENGTH || ')' + end data_type + FROM ALL_TAB_COLUMNS + WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'""" + + def __init__( + self, + engine: Dialect, + spark: SparkSession, + ws: WorkspaceClient, + secret_scope: str, + ): + self._engine = engine + self._spark = spark + self._ws = ws + self._secret_scope = secret_scope + + @property + def get_jdbc_url(self) -> str: + return ( + f"jdbc:{OracleDataSource._DRIVER}:thin:{self._get_secret('user')}" + f"/{self._get_secret('password')}@//{self._get_secret('host')}" + f":{self._get_secret('port')}/{self._get_secret('database')}" + ) + + def read_data( + self, + catalog: str | None, + schema: str, + table: str, + query: str, + options: JdbcReaderOptions | None, + ) -> DataFrame: + table_query = query.replace(":tbl", f"{schema}.{table}") + try: + if options is None: + return self.reader(table_query).options(**self._get_timestamp_options()).load() + reader_options = self._get_jdbc_reader_options(options) | self._get_timestamp_options() + return self.reader(table_query).options(**reader_options).load() + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "data", table_query) + + def get_schema( + self, + catalog: str | None, + schema: str, + table: str, + ) -> list[Schema]: + schema_query = re.sub( + r'\s+', + ' ', + OracleDataSource._SCHEMA_QUERY.format(table=table, owner=schema), + ) + try: + logger.debug(f"Fetching schema using query: \n`{schema_query}`") + logger.info(f"Fetching Schema: Started at: {datetime.now()}") + schema_metadata = self.reader(schema_query).load().collect() + logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}") + return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_metadata] + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "schema", schema_query) + + @staticmethod + def _get_timestamp_options() -> dict[str, str]: + return { + "oracle.jdbc.mapDateToTimestamp": "False", + "sessionInitStatement": "BEGIN dbms_session.set_nls('nls_date_format', " + "'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD " + "HH24:MI:SS''');END;", + } + + def reader(self, query: str) -> DataFrameReader: + return self._get_jdbc_reader(query, self.get_jdbc_url, OracleDataSource._DRIVER) diff --git a/src/databricks/labs/remorph/reconcile/connectors/secrets.py b/src/databricks/labs/remorph/reconcile/connectors/secrets.py new file mode 100644 index 0000000000..3b93bbc4e1 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/secrets.py @@ -0,0 +1,30 @@ +import base64 +import logging + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound + +logger = logging.getLogger(__name__) + + +class SecretsMixin: + _ws: WorkspaceClient + _secret_scope: str + + def _get_secret(self, secret_key: str) -> str: + """Get the secret value given a secret scope & secret key. Log a warning if secret does not exist""" + try: + # Return the decoded secret value in string format + secret = self._ws.secrets.get_secret(self._secret_scope, secret_key) + assert secret.value is not None + return base64.b64decode(secret.value).decode("utf-8") + except NotFound as e: + raise NotFound(f'Secret does not exist with scope: {self._secret_scope} and key: {secret_key} : {e}') from e + except UnicodeDecodeError as e: + raise UnicodeDecodeError( + "utf-8", + secret_key.encode(), + 0, + 1, + f"Secret {self._secret_scope}/{secret_key} has Base64 bytes that cannot be decoded to utf-8 string: {e}.", + ) from e diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py new file mode 100644 index 0000000000..9177eb7f94 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -0,0 +1,166 @@ +import logging +import re +from datetime import datetime + +from pyspark.errors import PySparkException +from pyspark.sql import DataFrame, DataFrameReader, SparkSession +from pyspark.sql.functions import col +from sqlglot import Dialect +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +from databricks.labs.remorph.reconcile.connectors.data_source import DataSource +from databricks.labs.remorph.reconcile.connectors.jdbc_reader import JDBCReaderMixin +from databricks.labs.remorph.reconcile.connectors.secrets import SecretsMixin +from databricks.labs.remorph.reconcile.exception import InvalidSnowflakePemPrivateKey +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound + +logger = logging.getLogger(__name__) + + +class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): + _DRIVER = "snowflake" + """ + * INFORMATION_SCHEMA: + - see https://docs.snowflake.com/en/sql-reference/info-schema#considerations-for-replacing-show-commands-with-information-schema-views + * DATA: + - only unquoted identifiers are treated as case-insensitive and are stored in uppercase. + - for quoted identifiers refer: + https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers + * ORDINAL_POSITION: + - indicates the sequential order of a column within a table or view, + starting from 1 based on the order of column definition. + """ + _SCHEMA_QUERY = """select column_name, + case + when numeric_precision is not null and numeric_scale is not null + then + concat(data_type, '(', numeric_precision, ',' , numeric_scale, ')') + when lower(data_type) = 'text' + then + concat('varchar', '(', CHARACTER_MAXIMUM_LENGTH, ')') + else data_type + end as data_type + from {catalog}.INFORMATION_SCHEMA.COLUMNS + where lower(table_name)='{table}' and table_schema = '{schema}' + order by ordinal_position""" + + def __init__( + self, + engine: Dialect, + spark: SparkSession, + ws: WorkspaceClient, + secret_scope: str, + ): + self._engine = engine + self._spark = spark + self._ws = ws + self._secret_scope = secret_scope + + @property + def get_jdbc_url(self) -> str: + return ( + f"jdbc:{SnowflakeDataSource._DRIVER}://{self._get_secret('sfAccount')}.snowflakecomputing.com" + f"/?user={self._get_secret('sfUser')}&password={self._get_secret('sfPassword')}" + f"&db={self._get_secret('sfDatabase')}&schema={self._get_secret('sfSchema')}" + f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}" + ) + + @staticmethod + def get_private_key(pem_private_key: str) -> str: + try: + private_key_bytes = pem_private_key.encode("UTF-8") + p_key = serialization.load_pem_private_key( + private_key_bytes, + password=None, + backend=default_backend(), + ) + pkb = p_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + pkb_str = pkb.decode("UTF-8") + # Remove the first and last lines (BEGIN/END markers) + private_key_pem_lines = pkb_str.strip().split('\n')[1:-1] + # Join the lines to form the base64 encoded string + private_key_pem_str = ''.join(private_key_pem_lines) + return private_key_pem_str + except Exception as e: + message = f"Failed to load or process the provided PEM private key. --> {e}" + logger.error(message) + raise InvalidSnowflakePemPrivateKey(message) from e + + def read_data( + self, + catalog: str | None, + schema: str, + table: str, + query: str, + options: JdbcReaderOptions | None, + ) -> DataFrame: + table_query = query.replace(":tbl", f"{catalog}.{schema}.{table}") + try: + if options is None: + df = self.reader(table_query).load() + else: + options = self._get_jdbc_reader_options(options) + df = ( + self._get_jdbc_reader(table_query, self.get_jdbc_url, SnowflakeDataSource._DRIVER) + .options(**options) + .load() + ) + return df.select([col(column).alias(column.lower()) for column in df.columns]) + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "data", table_query) + + def get_schema( + self, + catalog: str | None, + schema: str, + table: str, + ) -> list[Schema]: + """ + Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in Snowflake. + + If the user's current role does not have the necessary privileges to access the specified + Information Schema object, RunTimeError will be raised: + "SQL access control error: Insufficient privileges to operate on schema 'INFORMATION_SCHEMA' " + """ + schema_query = re.sub( + r'\s+', + ' ', + SnowflakeDataSource._SCHEMA_QUERY.format(catalog=catalog, schema=schema.upper(), table=table), + ) + try: + logger.debug(f"Fetching schema using query: \n`{schema_query}`") + logger.info(f"Fetching Schema: Started at: {datetime.now()}") + schema_metadata = self.reader(schema_query).load().collect() + logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}") + return [Schema(field.COLUMN_NAME.lower(), field.DATA_TYPE.lower()) for field in schema_metadata] + except (RuntimeError, PySparkException) as e: + return self.log_and_throw_exception(e, "schema", schema_query) + + def reader(self, query: str) -> DataFrameReader: + options = { + "sfUrl": self._get_secret('sfUrl'), + "sfUser": self._get_secret('sfUser'), + "sfDatabase": self._get_secret('sfDatabase'), + "sfSchema": self._get_secret('sfSchema'), + "sfWarehouse": self._get_secret('sfWarehouse'), + "sfRole": self._get_secret('sfRole'), + } + try: + options["pem_private_key"] = SnowflakeDataSource.get_private_key(self._get_secret('pem_private_key')) + except (NotFound, KeyError): + logger.warning("pem_private_key not found. Checking for sfPassword") + try: + options["sfPassword"] = self._get_secret('sfPassword') + except (NotFound, KeyError) as e: + message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth." + logger.error(message) + raise NotFound(message) from e + + return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options) diff --git a/src/databricks/labs/remorph/reconcile/connectors/source_adapter.py b/src/databricks/labs/remorph/reconcile/connectors/source_adapter.py new file mode 100644 index 0000000000..04d8c4401d --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/connectors/source_adapter.py @@ -0,0 +1,26 @@ +from pyspark.sql import SparkSession +from sqlglot import Dialect + +from databricks.labs.remorph.reconcile.connectors.data_source import DataSource +from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource +from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource +from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks +from databricks.labs.remorph.transpiler.sqlglot.parsers.oracle import Oracle +from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake +from databricks.sdk import WorkspaceClient + + +def create_adapter( + engine: Dialect, + spark: SparkSession, + ws: WorkspaceClient, + secret_scope: str, +) -> DataSource: + if isinstance(engine, Snowflake): + return SnowflakeDataSource(engine, spark, ws, secret_scope) + if isinstance(engine, Oracle): + return OracleDataSource(engine, spark, ws, secret_scope) + if isinstance(engine, Databricks): + return DatabricksDataSource(engine, spark, ws, secret_scope) + raise ValueError(f"Unsupported source type --> {engine}") diff --git a/src/databricks/labs/remorph/reconcile/constants.py b/src/databricks/labs/remorph/reconcile/constants.py new file mode 100644 index 0000000000..91d30d035d --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/constants.py @@ -0,0 +1,27 @@ +from enum import Enum, auto + + +class AutoName(Enum): + """ + This class is used to auto generate the enum values based on the name of the enum in lower case + + Reference: https://docs.python.org/3/howto/enum.html#enum-advanced-tutorial + """ + + @staticmethod + # pylint: disable-next=bad-dunder-name + def _generate_next_value_(name, start, count, last_values): # noqa ARG004 + return name.lower() + + +class ReconSourceType(AutoName): + SNOWFLAKE = auto() + ORACLE = auto() + DATABRICKS = auto() + + +class ReconReportType(AutoName): + DATA = auto() + SCHEMA = auto() + ROW = auto() + ALL = auto() diff --git a/src/databricks/labs/remorph/reconcile/exception.py b/src/databricks/labs/remorph/reconcile/exception.py new file mode 100644 index 0000000000..913a32a2d6 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/exception.py @@ -0,0 +1,42 @@ +from pyspark.errors import PySparkException +from databricks.labs.remorph.reconcile.recon_config import ReconcileOutput + + +class ColumnMismatchException(Exception): + """Raise the error when there is a mismatch in source and target column names""" + + +class DataSourceRuntimeException(Exception): + """Raise the error when there is a runtime exception thrown in DataSource""" + + +class WriteToTableException(Exception): + """Raise the error when there is a runtime exception thrown while writing data to table""" + + +class InvalidInputException(ValueError): + """Raise the error when the input is invalid""" + + +class ReconciliationException(Exception): + """Raise the error when there is an error occurred during reconciliation""" + + def __init__(self, message: str, reconcile_output: ReconcileOutput | None = None): + self._reconcile_output = reconcile_output + super().__init__(message, reconcile_output) + + @property + def reconcile_output(self) -> ReconcileOutput | None: + return self._reconcile_output + + +class ReadAndWriteWithVolumeException(PySparkException): + """Raise the error when there is a runtime exception thrown while writing data to volume""" + + +class CleanFromVolumeException(PySparkException): + """Raise the error when there is a runtime exception thrown while cleaning data from volume""" + + +class InvalidSnowflakePemPrivateKey(Exception): + """Raise the error when the input private key is invalid""" diff --git a/src/databricks/labs/remorph/reconcile/execute.py b/src/databricks/labs/remorph/reconcile/execute.py new file mode 100644 index 0000000000..f54c8722ae --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/execute.py @@ -0,0 +1,896 @@ +import logging +import sys +import os +from datetime import datetime +from uuid import uuid4 + +from pyspark.errors import PySparkException +from pyspark.sql import DataFrame, SparkSession +from sqlglot import Dialect + +from databricks.labs.remorph.config import ( + DatabaseConfig, + TableRecon, + get_dialect, + ReconcileConfig, + ReconcileMetadataConfig, +) +from databricks.labs.remorph.reconcile.compare import ( + capture_mismatch_data_and_columns, + reconcile_data, + join_aggregate_data, + reconcile_agg_data_per_rule, +) +from databricks.labs.remorph.reconcile.connectors.data_source import DataSource +from databricks.labs.remorph.reconcile.connectors.source_adapter import create_adapter +from databricks.labs.remorph.reconcile.exception import ( + DataSourceRuntimeException, + InvalidInputException, + ReconciliationException, +) +from databricks.labs.remorph.reconcile.query_builder.aggregate_query import AggregateQueryBuilder +from databricks.labs.remorph.reconcile.query_builder.count_query import CountQueryBuilder +from databricks.labs.remorph.reconcile.query_builder.hash_query import HashQueryBuilder +from databricks.labs.remorph.reconcile.query_builder.sampling_query import ( + SamplingQueryBuilder, +) +from databricks.labs.remorph.reconcile.query_builder.threshold_query import ( + ThresholdQueryBuilder, +) +from databricks.labs.remorph.reconcile.recon_capture import ( + ReconCapture, + generate_final_reconcile_output, + ReconIntermediatePersist, + generate_final_reconcile_aggregate_output, +) +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + ReconcileOutput, + ReconcileProcessDuration, + Schema, + SchemaReconcileOutput, + Table, + ThresholdOutput, + ReconcileRecordCount, + AggregateQueryOutput, + AggregateQueryRules, +) +from databricks.labs.remorph.reconcile.schema_compare import SchemaCompare +from databricks.labs.remorph.transpiler.execute import verify_workspace_client +from databricks.sdk import WorkspaceClient +from databricks.labs.blueprint.installation import Installation +from databricks.connect import DatabricksSession + +logger = logging.getLogger(__name__) +_SAMPLE_ROWS = 50 + +RECONCILE_OPERATION_NAME = "reconcile" +AGG_RECONCILE_OPERATION_NAME = "aggregates-reconcile" + + +def validate_input(input_value: str, list_of_value: set, message: str): + if input_value not in list_of_value: + error_message = f"{message} --> {input_value} is not one of {list_of_value}" + logger.error(error_message) + raise InvalidInputException(error_message) + + +def main(*argv) -> None: + logger.debug(f"Arguments received: {argv}") + + assert len(sys.argv) == 2, f"Invalid number of arguments: {len(sys.argv)}," f" Operation name must be specified." + operation_name = sys.argv[1] + + assert operation_name in { + RECONCILE_OPERATION_NAME, + AGG_RECONCILE_OPERATION_NAME, + }, f"Invalid option: {operation_name}" + + w = WorkspaceClient() + + installation = Installation.assume_user_home(w, "remorph") + + reconcile_config = installation.load(ReconcileConfig) + + catalog_or_schema = ( + reconcile_config.database_config.source_catalog + if reconcile_config.database_config.source_catalog + else reconcile_config.database_config.source_schema + ) + filename = f"recon_config_{reconcile_config.data_source}_{catalog_or_schema}_{reconcile_config.report_type}.json" + + logger.info(f"Loading {filename} from Databricks Workspace...") + + table_recon = installation.load(type_ref=TableRecon, filename=filename) + + if operation_name == AGG_RECONCILE_OPERATION_NAME: + return _trigger_reconcile_aggregates(w, table_recon, reconcile_config) + + return _trigger_recon(w, table_recon, reconcile_config) + + +def _trigger_recon( + w: WorkspaceClient, + table_recon: TableRecon, + reconcile_config: ReconcileConfig, +): + try: + recon_output = recon( + ws=w, + spark=DatabricksSession.builder.getOrCreate(), + table_recon=table_recon, + reconcile_config=reconcile_config, + ) + logger.info(f"recon_output: {recon_output}") + logger.info(f"recon_id: {recon_output.recon_id}") + except ReconciliationException as e: + logger.error(f"Error while running recon: {e.reconcile_output}") + raise e + + +def _trigger_reconcile_aggregates( + ws: WorkspaceClient, + table_recon: TableRecon, + reconcile_config: ReconcileConfig, +): + """ + Triggers the reconciliation process for aggregated data between source and target tables. + Supported Aggregate functions: MIN, MAX, COUNT, SUM, AVG, MEAN, MODE, PERCENTILE, STDDEV, VARIANCE, MEDIAN + + This function attempts to reconcile aggregate data based on the configurations provided. It logs the outcome + of the reconciliation process, including any errors encountered during execution. + + Parameters: + - ws (WorkspaceClient): The workspace client used to interact with Databricks workspaces. + - table_recon (TableRecon): Configuration for the table reconciliation process, including source and target details. + - reconcile_config (ReconcileConfig): General configuration for the reconciliation process, + including database and table settings. + + Raises: + - ReconciliationException: If an error occurs during the reconciliation process, it is caught and re-raised + after logging the error details. + """ + try: + recon_output = reconcile_aggregates( + ws=ws, + spark=DatabricksSession.builder.getOrCreate(), + table_recon=table_recon, + reconcile_config=reconcile_config, + ) + logger.info(f"recon_output: {recon_output}") + logger.info(f"recon_id: {recon_output.recon_id}") + except ReconciliationException as e: + logger.error(f"Error while running aggregate reconcile: {str(e)}") + raise e + + +def recon( + ws: WorkspaceClient, + spark: SparkSession, + table_recon: TableRecon, + reconcile_config: ReconcileConfig, + local_test_run: bool = False, +) -> ReconcileOutput: + """[EXPERIMENTAL] Reconcile the data between the source and target tables.""" + # verify the workspace client and add proper product and version details + # TODO For now we are utilising the + # verify_workspace_client from transpile/execute.py file. Later verify_workspace_client function has to be + # refactored + + ws_client: WorkspaceClient = verify_workspace_client(ws) + + # validate the report type + report_type = reconcile_config.report_type.lower() + logger.info(f"report_type: {report_type}, data_source: {reconcile_config.data_source} ") + validate_input(report_type, {"schema", "data", "row", "all"}, "Invalid report type") + + source, target = initialise_data_source( + engine=get_dialect(reconcile_config.data_source), + spark=spark, + ws=ws_client, + secret_scope=reconcile_config.secret_scope, + ) + + recon_id = str(uuid4()) + # initialise the Reconciliation + reconciler = Reconciliation( + source, + target, + reconcile_config.database_config, + report_type, + SchemaCompare(spark=spark), + get_dialect(reconcile_config.data_source), + spark, + metadata_config=reconcile_config.metadata_config, + ) + + # initialise the recon capture class + recon_capture = ReconCapture( + database_config=reconcile_config.database_config, + recon_id=recon_id, + report_type=report_type, + source_dialect=get_dialect(reconcile_config.data_source), + ws=ws_client, + spark=spark, + metadata_config=reconcile_config.metadata_config, + local_test_run=local_test_run, + ) + + for table_conf in table_recon.tables: + recon_process_duration = ReconcileProcessDuration(start_ts=str(datetime.now()), end_ts=None) + schema_reconcile_output = SchemaReconcileOutput(is_valid=True) + data_reconcile_output = DataReconcileOutput() + try: + src_schema, tgt_schema = _get_schema( + source=source, target=target, table_conf=table_conf, database_config=reconcile_config.database_config + ) + except DataSourceRuntimeException as e: + schema_reconcile_output = SchemaReconcileOutput(is_valid=False, exception=str(e)) + else: + if report_type in {"schema", "all"}: + schema_reconcile_output = _run_reconcile_schema( + reconciler=reconciler, table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema + ) + logger.warning("Schema comparison is completed.") + + if report_type in {"data", "row", "all"}: + data_reconcile_output = _run_reconcile_data( + reconciler=reconciler, table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema + ) + logger.warning(f"Reconciliation for '{report_type}' report completed.") + + recon_process_duration.end_ts = str(datetime.now()) + # Persist the data to the delta tables + recon_capture.start( + data_reconcile_output=data_reconcile_output, + schema_reconcile_output=schema_reconcile_output, + table_conf=table_conf, + recon_process_duration=recon_process_duration, + record_count=reconciler.get_record_count(table_conf, report_type), + ) + if report_type != "schema": + ReconIntermediatePersist( + spark=spark, path=generate_volume_path(table_conf, reconcile_config.metadata_config) + ).clean_unmatched_df_from_volume() + + return _verify_successful_reconciliation( + generate_final_reconcile_output( + recon_id=recon_id, + spark=spark, + metadata_config=reconcile_config.metadata_config, + local_test_run=local_test_run, + ) + ) + + +def _verify_successful_reconciliation( + reconcile_output: ReconcileOutput, operation_name: str = "reconcile" +) -> ReconcileOutput: + for table_output in reconcile_output.results: + if table_output.exception_message or ( + table_output.status.column is False + or table_output.status.row is False + or table_output.status.schema is False + or table_output.status.aggregate is False + ): + raise ReconciliationException( + f" Reconciliation failed for one or more tables. Please check the recon metrics for more details." + f" **{operation_name}** failed.", + reconcile_output=reconcile_output, + ) + + logger.info("Reconciliation completed successfully.") + return reconcile_output + + +def generate_volume_path(table_conf: Table, metadata_config: ReconcileMetadataConfig): + catalog = metadata_config.catalog + schema = metadata_config.schema + return f"/Volumes/{catalog}/{schema}/{metadata_config.volume}/{table_conf.source_name}_{table_conf.target_name}/" + + +def initialise_data_source( + ws: WorkspaceClient, + spark: SparkSession, + engine: Dialect, + secret_scope: str, +): + source = create_adapter(engine=engine, spark=spark, ws=ws, secret_scope=secret_scope) + target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope) + + return source, target + + +def _get_missing_data( + reader: DataSource, + sampler: SamplingQueryBuilder, + missing_df: DataFrame, + catalog: str, + schema: str, + table_name: str, +) -> DataFrame: + sample_query = sampler.build_query(missing_df) + return reader.read_data( + catalog=catalog, + schema=schema, + table=table_name, + query=sample_query, + options=None, + ) + + +def reconcile_aggregates( + ws: WorkspaceClient, + spark: SparkSession, + table_recon: TableRecon, + reconcile_config: ReconcileConfig, + local_test_run: bool = False, +): + """[EXPERIMENTAL] Reconcile the aggregated data between the source and target tables. + for e.g., COUNT, SUM, AVG of columns between source and target with or without any specific key/group by columns + Supported Aggregate functions: MIN, MAX, COUNT, SUM, AVG, MEAN, MODE, PERCENTILE, STDDEV, VARIANCE, MEDIAN + """ + # verify the workspace client and add proper product and version details + # TODO For now we are utilising the + # verify_workspace_client from transpile/execute.py file. Later verify_workspace_client function has to be + # refactored + + ws_client: WorkspaceClient = verify_workspace_client(ws) + + report_type = "" + if report_type: + logger.info(f"report_type: {report_type}") + logger.info(f"data_source: {reconcile_config.data_source}") + + # Read the reconcile_config and initialise the source and target data sources. Target is always Databricks + source, target = initialise_data_source( + engine=get_dialect(reconcile_config.data_source), + spark=spark, + ws=ws_client, + secret_scope=reconcile_config.secret_scope, + ) + + # Generate Unique recon_id for every run + recon_id = str(uuid4()) + + # initialise the Reconciliation + reconciler = Reconciliation( + source, + target, + reconcile_config.database_config, + report_type, + SchemaCompare(spark=spark), + get_dialect(reconcile_config.data_source), + spark, + metadata_config=reconcile_config.metadata_config, + ) + + # initialise the recon capture class + recon_capture = ReconCapture( + database_config=reconcile_config.database_config, + recon_id=recon_id, + report_type=report_type, + source_dialect=get_dialect(reconcile_config.data_source), + ws=ws_client, + spark=spark, + metadata_config=reconcile_config.metadata_config, + local_test_run=local_test_run, + ) + + # Get the Aggregated Reconciliation Output for each table + for table_conf in table_recon.tables: + recon_process_duration = ReconcileProcessDuration(start_ts=str(datetime.now()), end_ts=None) + try: + src_schema, tgt_schema = _get_schema( + source=source, + target=target, + table_conf=table_conf, + database_config=reconcile_config.database_config, + ) + except DataSourceRuntimeException as e: + raise ReconciliationException(message=str(e)) from e + + assert table_conf.aggregates, "Aggregates must be defined for Aggregates Reconciliation" + + table_reconcile_agg_output_list: list[AggregateQueryOutput] = _run_reconcile_aggregates( + reconciler=reconciler, + table_conf=table_conf, + src_schema=src_schema, + tgt_schema=tgt_schema, + ) + + recon_process_duration.end_ts = str(datetime.now()) + + # Persist the data to the delta tables + recon_capture.store_aggregates_metrics( + reconcile_agg_output_list=table_reconcile_agg_output_list, + table_conf=table_conf, + recon_process_duration=recon_process_duration, + ) + + ( + ReconIntermediatePersist( + spark=spark, + path=generate_volume_path(table_conf, reconcile_config.metadata_config), + ).clean_unmatched_df_from_volume() + ) + + return _verify_successful_reconciliation( + generate_final_reconcile_aggregate_output( + recon_id=recon_id, + spark=spark, + metadata_config=reconcile_config.metadata_config, + local_test_run=local_test_run, + ), + operation_name=AGG_RECONCILE_OPERATION_NAME, + ) + + +class Reconciliation: + + def __init__( + self, + source: DataSource, + target: DataSource, + database_config: DatabaseConfig, + report_type: str, + schema_comparator: SchemaCompare, + source_engine: Dialect, + spark: SparkSession, + metadata_config: ReconcileMetadataConfig, + ): + self._source = source + self._target = target + self._report_type = report_type + self._database_config = database_config + self._schema_comparator = schema_comparator + self._target_engine = get_dialect("databricks") + self._source_engine = source_engine + self._spark = spark + self._metadata_config = metadata_config + + def reconcile_data( + self, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + ) -> DataReconcileOutput: + data_reconcile_output = self._get_reconcile_output(table_conf, src_schema, tgt_schema) + reconcile_output = data_reconcile_output + if self._report_type in {"data", "all"}: + reconcile_output = self._get_sample_data(table_conf, data_reconcile_output, src_schema, tgt_schema) + if table_conf.get_threshold_columns("source"): + reconcile_output.threshold_output = self._reconcile_threshold_data(table_conf, src_schema, tgt_schema) + + if self._report_type == "row" and table_conf.get_threshold_columns("source"): + logger.warning("Threshold comparison is ignored for 'row' report type") + + return reconcile_output + + def reconcile_schema( + self, + src_schema: list[Schema], + tgt_schema: list[Schema], + table_conf: Table, + ): + return self._schema_comparator.compare(src_schema, tgt_schema, self._source_engine, table_conf) + + def reconcile_aggregates( + self, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + ) -> list[AggregateQueryOutput]: + return self._get_reconcile_aggregate_output(table_conf, src_schema, tgt_schema) + + def _get_reconcile_output( + self, + table_conf, + src_schema, + tgt_schema, + ): + src_hash_query = HashQueryBuilder(table_conf, src_schema, "source", self._source_engine).build_query( + report_type=self._report_type + ) + tgt_hash_query = HashQueryBuilder(table_conf, tgt_schema, "target", self._source_engine).build_query( + report_type=self._report_type + ) + src_data = self._source.read_data( + catalog=self._database_config.source_catalog, + schema=self._database_config.source_schema, + table=table_conf.source_name, + query=src_hash_query, + options=table_conf.jdbc_reader_options, + ) + tgt_data = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=table_conf.target_name, + query=tgt_hash_query, + options=table_conf.jdbc_reader_options, + ) + + volume_path = generate_volume_path(table_conf, self._metadata_config) + return reconcile_data( + source=src_data, + target=tgt_data, + key_columns=table_conf.join_columns, + report_type=self._report_type, + spark=self._spark, + path=volume_path, + ) + + def _get_reconcile_aggregate_output( + self, + table_conf, + src_schema, + tgt_schema, + ): + """ + Creates a single Query, for the aggregates having the same group by columns. (Ex: 1) + If there are no group by columns, all the aggregates are clubbed together in a single query. (Ex: 2) + Examples: + 1. { + "type": "MIN", + "agg_cols": ["COL1"], + "group_by_cols": ["COL4"] + }, + { + "type": "MAX", + "agg_cols": ["COL2"], + "group_by_cols": ["COL9"] + }, + { + "type": "COUNT", + "agg_cols": ["COL2"], + "group_by_cols": ["COL9"] + }, + { + "type": "AVG", + "agg_cols": ["COL3"], + "group_by_cols": ["COL4"] + }, + Query 1: SELECT MIN(COL1), AVG(COL3) FROM :table GROUP BY COL4 + Rules: ID | Aggregate Type | Column | Group By Column + #1, MIN, COL1, COL4 + #2, AVG, COL3, COL4 + ------------------------------------------------------- + Query 2: SELECT MAX(COL2), COUNT(COL2) FROM :table GROUP BY COL9 + Rules: ID | Aggregate Type | Column | Group By Column + #1, MAX, COL2, COL9 + #2, COUNT, COL2, COL9 + 2. { + "type": "MAX", + "agg_cols": ["COL1"] + }, + { + "type": "SUM", + "agg_cols": ["COL2"] + }, + { + "type": "MAX", + "agg_cols": ["COL3"] + } + Query: SELECT MAX(COL1), SUM(COL2), MAX(COL3) FROM :table + Rules: ID | Aggregate Type | Column | Group By Column + #1, MAX, COL1, + #2, SUM, COL2, + #3, MAX, COL3, + """ + + src_query_builder = AggregateQueryBuilder( + table_conf, + src_schema, + "source", + self._source_engine, + ) + + # build Aggregate queries for source, + src_agg_queries: list[AggregateQueryRules] = src_query_builder.build_queries() + + # There could be one or more queries per table based on the group by columns + + # build Aggregate queries for target(Databricks), + tgt_agg_queries: list[AggregateQueryRules] = AggregateQueryBuilder( + table_conf, + tgt_schema, + "target", + self._target_engine, + ).build_queries() + + volume_path = generate_volume_path(table_conf, self._metadata_config) + + table_agg_output: list[AggregateQueryOutput] = [] + + # Iterate over the grouped aggregates and reconcile the data + # Zip all the keys, read the source, target data for each Aggregate query + # and reconcile on the aggregate data + # For e.g., (source_query_GRP1, target_query_GRP1), (source_query_GRP2, target_query_GRP2) + for src_query_with_rules, tgt_query_with_rules in zip(src_agg_queries, tgt_agg_queries): + # For each Aggregate query, read the Source and Target Data and add a hash column + + rules_reconcile_output: list[AggregateQueryOutput] = [] + src_data = None + tgt_data = None + joined_df = None + data_source_exception = None + try: + src_data = self._source.read_data( + catalog=self._database_config.source_catalog, + schema=self._database_config.source_schema, + table=table_conf.source_name, + query=src_query_with_rules.query, + options=table_conf.jdbc_reader_options, + ) + tgt_data = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=table_conf.target_name, + query=tgt_query_with_rules.query, + options=table_conf.jdbc_reader_options, + ) + # Join the Source and Target Aggregated data + joined_df = join_aggregate_data( + source=src_data, + target=tgt_data, + key_columns=src_query_with_rules.group_by_columns, + spark=self._spark, + path=f"{volume_path}{src_query_with_rules.group_by_columns_as_str}", + ) + except DataSourceRuntimeException as e: + data_source_exception = e + + # For each Aggregated Query, reconcile the data based on the rule + for rule in src_query_with_rules.rules: + if data_source_exception: + rule_reconcile_output = DataReconcileOutput(exception=str(data_source_exception)) + else: + rule_reconcile_output = reconcile_agg_data_per_rule( + joined_df, src_data.columns, tgt_data.columns, rule + ) + rules_reconcile_output.append(AggregateQueryOutput(rule=rule, reconcile_output=rule_reconcile_output)) + + # For each table, there could be many Aggregated queries. + # Collect the list of Rule Reconcile output per each Aggregate query and append it to the list + table_agg_output.extend(rules_reconcile_output) + return table_agg_output + + def _get_sample_data( + self, + table_conf, + reconcile_output, + src_schema, + tgt_schema, + ): + mismatch = None + missing_in_src = None + missing_in_tgt = None + + if ( + reconcile_output.mismatch_count > 0 + or reconcile_output.missing_in_src_count > 0 + or reconcile_output.missing_in_tgt_count > 0 + ): + src_sampler = SamplingQueryBuilder(table_conf, src_schema, "source", self._source_engine) + tgt_sampler = SamplingQueryBuilder(table_conf, tgt_schema, "target", self._target_engine) + if reconcile_output.mismatch_count > 0: + mismatch = self._get_mismatch_data( + src_sampler, + tgt_sampler, + reconcile_output.mismatch.mismatch_df, + table_conf.join_columns, + table_conf.source_name, + table_conf.target_name, + ) + + if reconcile_output.missing_in_src_count > 0: + missing_in_src = _get_missing_data( + self._target, + tgt_sampler, + reconcile_output.missing_in_src, + self._database_config.target_catalog, + self._database_config.target_schema, + table_conf.target_name, + ) + + if reconcile_output.missing_in_tgt_count > 0: + missing_in_tgt = _get_missing_data( + self._source, + src_sampler, + reconcile_output.missing_in_tgt, + self._database_config.source_catalog, + self._database_config.source_schema, + table_conf.source_name, + ) + + return DataReconcileOutput( + mismatch=mismatch, + mismatch_count=reconcile_output.mismatch_count, + missing_in_src_count=reconcile_output.missing_in_src_count, + missing_in_tgt_count=reconcile_output.missing_in_tgt_count, + missing_in_src=missing_in_src, + missing_in_tgt=missing_in_tgt, + ) + + def _get_mismatch_data( + self, + src_sampler, + tgt_sampler, + mismatch, + key_columns, + src_table: str, + tgt_table: str, + ): + df = mismatch.limit(_SAMPLE_ROWS).cache() + src_mismatch_sample_query = src_sampler.build_query(df) + tgt_mismatch_sample_query = tgt_sampler.build_query(df) + + src_data = self._source.read_data( + catalog=self._database_config.source_catalog, + schema=self._database_config.source_schema, + table=src_table, + query=src_mismatch_sample_query, + options=None, + ) + tgt_data = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=tgt_table, + query=tgt_mismatch_sample_query, + options=None, + ) + + return capture_mismatch_data_and_columns(source=src_data, target=tgt_data, key_columns=key_columns) + + def _reconcile_threshold_data( + self, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + ): + + src_data, tgt_data = self._get_threshold_data(table_conf, src_schema, tgt_schema) + + source_view = f"source_{table_conf.source_name}_df_threshold_vw" + target_view = f"target_{table_conf.target_name}_df_threshold_vw" + + src_data.createOrReplaceTempView(source_view) + tgt_data.createOrReplaceTempView(target_view) + + return self._compute_threshold_comparison(table_conf, src_schema) + + def _get_threshold_data( + self, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], + ) -> tuple[DataFrame, DataFrame]: + src_threshold_query = ThresholdQueryBuilder( + table_conf, src_schema, "source", self._source_engine + ).build_threshold_query() + tgt_threshold_query = ThresholdQueryBuilder( + table_conf, tgt_schema, "target", self._target_engine + ).build_threshold_query() + + src_data = self._source.read_data( + catalog=self._database_config.source_catalog, + schema=self._database_config.source_schema, + table=table_conf.source_name, + query=src_threshold_query, + options=table_conf.jdbc_reader_options, + ) + tgt_data = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=table_conf.target_name, + query=tgt_threshold_query, + options=table_conf.jdbc_reader_options, + ) + + return src_data, tgt_data + + def _compute_threshold_comparison(self, table_conf: Table, src_schema: list[Schema]) -> ThresholdOutput: + threshold_comparison_query = ThresholdQueryBuilder( + table_conf, src_schema, "target", self._target_engine + ).build_comparison_query() + + threshold_result = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=table_conf.target_name, + query=threshold_comparison_query, + options=table_conf.jdbc_reader_options, + ) + threshold_columns = table_conf.get_threshold_columns("source") + failed_where_cond = " OR ".join([name + "_match = 'Failed'" for name in threshold_columns]) + mismatched_df = threshold_result.filter(failed_where_cond) + mismatched_count = mismatched_df.count() + threshold_df = None + if mismatched_count > 0: + threshold_df = mismatched_df.limit(_SAMPLE_ROWS) + + return ThresholdOutput(threshold_df=threshold_df, threshold_mismatch_count=mismatched_count) + + def get_record_count(self, table_conf: Table, report_type: str) -> ReconcileRecordCount: + if report_type != "schema": + source_count_query = CountQueryBuilder(table_conf, "source", self._source_engine).build_query() + target_count_query = CountQueryBuilder(table_conf, "target", self._target_engine).build_query() + source_count = self._source.read_data( + catalog=self._database_config.source_catalog, + schema=self._database_config.source_schema, + table=table_conf.source_name, + query=source_count_query, + options=None, + ).collect()[0]["count"] + target_count = self._target.read_data( + catalog=self._database_config.target_catalog, + schema=self._database_config.target_schema, + table=table_conf.target_name, + query=target_count_query, + options=None, + ).collect()[0]["count"] + + return ReconcileRecordCount(source=int(source_count), target=int(target_count)) + return ReconcileRecordCount() + + +def _get_schema( + source: DataSource, + target: DataSource, + table_conf: Table, + database_config: DatabaseConfig, +) -> tuple[list[Schema], list[Schema]]: + src_schema = source.get_schema( + catalog=database_config.source_catalog, + schema=database_config.source_schema, + table=table_conf.source_name, + ) + tgt_schema = target.get_schema( + catalog=database_config.target_catalog, + schema=database_config.target_schema, + table=table_conf.target_name, + ) + + return src_schema, tgt_schema + + +def _run_reconcile_data( + reconciler: Reconciliation, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], +) -> DataReconcileOutput: + try: + return reconciler.reconcile_data(table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema) + except DataSourceRuntimeException as e: + return DataReconcileOutput(exception=str(e)) + + +def _run_reconcile_schema( + reconciler: Reconciliation, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], +): + try: + return reconciler.reconcile_schema(table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema) + except PySparkException as e: + return SchemaReconcileOutput(is_valid=False, exception=str(e)) + + +def _run_reconcile_aggregates( + reconciler: Reconciliation, + table_conf: Table, + src_schema: list[Schema], + tgt_schema: list[Schema], +) -> list[AggregateQueryOutput]: + try: + return reconciler.reconcile_aggregates(table_conf, src_schema, tgt_schema) + except DataSourceRuntimeException as e: + return [AggregateQueryOutput(reconcile_output=DataReconcileOutput(exception=str(e)), rule=None)] + + +if __name__ == "__main__": + if "DATABRICKS_RUNTIME_VERSION" not in os.environ: + raise SystemExit("Only intended to run in Databricks Runtime") + main(*sys.argv) diff --git a/src/databricks/labs/remorph/reconcile/query_builder/__init__.py b/src/databricks/labs/remorph/reconcile/query_builder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/reconcile/query_builder/aggregate_query.py b/src/databricks/labs/remorph/reconcile/query_builder/aggregate_query.py new file mode 100644 index 0000000000..4aed6128ad --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/aggregate_query.py @@ -0,0 +1,287 @@ +import logging +from itertools import groupby +from operator import attrgetter + +import sqlglot.expressions as exp + +from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + build_column, +) +from databricks.labs.remorph.reconcile.recon_config import ( + Aggregate, + AggregateQueryRules, + AggregateRule, +) + +logger = logging.getLogger(__name__) + + +class AggregateQueryBuilder(QueryBuilder): + + def _get_mapping_col(self, col: str) -> str: + """ + Get the column mapping for the given column based on the layer + + Examples: + Input :: col: "COL1", mapping: "{source: COL1, target: COLUMN1}", layer: "source" + + Returns -> "COLUMN1" + + :param col: Column Name + :return: Mapped Column Name if found, else Column Name + """ + # apply column mapping, ex: "{source: pid, target: product_id}" + column_with_mapping = self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer) + if self.layer == "target": + column_with_mapping = self.table_conf.get_layer_src_to_tgt_col_mapping(col, self.layer) + return column_with_mapping + + def _get_mapping_cols_with_alias(self, cols_list: list[str], agg_type: str): + """ + Creates a Column Expression for each [Mapped] Column with Agg_Type+Original_Column as Alias + + Examples: + Input :: cols_list: ["COL1", "COL2"], agg_type: ["MAX"] \n + Returns -> ["column1 AS max<#>col1", "column2 AS max<#>col2] + + :param cols_list: List of aggregate columns + :param agg_type: MIN, MAX, COUNT, AVG + :return: list[Expression] - List of Column Expressions with Alias + """ + cols_with_mapping: list[exp.Expression] = [] + for col in cols_list: + column_expr = build_column( + this=f"{self._get_mapping_col(col)}", alias=f"{agg_type.lower()}<#>{col.lower()}" + ) + cols_with_mapping.append(column_expr) + return cols_with_mapping + + def _agg_query_cols_with_alias(self, transformed_cols: list[exp.Expression]): + cols_with_alias = [] + + for transformed_col in transformed_cols: + # Split the alias defined above as agg_type(min, max etc..), original column (pid) + agg_type, org_col_name = transformed_col.alias.split("<#>") + + # Create a new alias with layer, agg_type and original column name, + # ex: source_min_pid, target_max_product_id + layer_agg_type_col_alias = f"{self.layer}_{agg_type}_{org_col_name}".lower() + + # Get the Transformed column name without the alias + col_name = transformed_col.sql().replace(f"AS {transformed_col.alias}", '').strip() + + # Create a new Column Expression with the new alias, + # ex: MIN(pid) AS source_min_pid, MIN(product_id) AS target_min_pid + column_name = f"{col_name}" if agg_type == "group_by" else f"{agg_type}({col_name})" + col_with_alias = build_column(this=column_name, alias=layer_agg_type_col_alias) + cols_with_alias.append(col_with_alias) + + return cols_with_alias + + def _get_layer_query(self, group_list: list[Aggregate]) -> AggregateQueryRules: + """ + Builds the query based on the layer: + * Creates an Expression using + - 'select' columns with alias for the aggregate columns + - 'filters' (where) based on the layer + - 'group by' if group_by_columns are defined + * Generates and returns the SQL query using the above Expression and Dialect + - query Aggregate rules + + Examples: + 1.Input :: group_list: [Aggregate(type="Max", agg_cols=["col2", "col3"], group_by_columns=["col1"]), + Aggregate(type="Sum", agg_cols=["col1", "col2"], group_by_columns=["col1"])] + Returns -> SELECT max(col2) AS src_max_col2, max(col3) AS src_max_col3, + sum(col1) AS src_sum_col1, sum(col2) AS src_sum_col2 + FROM :tbl + WHERE col1 IS NOT NULL + GROUP BY col1 + 2. + group_list: [Aggregate(type="avg", agg_cols=["col4"])] + :layer: "tgt" + :returns -> SELECT avg(col4) AS tgt_avg_col4 FROM :tbl + + :param group_list: List of Aggregate objects with same Group by columns + :return: str - SQL Query + """ + cols_with_mapping: list[exp.Expression] = [] + # Generates a Single Query for multiple aggregates with the same group_by_columns, + # refer to Example 1 + query_agg_rules = [] + processed_rules: dict[str, str] = {} + for agg in group_list: + + # Skip duplicate rules + # Example: {min_grp1+__+grp2 : col1+__+col2}, key = min_grp1+__+grp2 + key = f"{agg.type}_{agg.group_by_columns_as_str}" + if key in processed_rules: + existing_rule = processed_rules.get(key) + if existing_rule == agg.agg_columns_as_str: + logger.info( + f"Skipping duplicate rule for key: {key}, value: {agg.agg_columns_as_str}," + f" layer: {self.layer}" + ) + continue + processed_rules[key] = agg.agg_columns_as_str + + # Get the rules for each aggregate and append to the query_agg_rules list + query_agg_rules.extend(self._build_aggregate_rules(agg)) + + # Get the mapping with alias for aggregate columns and append to the cols_with_mapping list + cols_with_mapping.extend(self._get_mapping_cols_with_alias(agg.agg_columns, agg.type)) + + # Apply user transformations on Select columns + # Example: {column_name: creation_date, source: creation_date, target: to_date(creation_date,'yyyy-mm-dd')} + select_cols_with_transform = ( + self._apply_user_transformation(cols_with_mapping) if self.user_transformations else cols_with_mapping + ) + + # Transformed columns + select_cols_with_alias = self._agg_query_cols_with_alias(select_cols_with_transform) + query_exp = exp.select(*select_cols_with_alias).from_(":tbl").where(self.filter) + + assert group_list[0], "At least, one item must be present in the group_list." + + # Apply Group by if group_by_columns are defined + if group_list[0].group_by_columns: + group_by_cols_with_mapping = self._get_mapping_cols_with_alias(group_list[0].group_by_columns, "GROUP_BY") + + # Apply user transformations on group_by_columns, + # ex: {column_name: creation_date, source: creation_date, target: to_date(creation_date,'yyyy-mm-dd')} + group_by_cols_with_transform = ( + self._apply_user_transformation(group_by_cols_with_mapping) + if self.user_transformations + else group_by_cols_with_mapping + ) + + select_group_by_cols_with_alias = self._agg_query_cols_with_alias(group_by_cols_with_transform) + + # Group by column doesn't support alias (GROUP BY to_date(COL1, 'yyyy-MM-dd') AS col1) throws error + group_by_col_without_alias = [ + build_column(this=group_by_col_with_alias.sql().split(" AS ")[0].strip()) + for group_by_col_with_alias in select_group_by_cols_with_alias + if " AS " in group_by_col_with_alias.sql() + ] + + query_exp = ( + exp.select(*select_cols_with_alias + select_group_by_cols_with_alias) + .from_(":tbl") + .where(self.filter) + .group_by(*group_by_col_without_alias) + ) + + agg_query_rules = AggregateQueryRules( + layer=self.layer, + group_by_columns=group_list[0].group_by_columns, + group_by_columns_as_str=group_list[0].group_by_columns_as_str, + query=query_exp.sql(dialect=self.engine), + rules=query_agg_rules, + ) + return agg_query_rules + + def grouped_aggregates(self): + """ + Group items based on group_by_columns_keys: + Example: + aggregates = [ + Aggregate(type="Min", agg_cols=["c_nation_str", "col2"], + group_by_columns=["col3"]), + Aggregate(type="Max", agg_cols=["col2", "col3"], group_by_columns=["col1"]), + Aggregate(type="avg", agg_cols=["col4"]), + Aggregate(type="sum", agg_cols=["col3", "col6"], group_by_columns=["col1"]), + ] + output: + * key: NA with index 1 + - Aggregate(agg_cols=['col4'], type='avg', group_by_columns=None, group_by_columns_as_str='NA') + * key: col1 with index 2 + - Aggregate(agg_cols=['col2', 'col3'], type='Max', group_by_columns=['col1'], + group_by_columns_as_str='col1') + - Aggregate(agg_cols=['col3', 'col6'], type='sum', group_by_columns=['col1'], + group_by_columns_as_str='col1') + * key: col3 with index 3 + - Aggregate(agg_cols=['c_nation_str', 'col2'], type='Min', group_by_columns=['col3'], + group_by_columns_as_str='col3') + """ + _aggregates: list[Aggregate] = [] + + assert self.aggregates, "Aggregates config must be defined to build the queries." + self._validate(self.aggregates, "Aggregates config must be defined to build the queries.") + + if self.aggregates: + _aggregates = self.aggregates + + # Sort the aggregates based on group_by_columns_as_str + _aggregates.sort(key=attrgetter("group_by_columns_as_str")) + + return groupby(_aggregates, key=attrgetter("group_by_columns_as_str")) + + @classmethod + def _build_aggregate_rules(cls, agg: Aggregate) -> list[AggregateRule]: + """ + Builds the rules for each aggregate column in the given Aggregate object + + Example: + Input :: Aggregate: { + "type": "MIN", + "agg_cols": ["COL1", "COL2"], + "group_by_columns": ["GRP1", "GRP2] + } + Returns -> [AggregateRule(rule_id=hash(min_col1_grp1_grp2)), + query=SELECT {rule_id} as rule_id, + 'min' as agg_type, + 'col1' as agg_column, + ('grp1', 'grp2') as group_by_columns), + + AggregateRule(rule_id=hash(min_col2_grp1_grp2)), + query=SELECT {rule_id} as rule_id, + 'min' as agg_type, + 'col2' as agg_column, + ('grp1', 'grp2') as group_by_columns)] + :param agg: Aggregate + :return: list[AggregateRule] + """ + + return [ + AggregateRule( + agg_type=agg.type, + agg_column=agg_col, + group_by_columns=agg.group_by_columns, + group_by_columns_as_str=agg.group_by_columns_as_str, + ) + for agg_col in agg.agg_columns + ] + + def build_queries(self) -> list[AggregateQueryRules]: + """ + Generates the Source and Target Queries for the list of Aggregate objects + * Group items based on group_by_columns_keys and for each group, + generates the query_with_rules for both Source and Target Dialects + * Generates 2 Queries (Source, Target) for each unique group_by_columns_keys + + Examples: + 1. [Aggregate(type="avg", agg_cols=["col4"])] + { + "src_query_1": "SELECT avg(col4) AS src_avg_col4 FROM :tbl" + } + { + "tgt_query_1": "SELECT avg(col4) AS tgt_avg_col4 FROM :tbl" + } + 2. [Aggregate(type="Max", agg_cols=["col3"], group_by_columns=["col1"]), + Aggregate(type="Sum", agg_cols=["col2"], group_by_columns=["col4"])] + { + "src_query_1": "SELECT max(col3) AS src_max_col3 FROM :tbl GROUP BY col1" + "src_query_2": "SELECT sum(col2) AS src_sum_col2 FROM :tbl GROUP BY col4" + } + { + "tgt_query_1": "SELECT max(col3) AS tgt_max_col3 FROM :tbl GROUP BY col1" + "tgt_query_2": "SELECT sum(col2) AS tgt_sum_col2 FROM :tbl GROUP BY col4" + } + :return: Dictionary with Source and Target Queries + """ + query_with_rules_list = [] + for key, group in self.grouped_aggregates(): + logger.info(f"Building Query and Rules for key: {key}, layer: {self.layer}") + query_with_rules_list.append(self._get_layer_query(list(group))) + + return query_with_rules_list diff --git a/src/databricks/labs/remorph/reconcile/query_builder/base.py b/src/databricks/labs/remorph/reconcile/query_builder/base.py new file mode 100644 index 0000000000..c718344b1c --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/base.py @@ -0,0 +1,138 @@ +import logging +from abc import ABC + +import sqlglot.expressions as exp +from sqlglot import Dialect, parse_one + +from databricks.labs.remorph.config import SQLGLOT_DIALECTS +from databricks.labs.remorph.reconcile.exception import InvalidInputException +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + DataType_transform_mapping, + transform_expression, +) +from databricks.labs.remorph.reconcile.recon_config import Schema, Table, Aggregate + +logger = logging.getLogger(__name__) + + +class QueryBuilder(ABC): + def __init__( + self, + table_conf: Table, + schema: list[Schema], + layer: str, + engine: Dialect, + ): + self._table_conf = table_conf + self._schema = schema + self._layer = layer + self._engine = engine + + @property + def engine(self) -> Dialect: + return self._engine + + @property + def layer(self) -> str: + return self._layer + + @property + def schema(self) -> list[Schema]: + return self._schema + + @property + def table_conf(self) -> Table: + return self._table_conf + + @property + def select_columns(self) -> set[str]: + return self.table_conf.get_select_columns(self._schema, self._layer) + + @property + def threshold_columns(self) -> set[str]: + return self.table_conf.get_threshold_columns(self._layer) + + @property + def join_columns(self) -> set[str] | None: + return self.table_conf.get_join_columns(self._layer) + + @property + def drop_columns(self) -> set[str]: + return self._table_conf.get_drop_columns(self._layer) + + @property + def partition_column(self) -> set[str]: + return self._table_conf.get_partition_column(self._layer) + + @property + def filter(self) -> str | None: + return self._table_conf.get_filter(self._layer) + + @property + def user_transformations(self) -> dict[str, str]: + return self._table_conf.get_transformation_dict(self._layer) + + @property + def aggregates(self) -> list[Aggregate] | None: + return self.table_conf.aggregates + + def add_transformations(self, aliases: list[exp.Expression], source: Dialect) -> list[exp.Expression]: + if self.user_transformations: + alias_with_user_transforms = self._apply_user_transformation(aliases) + default_transform_schema: list[Schema] = list( + filter(lambda sch: sch.column_name not in self.user_transformations.keys(), self.schema) + ) + return self._apply_default_transformation(alias_with_user_transforms, default_transform_schema, source) + return self._apply_default_transformation(aliases, self.schema, source) + + def _apply_user_transformation(self, aliases: list[exp.Expression]) -> list[exp.Expression]: + with_transform = [] + for alias in aliases: + with_transform.append(alias.transform(self._user_transformer, self.user_transformations)) + return with_transform + + @staticmethod + def _user_transformer(node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression: + if isinstance(node, exp.Column) and user_transformations: + column_name = node.name + if column_name in user_transformations.keys(): + return parse_one(user_transformations.get(column_name, column_name)) + return node + + def _apply_default_transformation( + self, aliases: list[exp.Expression], schema: list[Schema], source: Dialect + ) -> list[exp.Expression]: + with_transform = [] + for alias in aliases: + with_transform.append(alias.transform(self._default_transformer, schema, source)) + return with_transform + + @staticmethod + def _default_transformer(node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression: + + def _get_transform(datatype: str): + source_dialects = [source_key for source_key, dialect in SQLGLOT_DIALECTS.items() if dialect == source] + source_dialect = source_dialects[0] if source_dialects else "universal" + + source_mapping = DataType_transform_mapping.get(source_dialect, {}) + + if source_mapping.get(datatype.upper()) is not None: + return source_mapping.get(datatype.upper()) + if source_mapping.get("default") is not None: + return source_mapping.get("default") + + return DataType_transform_mapping.get("universal", {}).get("default") + + schema_dict = {v.column_name: v.data_type for v in schema} + if isinstance(node, exp.Column): + column_name = node.name + if column_name in schema_dict.keys(): + transform = _get_transform(schema_dict.get(column_name, column_name)) + return transform_expression(node, transform) + return node + + def _validate(self, field: set[str] | list[str] | None, message: str): + if field is None: + message = f"Exception for {self.table_conf.target_name} target table in {self.layer} layer --> {message}" + logger.error(message) + raise InvalidInputException(message) diff --git a/src/databricks/labs/remorph/reconcile/query_builder/count_query.py b/src/databricks/labs/remorph/reconcile/query_builder/count_query.py new file mode 100644 index 0000000000..98215836f4 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/count_query.py @@ -0,0 +1,33 @@ +import logging + +from sqlglot import Dialect +from sqlglot import expressions as exp + +from databricks.labs.remorph.reconcile.query_builder.expression_generator import build_column, build_literal +from databricks.labs.remorph.reconcile.recon_config import Table + +logger = logging.getLogger(__name__) + + +class CountQueryBuilder: + + def __init__( + self, + table_conf: Table, + layer: str, + engine: Dialect, + ): + self._table_conf = table_conf + self._layer = layer + self._engine = engine + + def build_query(self): + select_clause = build_column(this=exp.Count(this=build_literal(this="1", is_string=False)), alias="count") + count_query = ( + exp.select(select_clause) + .from_(":tbl") + .where(self._table_conf.get_filter(self._layer)) + .sql(dialect=self._engine) + ) + logger.info(f"Record Count Query for {self._layer}: {count_query}") + return count_query diff --git a/src/databricks/labs/remorph/reconcile/query_builder/expression_generator.py b/src/databricks/labs/remorph/reconcile/query_builder/expression_generator.py new file mode 100644 index 0000000000..aad37f276d --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/expression_generator.py @@ -0,0 +1,260 @@ +from collections.abc import Callable +from functools import partial + +from pyspark.sql.types import DataType, NumericType +from sqlglot import Dialect +from sqlglot import expressions as exp + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.recon_config import HashAlgoMapping + + +def _apply_func_expr(expr: exp.Expression, expr_func: Callable, **kwargs) -> exp.Expression: + is_terminal = isinstance(expr, exp.Column) + new_expr = expr.copy() + for node in new_expr.dfs(): + if isinstance(node, exp.Column): + column_name = node.name + table_name = node.table + func = expr_func(this=exp.Column(this=column_name, table=table_name), **kwargs) + if is_terminal: + return func + node.replace(func) + return new_expr + + +def concat(expr: list[exp.Expression]) -> exp.Expression: + return exp.Concat(expressions=expr, safe=True) + + +def sha2(expr: exp.Expression, num_bits: str, is_expr: bool = False) -> exp.Expression: + if is_expr: + return exp.SHA2(this=expr, length=exp.Literal(this=num_bits, is_string=False)) + return _apply_func_expr(expr, exp.SHA2, length=exp.Literal(this=num_bits, is_string=False)) + + +def lower(expr: exp.Expression, is_expr: bool = False) -> exp.Expression: + if is_expr: + return exp.Lower(this=expr) + return _apply_func_expr(expr, exp.Lower) + + +def coalesce(expr: exp.Expression, default="0", is_string=False) -> exp.Expression: + expressions = [exp.Literal(this=default, is_string=is_string)] + return _apply_func_expr(expr, exp.Coalesce, expressions=expressions) + + +def trim(expr: exp.Expression) -> exp.Trim | exp.Expression: + return _apply_func_expr(expr, exp.Trim) + + +def json_format(expr: exp.Expression, options: dict[str, str] | None = None) -> exp.Expression: + return _apply_func_expr(expr, exp.JSONFormat, options=options) + + +def sort_array(expr: exp.Expression, asc=True) -> exp.Expression: + return _apply_func_expr(expr, exp.SortArray, asc=exp.Boolean(this=asc)) + + +def to_char(expr: exp.Expression, to_format=None, nls_param=None) -> exp.Expression: + if to_format: + return _apply_func_expr( + expr, exp.ToChar, format=exp.Literal(this=to_format, is_string=True), nls_param=nls_param + ) + return _apply_func_expr(expr, exp.ToChar) + + +def array_to_string( + expr: exp.Expression, + delimiter: str = ",", + is_string=True, + null_replacement: str | None = None, + is_null_replace=True, +) -> exp.Expression: + if null_replacement: + return _apply_func_expr( + expr, + exp.ArrayToString, + expression=[exp.Literal(this=delimiter, is_string=is_string)], + null=exp.Literal(this=null_replacement, is_string=is_null_replace), + ) + return _apply_func_expr(expr, exp.ArrayToString, expression=[exp.Literal(this=delimiter, is_string=is_string)]) + + +def array_sort(expr: exp.Expression, asc=True) -> exp.Expression: + return _apply_func_expr(expr, exp.ArraySort, expression=exp.Boolean(this=asc)) + + +def anonymous(expr: exp.Column, func: str, is_expr: bool = False) -> exp.Expression: + """ + + This function used in cases where the sql functions are not available in sqlGlot expressions + Example: + >>> from sqlglot import parse_one + >>> print(repr(parse_one('select unix_timestamp(col1)'))) + + the above code gives you a Select Expression of Anonymous function. + + To achieve the same,we can use the function as below: + eg: + >>> expr = parse_one("select col1 from dual") + >>> transformed_expr=anonymous(expr,"unix_timestamp({})") + >>> print(transformed_expr) + 'SELECT UNIX_TIMESTAMP(col1) FROM DUAL' + + """ + if is_expr: + return exp.Column(this=func.format(expr)) + is_terminal = isinstance(expr, exp.Column) + new_expr = expr.copy() + for node in new_expr.dfs(): + if isinstance(node, exp.Column): + name = f"{node.table}.{node.name}" if node.table else node.name + anonymous_func = exp.Column(this=func.format(name)) + if is_terminal: + return anonymous_func + node.replace(anonymous_func) + return new_expr + + +def build_column(this: exp.ExpOrStr, table_name="", quoted=False, alias=None) -> exp.Expression: + if alias: + if isinstance(this, str): + return exp.Alias( + this=exp.Column(this=this, table=table_name), alias=exp.Identifier(this=alias, quoted=quoted) + ) + return exp.Alias(this=this, alias=exp.Identifier(this=alias, quoted=quoted)) + return exp.Column(this=exp.Identifier(this=this, quoted=quoted), table=table_name) + + +def build_literal(this: exp.ExpOrStr, alias=None, quoted=False, is_string=True) -> exp.Expression: + if alias: + return exp.Alias( + this=exp.Literal(this=this, is_string=is_string), alias=exp.Identifier(this=alias, quoted=quoted) + ) + return exp.Literal(this=this, is_string=is_string) + + +def transform_expression( + expr: exp.Expression, + funcs: list[Callable[[exp.Expression], exp.Expression]], +) -> exp.Expression: + for func in funcs: + expr = func(expr) + assert isinstance(expr, exp.Expression), ( + f"Func returned an instance of type [{type(expr)}], " "should have been Expression." + ) + return expr + + +def get_hash_transform( + source: Dialect, + layer: str, +): + dialect_algo = Dialect_hash_algo_mapping.get(source) + if not dialect_algo: + raise ValueError(f"Source {source} is not supported. Please add it to Dialect_hash_algo_mapping dictionary.") + + layer_algo = getattr(dialect_algo, layer, None) + if not layer_algo: + raise ValueError( + f"Layer {layer} is not supported for source {source}. Please add it to Dialect_hash_algo_mapping dictionary." + ) + return [layer_algo] + + +def build_from_clause(table_name: str, table_alias: str | None = None) -> exp.From: + return exp.From(this=exp.Table(this=exp.Identifier(this=table_name), alias=table_alias)) + + +def build_join_clause( + table_name: str, + join_columns: list, + source_table_alias: str | None = None, + target_table_alias: str | None = None, + kind: str = "inner", + func: Callable = exp.NullSafeEQ, +) -> exp.Join: + join_conditions = [] + for column in join_columns: + join_condition = func( + this=exp.Column(this=column, table=source_table_alias), + expression=exp.Column(this=column, table=target_table_alias), + ) + join_conditions.append(join_condition) + + # Combine all join conditions with AND + on_condition: exp.NullSafeEQ | exp.And = join_conditions[0] + for condition in join_conditions[1:]: + on_condition = exp.And(this=on_condition, expression=condition) + + return exp.Join( + this=exp.Table(this=exp.Identifier(this=table_name), alias=target_table_alias), kind=kind, on=on_condition + ) + + +def build_sub( + left_column_name: str, + right_column_name: str, + left_table_name: str | None = None, + right_table_name: str | None = None, +) -> exp.Sub: + return exp.Sub( + this=build_column(left_column_name, left_table_name), + expression=build_column(right_column_name, right_table_name), + ) + + +def build_where_clause(where_clause: list[exp.Expression], condition_type: str = "or") -> exp.Expression: + func = exp.Or if condition_type == "or" else exp.And + # Start with a default + combined_expression: exp.Expression = exp.Paren(this=func(this='1 = 1', expression='1 = 1')) + + # Loop through the expressions and combine them with OR + for expression in where_clause: + combined_expression = func(this=combined_expression, expression=expression) + + return combined_expression + + +def build_if(this: exp.Expression, true: exp.Expression, false: exp.Expression | None = None) -> exp.If: + return exp.If(this=this, true=true, false=false) + + +def build_between(this: exp.Expression, low: exp.Expression, high: exp.Expression) -> exp.Between: + return exp.Between(this=this, low=low, high=high) + + +def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) -> bool: + if isinstance(column_types_dict.get(column_name), NumericType): + return False + return True + + +DataType_transform_mapping: dict[str, dict[str, list[partial[exp.Expression]]]] = { + "universal": {"default": [partial(coalesce, default='_null_recon_', is_string=True), partial(trim)]}, + "snowflake": {exp.DataType.Type.ARRAY.value: [partial(array_to_string), partial(array_sort)]}, + "oracle": { + exp.DataType.Type.NCHAR.value: [partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')")], + exp.DataType.Type.NVARCHAR.value: [partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')")], + }, + "databricks": { + exp.DataType.Type.ARRAY.value: [partial(anonymous, func="CONCAT_WS(',', SORT_ARRAY({}))")], + }, +} + +sha256_partial = partial(sha2, num_bits="256", is_expr=True) +Dialect_hash_algo_mapping: dict[Dialect, HashAlgoMapping] = { + get_dialect("snowflake"): HashAlgoMapping( + source=sha256_partial, + target=sha256_partial, + ), + get_dialect("oracle"): HashAlgoMapping( + source=partial(anonymous, func="RAWTOHEX(STANDARD_HASH({}, 'SHA256'))", is_expr=True), + target=sha256_partial, + ), + get_dialect("databricks"): HashAlgoMapping( + source=sha256_partial, + target=sha256_partial, + ), +} diff --git a/src/databricks/labs/remorph/reconcile/query_builder/hash_query.py b/src/databricks/labs/remorph/reconcile/query_builder/hash_query.py new file mode 100644 index 0000000000..ddd9bec045 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/hash_query.py @@ -0,0 +1,87 @@ +import logging + +import sqlglot.expressions as exp +from sqlglot import Dialect + +from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + build_column, + concat, + get_hash_transform, + lower, + transform_expression, +) +from databricks.labs.remorph.config import get_dialect + +logger = logging.getLogger(__name__) + + +def _hash_transform( + node: exp.Expression, + source: Dialect, + layer: str, +): + transform = get_hash_transform(source, layer) + return transform_expression(node, transform) + + +_HASH_COLUMN_NAME = "hash_value_recon" + + +class HashQueryBuilder(QueryBuilder): + + def build_query(self, report_type: str) -> str: + + if report_type != 'row': + self._validate(self.join_columns, f"Join Columns are compulsory for {report_type} type") + + _join_columns = self.join_columns if self.join_columns else set() + hash_cols = sorted((_join_columns | self.select_columns) - self.threshold_columns - self.drop_columns) + + key_cols = hash_cols if report_type == "row" else sorted(_join_columns | self.partition_column) + + cols_with_alias = [ + build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)) + for col in key_cols + ] + + # in case if we have column mapping, we need to sort the target columns in the order of source columns to get + # same hash value + hash_cols_with_alias = [ + {"this": col, "alias": self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)} + for col in hash_cols + ] + sorted_hash_cols_with_alias = sorted(hash_cols_with_alias, key=lambda column: column["alias"]) + hashcols_sorted_as_src_seq = [column["this"] for column in sorted_hash_cols_with_alias] + + key_cols_with_transform = ( + self._apply_user_transformation(cols_with_alias) if self.user_transformations else cols_with_alias + ) + hash_col_with_transform = [self._generate_hash_algorithm(hashcols_sorted_as_src_seq, _HASH_COLUMN_NAME)] + + dialect = self.engine if self.layer == "source" else get_dialect("databricks") + res = ( + exp.select(*hash_col_with_transform + key_cols_with_transform) + .from_(":tbl") + .where(self.filter) + .sql(dialect=dialect) + ) + + logger.info(f"Hash Query for {self.layer}: {res}") + return res + + def _generate_hash_algorithm( + self, + cols: list[str], + column_alias: str, + ) -> exp.Expression: + cols_with_alias = [build_column(this=col, alias=None) for col in cols] + cols_with_transform = self.add_transformations( + cols_with_alias, self.engine if self.layer == "source" else get_dialect("databricks") + ) + col_exprs = exp.select(*cols_with_transform).iter_expressions() + concat_expr = concat(list(col_exprs)) + + hash_expr = concat_expr.transform(_hash_transform, self.engine, self.layer).transform(lower, is_expr=True) + + return build_column(hash_expr, alias=column_alias) diff --git a/src/databricks/labs/remorph/reconcile/query_builder/sampling_query.py b/src/databricks/labs/remorph/reconcile/query_builder/sampling_query.py new file mode 100644 index 0000000000..8fd83885a2 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/sampling_query.py @@ -0,0 +1,105 @@ +import logging + +import sqlglot.expressions as exp +from pyspark.sql import DataFrame +from sqlglot import select + +from databricks.labs.remorph.config import get_key_from_dialect +from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + build_column, + build_literal, + _get_is_string, + build_join_clause, + trim, + coalesce, +) + +_SAMPLE_ROWS = 50 + +logger = logging.getLogger(__name__) + + +def _union_concat( + unions: list[exp.Select], + result: exp.Union | exp.Select, + cnt=0, +) -> exp.Select | exp.Union: + if len(unions) == 1: + return result + if cnt == len(unions) - 2: + return exp.union(result, unions[cnt + 1]) + cnt = cnt + 1 + res = exp.union(result, unions[cnt]) + return _union_concat(unions, res, cnt) + + +class SamplingQueryBuilder(QueryBuilder): + def build_query(self, df: DataFrame): + self._validate(self.join_columns, "Join Columns are compulsory for sampling query") + join_columns = self.join_columns if self.join_columns else set() + if self.layer == "source": + key_cols = sorted(join_columns) + else: + key_cols = sorted(self.table_conf.get_tgt_to_src_col_mapping_list(join_columns)) + keys_df = df.select(*key_cols) + with_clause = self._get_with_clause(keys_df) + + cols = sorted((join_columns | self.select_columns) - self.threshold_columns - self.drop_columns) + + cols_with_alias = [ + build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)) + for col in cols + ] + + sql_with_transforms = self.add_transformations(cols_with_alias, self.engine) + query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter) + if self.layer == "source": + with_select = [build_column(this=col, table_name="src") for col in sorted(cols)] + else: + with_select = [ + build_column(this=col, table_name="src") + for col in sorted(self.table_conf.get_tgt_to_src_col_mapping_list(cols)) + ] + + join_clause = SamplingQueryBuilder._get_join_clause(key_cols) + + query = ( + with_clause.with_(alias="src", as_=query_sql) + .select(*with_select) + .from_("src") + .join(join_clause) + .sql(dialect=self.engine) + ) + logger.info(f"Sampling Query for {self.layer}: {query}") + return query + + @classmethod + def _get_join_clause(cls, key_cols: list): + return ( + build_join_clause( + "recon", key_cols, source_table_alias="src", target_table_alias="recon", kind="inner", func=exp.EQ + ) + .transform(coalesce, default="_null_recon_", is_string=True) + .transform(trim) + ) + + def _get_with_clause(self, df: DataFrame) -> exp.Select: + union_res = [] + for row in df.take(_SAMPLE_ROWS): + column_types = [(str(f.name).lower(), f.dataType) for f in df.schema.fields] + column_types_dict = dict(column_types) + row_select = [ + ( + build_literal(this=str(value), alias=col, is_string=_get_is_string(column_types_dict, col)) + if value is not None + else exp.Alias(this=exp.Null(), alias=col) + ) + for col, value in zip(df.columns, row) + ] + if get_key_from_dialect(self.engine) == "oracle": + union_res.append(select(*row_select).from_("dual")) + else: + union_res.append(select(*row_select)) + union_statements = _union_concat(union_res, union_res[0], 0) + return exp.Select().with_(alias='recon', as_=union_statements) diff --git a/src/databricks/labs/remorph/reconcile/query_builder/threshold_query.py b/src/databricks/labs/remorph/reconcile/query_builder/threshold_query.py new file mode 100644 index 0000000000..44bc155f9f --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_builder/threshold_query.py @@ -0,0 +1,231 @@ +import logging + +from sqlglot import expressions as exp +from sqlglot import select + +from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + anonymous, + build_between, + build_column, + build_from_clause, + build_if, + build_join_clause, + build_literal, + build_sub, + build_where_clause, + coalesce, +) +from databricks.labs.remorph.reconcile.recon_config import ColumnThresholds +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks + +logger = logging.getLogger(__name__) + + +class ThresholdQueryBuilder(QueryBuilder): + # Comparison query + def build_comparison_query(self) -> str: + self._validate( + self.table_conf.get_join_columns("source"), "Join Columns are compulsory for threshold comparison query" + ) + join_columns = ( + self.table_conf.get_join_columns("source") if self.table_conf.get_join_columns("source") else set() + ) + select_clause, where = self._generate_select_where_clause(join_columns) + from_clause, join_clause = self._generate_from_and_join_clause(join_columns) + # for threshold comparison query the dialect is always Databricks + query = select(*select_clause).from_(from_clause).join(join_clause).where(where).sql(dialect=Databricks) + logger.info(f"Threshold Comparison query: {query}") + return query + + def _generate_select_where_clause(self, join_columns) -> tuple[list[exp.Expression], exp.Expression]: + thresholds: list[ColumnThresholds] = ( + self.table_conf.column_thresholds if self.table_conf.column_thresholds else [] + ) + select_clause = [] + where_clause = [] + + # threshold columns + for threshold in thresholds: + column = threshold.column_name + base = exp.Paren( + this=build_sub( + left_column_name=column, + left_table_name="source", + right_column_name=column, + right_table_name="databricks", + ) + ).transform(coalesce) + + select_exp, where = self._build_expression_type(threshold, base) + select_clause.extend(select_exp) + where_clause.append(where) + # join columns + for column in sorted(join_columns): + select_clause.append(build_column(this=column, alias=f"{column}_source", table_name="source")) + where = build_where_clause(where_clause) + + return select_clause, where + + @classmethod + def _build_expression_alias_components( + cls, + threshold: ColumnThresholds, + base: exp.Expression, + ) -> tuple[list[exp.Expression], exp.Expression]: + select_clause = [] + column = threshold.column_name + select_clause.append( + build_column(this=column, alias=f"{column}_source", table_name="source").transform(coalesce) + ) + select_clause.append( + build_column(this=column, alias=f"{column}_databricks", table_name="databricks").transform(coalesce) + ) + where_clause = exp.NEQ(this=base, expression=exp.Literal(this="0", is_string=False)) + return select_clause, where_clause + + def _build_expression_type( + self, + threshold: ColumnThresholds, + base: exp.Expression, + ) -> tuple[list[exp.Expression], exp.Expression]: + column = threshold.column_name + # default expressions + select_clause, where_clause = self._build_expression_alias_components(threshold, base) + + if threshold.get_type() in {"number_absolute", "datetime"}: + if threshold.get_type() == "datetime": + # unix_timestamp expression only if it is datetime + select_clause = [expression.transform(anonymous, "unix_timestamp({})") for expression in select_clause] + base = base.transform(anonymous, "unix_timestamp({})") + where_clause = exp.NEQ(this=base, expression=exp.Literal(this="0", is_string=False)) + + # absolute threshold + func = self._build_threshold_absolute_case + elif threshold.get_type() == "number_percentage": + # percentage threshold + func = self._build_threshold_percentage_case + else: + error_message = f"Threshold type {threshold.get_type()} not supported for column {column}" + logger.error(error_message) + raise ValueError(error_message) + + select_clause.append(build_column(this=func(base=base, threshold=threshold), alias=f"{column}_match")) + + return select_clause, where_clause + + def _generate_from_and_join_clause(self, join_columns) -> tuple[exp.From, exp.Join]: + source_view = f"source_{self.table_conf.source_name}_df_threshold_vw" + target_view = f"target_{self.table_conf.target_name}_df_threshold_vw" + + from_clause = build_from_clause(source_view, "source") + join_clause = build_join_clause( + table_name=target_view, + source_table_alias="source", + target_table_alias="databricks", + join_columns=sorted(join_columns), + ) + + return from_clause, join_clause + + @classmethod + def _build_threshold_absolute_case( + cls, + base: exp.Expression, + threshold: ColumnThresholds, + ) -> exp.Case: + eq_if = build_if( + this=exp.EQ(this=base, expression=build_literal(this="0", is_string=False)), + true=exp.Literal(this="Match", is_string=True), + ) + + between_base = build_between( + this=base, + low=build_literal(threshold.lower_bound.replace("%", ""), is_string=False), + high=build_literal(threshold.upper_bound.replace("%", ""), is_string=False), + ) + + between_if = build_if( + this=between_base, + true=exp.Literal(this="Warning", is_string=True), + ) + return exp.Case(ifs=[eq_if, between_if], default=exp.Literal(this="Failed", is_string=True)) + + @classmethod + def _build_threshold_percentage_case( + cls, + base: exp.Expression, + threshold: ColumnThresholds, + ) -> exp.Case: + eq_if = exp.If( + this=exp.EQ(this=base, expression=build_literal(this="0", is_string=False)), + true=exp.Literal(this="Match", is_string=True), + ) + + denominator = build_if( + this=exp.Or( + this=exp.EQ( + this=exp.Column(this=threshold.column_name, table="databricks"), + expression=exp.Literal(this='0', is_string=False), + ), + expression=exp.Is( + this=exp.Column( + this=exp.Identifier(this=threshold.column_name, quoted=False), + table=exp.Identifier(this='databricks'), + ), + expression=exp.Null(), + ), + ), + true=exp.Literal(this="1", is_string=False), + false=exp.Column(this=threshold.column_name, table="databricks"), + ) + + division = exp.Div(this=base, expression=denominator, typed=False, safe=False) + percentage = exp.Mul(this=exp.Paren(this=division), expression=exp.Literal(this="100", is_string=False)) + between_base = build_between( + this=percentage, + low=build_literal(threshold.lower_bound.replace("%", ""), is_string=False), + high=build_literal(threshold.upper_bound.replace("%", ""), is_string=False), + ) + + between_if = build_if( + this=between_base, + true=exp.Literal(this="Warning", is_string=True), + ) + return exp.Case(ifs=[eq_if, between_if], default=exp.Literal(this="Failed", is_string=True)) + + def build_threshold_query(self) -> str: + """ + This method builds a threshold query based on the configuration of the table and the columns involved. + + The query is constructed by selecting the necessary columns (partition, join, and threshold columns) + from a specified table. Any transformations specified in the table configuration are applied to the + selected columns. The query also includes a WHERE clause based on the filter defined in the table configuration. + + The resulting query is then converted to a SQL string using the dialect of the source database. + + Returns: + str: The SQL string representation of the threshold query. + """ + # key column expression + self._validate(self.join_columns, "Join Columns are compulsory for threshold query") + join_columns = self.join_columns if self.join_columns else set() + keys: list[str] = sorted(self.partition_column.union(join_columns)) + keys_select_alias = [ + build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)) + for col in keys + ] + keys_expr = self._apply_user_transformation(keys_select_alias) + + # threshold column expression + threshold_alias = [ + build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)) + for col in sorted(self.threshold_columns) + ] + thresholds_expr = threshold_alias + if self.user_transformations: + thresholds_expr = self._apply_user_transformation(threshold_alias) + + query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter)).sql(dialect=self.engine) + logger.info(f"Threshold Query for {self.layer}: {query}") + return query diff --git a/src/databricks/labs/remorph/reconcile/recon_capture.py b/src/databricks/labs/remorph/reconcile/recon_capture.py new file mode 100644 index 0000000000..31b873c445 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/recon_capture.py @@ -0,0 +1,635 @@ +import logging +from datetime import datetime +from functools import reduce + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col, collect_list, create_map, lit +from pyspark.sql.types import StringType, StructField, StructType +from pyspark.errors import PySparkException +from sqlglot import Dialect + +from databricks.labs.remorph.config import DatabaseConfig, Table, get_key_from_dialect, ReconcileMetadataConfig +from databricks.labs.remorph.reconcile.exception import ( + WriteToTableException, + ReadAndWriteWithVolumeException, + CleanFromVolumeException, +) +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + ReconcileOutput, + ReconcileProcessDuration, + ReconcileTableOutput, + SchemaReconcileOutput, + StatusOutput, + TableThresholds, + ReconcileRecordCount, + AggregateQueryOutput, +) +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + +_RECON_TABLE_NAME = "main" +_RECON_METRICS_TABLE_NAME = "metrics" +_RECON_DETAILS_TABLE_NAME = "details" +_RECON_AGGREGATE_RULES_TABLE_NAME = "aggregate_rules" +_RECON_AGGREGATE_METRICS_TABLE_NAME = "aggregate_metrics" +_RECON_AGGREGATE_DETAILS_TABLE_NAME = "aggregate_details" +_SAMPLE_ROWS = 50 + + +class ReconIntermediatePersist: + + def __init__(self, spark: SparkSession, path: str): + self.spark = spark + self.path = path + + def _write_unmatched_df_to_volumes( + self, + unmatched_df: DataFrame, + ) -> None: + unmatched_df.write.format("parquet").mode("overwrite").save(self.path) + + def _read_unmatched_df_from_volumes(self) -> DataFrame: + return self.spark.read.format("parquet").load(self.path) + + def clean_unmatched_df_from_volume(self): + try: + # TODO: for now we are overwriting the intermediate cache path. We should delete the volume in future + # workspace_client.dbfs.get_status(path) + # workspace_client.dbfs.delete(path, recursive=True) + empty_df = self.spark.createDataFrame([], schema=StructType([StructField("empty", StringType(), True)])) + empty_df.write.format("parquet").mode("overwrite").save(self.path) + logger.warning(f"Unmatched DF cleaned up from {self.path} successfully.") + except PySparkException as e: + message = f"Error cleaning up unmatched DF from {self.path} volumes --> {e}" + logger.error(message) + raise CleanFromVolumeException(message) from e + + def write_and_read_unmatched_df_with_volumes( + self, + unmatched_df: DataFrame, + ) -> DataFrame: + try: + self._write_unmatched_df_to_volumes(unmatched_df) + return self._read_unmatched_df_from_volumes() + except PySparkException as e: + message = f"Exception in reading or writing unmatched DF with volumes {self.path} --> {e}" + logger.error(message) + raise ReadAndWriteWithVolumeException(message) from e + + +def _write_df_to_delta(df: DataFrame, table_name: str, mode="append"): + try: + df.write.mode(mode).saveAsTable(table_name) + logger.info(f"Data written to {table_name} successfully.") + except Exception as e: + message = f"Error writing data to {table_name}: {e}" + logger.error(message) + raise WriteToTableException(message) from e + + +def generate_final_reconcile_output( + recon_id: str, + spark: SparkSession, + metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), + local_test_run: bool = False, +) -> ReconcileOutput: + _db_prefix = "default" if local_test_run else f"{metadata_config.catalog}.{metadata_config.schema}" + recon_df = spark.sql( + f""" + SELECT + CASE + WHEN COALESCE(MAIN.SOURCE_TABLE.CATALOG, '') <> '' THEN CONCAT(MAIN.SOURCE_TABLE.CATALOG, '.', MAIN.SOURCE_TABLE.SCHEMA, '.', MAIN.SOURCE_TABLE.TABLE_NAME) + ELSE CONCAT(MAIN.SOURCE_TABLE.SCHEMA, '.', MAIN.SOURCE_TABLE.TABLE_NAME) + END AS SOURCE_TABLE, + CONCAT(MAIN.TARGET_TABLE.CATALOG, '.', MAIN.TARGET_TABLE.SCHEMA, '.', MAIN.TARGET_TABLE.TABLE_NAME) AS TARGET_TABLE, + CASE WHEN lower(MAIN.report_type) in ('all', 'row', 'data') THEN + CASE + WHEN METRICS.recon_metrics.row_comparison.missing_in_source = 0 AND METRICS.recon_metrics.row_comparison.missing_in_target = 0 THEN TRUE + ELSE FALSE + END + ELSE NULL END AS ROW, + CASE WHEN lower(MAIN.report_type) in ('all', 'data') THEN + CASE + WHEN (METRICS.run_metrics.status = true) or + (METRICS.recon_metrics.column_comparison.absolute_mismatch = 0 AND METRICS.recon_metrics.column_comparison.threshold_mismatch = 0 AND METRICS.recon_metrics.column_comparison.mismatch_columns = '') THEN TRUE + ELSE FALSE + END + ELSE NULL END AS COLUMN, + CASE WHEN lower(MAIN.report_type) in ('all', 'schema') THEN + CASE + WHEN METRICS.recon_metrics.schema_comparison = true THEN TRUE + ELSE FALSE + END + ELSE NULL END AS SCHEMA, + METRICS.run_metrics.exception_message AS EXCEPTION_MESSAGE + FROM + {_db_prefix}.{_RECON_TABLE_NAME} MAIN + INNER JOIN + {_db_prefix}.{_RECON_METRICS_TABLE_NAME} METRICS + ON + (MAIN.recon_table_id = METRICS.recon_table_id) + WHERE + MAIN.recon_id = '{recon_id}' + """ + ) + table_output = [] + for row in recon_df.collect(): + if row.EXCEPTION_MESSAGE is not None and row.EXCEPTION_MESSAGE != "": + table_output.append( + ReconcileTableOutput( + target_table_name=row.TARGET_TABLE, + source_table_name=row.SOURCE_TABLE, + status=StatusOutput(), + exception_message=row.EXCEPTION_MESSAGE, + ) + ) + else: + table_output.append( + ReconcileTableOutput( + target_table_name=row.TARGET_TABLE, + source_table_name=row.SOURCE_TABLE, + status=StatusOutput(row=row.ROW, column=row.COLUMN, schema=row.SCHEMA), + exception_message=row.EXCEPTION_MESSAGE, + ) + ) + final_reconcile_output = ReconcileOutput(recon_id=recon_id, results=table_output) + logger.info(f"Final reconcile output: {final_reconcile_output}") + return final_reconcile_output + + +def generate_final_reconcile_aggregate_output( + recon_id: str, + spark: SparkSession, + metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), + local_test_run: bool = False, +) -> ReconcileOutput: + _db_prefix = "default" if local_test_run else f"{metadata_config.catalog}.{metadata_config.schema}" + recon_df = spark.sql( + f""" + SELECT source_table, + target_table, + EVERY(status) AS status, + ARRAY_JOIN(COLLECT_SET(exception_message), '\n') AS exception_message + FROM + (SELECT + IF(ISNULL(main.source_table.catalog) + , CONCAT_WS('.', main.source_table.schema, main.source_table.table_name) + , CONCAT_WS('.', main.source_table.catalog, main.source_table.schema, main.source_table.table_name)) AS source_table, + CONCAT_WS('.', main.target_table.catalog, main.target_table.schema, main.target_table.table_name) AS target_table, + IF(metrics.run_metrics.status='true', TRUE , FALSE) AS status, + metrics.run_metrics.exception_message AS exception_message + FROM + {_db_prefix}.{_RECON_TABLE_NAME} main + INNER JOIN + {_db_prefix}.{_RECON_AGGREGATE_METRICS_TABLE_NAME} metrics + ON + (MAIN.recon_table_id = METRICS.recon_table_id + AND MAIN.operation_name = 'aggregates-reconcile') + WHERE + MAIN.recon_id = '{recon_id}' + ) + GROUP BY source_table, target_table; + """ + ) + table_output = [] + for row in recon_df.collect(): + if row.exception_message is not None and row.exception_message != "": + table_output.append( + ReconcileTableOutput( + target_table_name=row.target_table, + source_table_name=row.source_table, + status=StatusOutput(), + exception_message=row.exception_message, + ) + ) + else: + table_output.append( + ReconcileTableOutput( + target_table_name=row.target_table, + source_table_name=row.source_table, + status=StatusOutput(aggregate=row.status), + exception_message=row.exception_message, + ) + ) + final_reconcile_output = ReconcileOutput(recon_id=recon_id, results=table_output) + logger.info(f"Final reconcile output: {final_reconcile_output}") + return final_reconcile_output + + +class ReconCapture: + + def __init__( + self, + database_config: DatabaseConfig, + recon_id: str, + report_type: str, + source_dialect: Dialect, + ws: WorkspaceClient, + spark: SparkSession, + metadata_config: ReconcileMetadataConfig = ReconcileMetadataConfig(), + local_test_run: bool = False, + ): + self.database_config = database_config + self.recon_id = recon_id + self.report_type = report_type + self.source_dialect = source_dialect + self.ws = ws + self.spark = spark + self._db_prefix = "default" if local_test_run else f"{metadata_config.catalog}.{metadata_config.schema}" + + def _generate_recon_main_id( + self, + table_conf: Table, + ) -> int: + full_source_table = ( + f"{self.database_config.source_schema}.{table_conf.source_name}" + if self.database_config.source_catalog is None + else f"{self.database_config.source_catalog}.{self.database_config.source_schema}.{table_conf.source_name}" + ) + full_target_table = ( + f"{self.database_config.target_catalog}.{self.database_config.target_schema}.{table_conf.target_name}" + ) + return hash(f"{self.recon_id}{full_source_table}{full_target_table}") + + def _insert_into_main_table( + self, + recon_table_id: int, + table_conf: Table, + recon_process_duration: ReconcileProcessDuration, + operation_name: str = "reconcile", + ) -> None: + source_dialect_key = get_key_from_dialect(self.source_dialect) + df = self.spark.sql( + f""" + select {recon_table_id} as recon_table_id, + '{self.recon_id}' as recon_id, + case + when '{source_dialect_key}' = 'databricks' then 'Databricks' + when '{source_dialect_key}' = 'snowflake' then 'Snowflake' + when '{source_dialect_key}' = 'oracle' then 'Oracle' + else '{source_dialect_key}' + end as source_type, + named_struct( + 'catalog', case when '{self.database_config.source_catalog}' = 'None' then null else '{self.database_config.source_catalog}' end, + 'schema', '{self.database_config.source_schema}', + 'table_name', '{table_conf.source_name}' + ) as source_table, + named_struct( + 'catalog', '{self.database_config.target_catalog}', + 'schema', '{self.database_config.target_schema}', + 'table_name', '{table_conf.target_name}' + ) as target_table, + '{self.report_type}' as report_type, + '{operation_name}' as operation_name, + cast('{recon_process_duration.start_ts}' as timestamp) as start_ts, + cast('{recon_process_duration.end_ts}' as timestamp) as end_ts + """ + ) + _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_TABLE_NAME}") + + @classmethod + def _is_mismatch_within_threshold_limits( + cls, data_reconcile_output: DataReconcileOutput, table_conf: Table, record_count: ReconcileRecordCount + ): + total_mismatch_count = ( + data_reconcile_output.mismatch_count + data_reconcile_output.threshold_output.threshold_mismatch_count + ) + logger.info(f"total_mismatch_count : {total_mismatch_count}") + logger.warning(f"reconciled_record_count : {record_count}") + # if the mismatch count is 0 then no need of checking bounds. + if total_mismatch_count == 0: + return True + # pull out table thresholds + thresholds: list[TableThresholds] = ( + [threshold for threshold in table_conf.table_thresholds if threshold.model == "mismatch"] + if table_conf.table_thresholds + else [] + ) + # if not table thresholds are provided return false + if not thresholds: + return False + + res = None + for threshold in thresholds: + mode = threshold.get_mode() + lower_bound = int(threshold.lower_bound.replace("%", "")) + upper_bound = int(threshold.upper_bound.replace("%", "")) + if mode == "absolute": + res = lower_bound <= total_mismatch_count <= upper_bound + if mode == "percentage": + lower_bound = int(round((lower_bound / 100) * record_count.source)) + upper_bound = int(round((upper_bound / 100) * record_count.source)) + res = lower_bound <= total_mismatch_count <= upper_bound + + return res + + def _insert_into_metrics_table( + self, + recon_table_id: int, + data_reconcile_output: DataReconcileOutput, + schema_reconcile_output: SchemaReconcileOutput, + table_conf: Table, + record_count: ReconcileRecordCount, + ) -> None: + status = False + if data_reconcile_output.exception in {None, ''} and schema_reconcile_output.exception in {None, ''}: + status = ( + # validate for both exact mismatch and threshold mismatch + self._is_mismatch_within_threshold_limits( + data_reconcile_output=data_reconcile_output, table_conf=table_conf, record_count=record_count + ) + and data_reconcile_output.missing_in_src_count == 0 + and data_reconcile_output.missing_in_tgt_count == 0 + and schema_reconcile_output.is_valid + ) + + exception_msg = "" + if schema_reconcile_output.exception is not None: + exception_msg = schema_reconcile_output.exception.replace("'", '').replace('"', '') + if data_reconcile_output.exception is not None: + exception_msg = data_reconcile_output.exception.replace("'", '').replace('"', '') + + insertion_time = str(datetime.now()) + mismatch_columns = [] + if data_reconcile_output.mismatch and data_reconcile_output.mismatch.mismatch_columns: + mismatch_columns = data_reconcile_output.mismatch.mismatch_columns + + df = self.spark.sql( + f""" + select {recon_table_id} as recon_table_id, + named_struct( + 'row_comparison', case when '{self.report_type.lower()}' in ('all', 'row', 'data') + and '{exception_msg}' = '' then + named_struct( + 'missing_in_source', cast({data_reconcile_output.missing_in_src_count} as bigint), + 'missing_in_target', cast({data_reconcile_output.missing_in_tgt_count} as bigint) + ) else null end, + 'column_comparison', case when '{self.report_type.lower()}' in ('all', 'data') + and '{exception_msg}' = '' then + named_struct( + 'absolute_mismatch', cast({data_reconcile_output.mismatch_count} as bigint), + 'threshold_mismatch', cast({data_reconcile_output.threshold_output.threshold_mismatch_count} as bigint), + 'mismatch_columns', '{",".join(mismatch_columns)}' + ) else null end, + 'schema_comparison', case when '{self.report_type.lower()}' in ('all', 'schema') + and '{exception_msg}' = '' then + {schema_reconcile_output.is_valid} else null end + ) as recon_metrics, + named_struct( + 'status', {status}, + 'run_by_user', '{self.ws.current_user.me().user_name}', + 'exception_message', "{exception_msg}" + ) as run_metrics, + cast('{insertion_time}' as timestamp) as inserted_ts + """ + ) + _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_METRICS_TABLE_NAME}") + + @classmethod + def _create_map_column( + cls, + recon_table_id: int, + df: DataFrame, + recon_type: str, + status: bool, + ) -> DataFrame: + columns = df.columns + # Create a list of column names and their corresponding column values + map_args = [] + for column in columns: + map_args.extend([lit(column).alias(column + "_key"), col(column).cast("string").alias(column + "_value")]) + # Create a new DataFrame with a map column + df = df.limit(_SAMPLE_ROWS).select(create_map(*map_args).alias("data")) + df = ( + df.withColumn("recon_table_id", lit(recon_table_id)) + .withColumn("recon_type", lit(recon_type)) + .withColumn("status", lit(status)) + .withColumn("inserted_ts", lit(datetime.now())) + ) + return ( + df.groupBy("recon_table_id", "recon_type", "status", "inserted_ts") + .agg(collect_list("data").alias("data")) + .selectExpr("recon_table_id", "recon_type", "status", "data", "inserted_ts") + ) + + def _create_map_column_and_insert( + self, + recon_table_id: int, + df: DataFrame, + recon_type: str, + status: bool, + ) -> None: + df = self._create_map_column(recon_table_id, df, recon_type, status) + _write_df_to_delta(df, f"{self._db_prefix}.{_RECON_DETAILS_TABLE_NAME}") + + def _insert_into_details_table( + self, + recon_table_id: int, + reconcile_output: DataReconcileOutput, + schema_output: SchemaReconcileOutput, + ): + if reconcile_output.mismatch_count > 0 and reconcile_output.mismatch.mismatch_df: + self._create_map_column_and_insert( + recon_table_id, + reconcile_output.mismatch.mismatch_df, + "mismatch", + False, + ) + + if reconcile_output.missing_in_src_count > 0 and reconcile_output.missing_in_src: + self._create_map_column_and_insert( + recon_table_id, + reconcile_output.missing_in_src, + "missing_in_source", + False, + ) + + if reconcile_output.missing_in_tgt_count > 0 and reconcile_output.missing_in_tgt: + self._create_map_column_and_insert( + recon_table_id, + reconcile_output.missing_in_tgt, + "missing_in_target", + False, + ) + + if ( + reconcile_output.threshold_output.threshold_mismatch_count > 0 + and reconcile_output.threshold_output.threshold_df + ): + self._create_map_column_and_insert( + recon_table_id, + reconcile_output.threshold_output.threshold_df, + "threshold_mismatch", + False, + ) + + if schema_output.compare_df is not None: + self._create_map_column_and_insert( + recon_table_id, schema_output.compare_df, "schema", schema_output.is_valid + ) + + def _get_df( + self, + recon_table_id: int, + agg_data: DataReconcileOutput, + recon_type: str, + ): + + column_count = agg_data.mismatch_count + agg_df = agg_data.mismatch.mismatch_df + match recon_type: + case "missing_in_source": + column_count = agg_data.missing_in_src_count + agg_df = agg_data.missing_in_src + case "missing_in_target": + column_count = agg_data.missing_in_tgt_count + agg_df = agg_data.missing_in_tgt + + if column_count > 0 and agg_df: + return self._create_map_column( + recon_table_id, + agg_df, + recon_type, + False, + ) + return None + + @classmethod + def _union_dataframes(cls, df_list: list[DataFrame]) -> DataFrame: + return reduce(lambda agg_df, df: agg_df.unionByName(df), df_list) + + def _insert_aggregates_into_metrics_table( + self, + recon_table_id: int, + reconcile_agg_output_list: list[AggregateQueryOutput], + ) -> None: + + agg_metrics_df_list = [] + for agg_output in reconcile_agg_output_list: + agg_data = agg_output.reconcile_output + + status = False + if agg_data.exception in {None, ''}: + status = not ( + agg_data.mismatch_count > 0 + or agg_data.missing_in_src_count > 0 + or agg_data.missing_in_tgt_count > 0 + ) + + exception_msg = "" + if agg_data.exception is not None: + exception_msg = agg_data.exception.replace("'", '').replace('"', '') + + insertion_time = str(datetime.now()) + + # If there is any exception while running the Query, + # each rule is stored, with the Exception message in the metrics table + assert agg_output.rule, "Aggregate Rule must be present for storing the metrics" + rule_id = hash(f"{recon_table_id}_{agg_output.rule.column_from_rule}") + + agg_metrics_df = self.spark.sql( + f""" + select {recon_table_id} as recon_table_id, + {rule_id} as rule_id, + if('{exception_msg}' = '', named_struct( + 'missing_in_source', {agg_data.missing_in_src_count}, + 'missing_in_target', {agg_data.missing_in_tgt_count}, + 'mismatch', {agg_data.mismatch_count} + ), null) as recon_metrics, + named_struct( + 'status', {status}, + 'run_by_user', '{self.ws.current_user.me().user_name}', + 'exception_message', "{exception_msg}" + ) as run_metrics, + cast('{insertion_time}' as timestamp) as inserted_ts + """ + ) + agg_metrics_df_list.append(agg_metrics_df) + + agg_metrics_table_df = self._union_dataframes(agg_metrics_df_list) + _write_df_to_delta(agg_metrics_table_df, f"{self._db_prefix}.{_RECON_AGGREGATE_METRICS_TABLE_NAME}") + + def _insert_aggregates_into_details_table( + self, recon_table_id: int, reconcile_agg_output_list: list[AggregateQueryOutput] + ): + agg_details_df_list = [] + for agg_output in reconcile_agg_output_list: + agg_details_rule_df_list = [] + + mismatch_df = self._get_df(recon_table_id, agg_output.reconcile_output, "mismatch") + if mismatch_df and not mismatch_df.isEmpty(): + agg_details_rule_df_list.append(mismatch_df) + + missing_src_df = self._get_df(recon_table_id, agg_output.reconcile_output, "missing_in_source") + if missing_src_df and not missing_src_df.isEmpty(): + agg_details_rule_df_list.append(missing_src_df) + + missing_tgt_df = self._get_df(recon_table_id, agg_output.reconcile_output, "missing_in_target") + if missing_tgt_df and not missing_tgt_df.isEmpty(): + agg_details_rule_df_list.append(missing_tgt_df) + + if agg_details_rule_df_list: + agg_details_rule_df = self._union_dataframes(agg_details_rule_df_list) + if agg_output.rule: + rule_id = hash(f"{recon_table_id}_{agg_output.rule.column_from_rule}") + agg_details_rule_df = agg_details_rule_df.withColumn("rule_id", lit(rule_id)).select( + "recon_table_id", "rule_id", "recon_type", "data", "inserted_ts" + ) + agg_details_df_list.append(agg_details_rule_df) + else: + logger.warning("Aggregate Details Rules are empty") + + if agg_details_df_list: + agg_details_table_df = self._union_dataframes(agg_details_df_list) + _write_df_to_delta(agg_details_table_df, f"{self._db_prefix}.{_RECON_AGGREGATE_DETAILS_TABLE_NAME}") + + def start( + self, + data_reconcile_output: DataReconcileOutput, + schema_reconcile_output: SchemaReconcileOutput, + table_conf: Table, + recon_process_duration: ReconcileProcessDuration, + record_count: ReconcileRecordCount, + ) -> None: + recon_table_id = self._generate_recon_main_id(table_conf) + self._insert_into_main_table(recon_table_id, table_conf, recon_process_duration) + self._insert_into_metrics_table( + recon_table_id, data_reconcile_output, schema_reconcile_output, table_conf, record_count + ) + self._insert_into_details_table(recon_table_id, data_reconcile_output, schema_reconcile_output) + + def store_aggregates_metrics( + self, + table_conf: Table, + recon_process_duration: ReconcileProcessDuration, + reconcile_agg_output_list: list[AggregateQueryOutput], + ) -> None: + recon_table_id = self._generate_recon_main_id(table_conf) + self._insert_into_main_table(recon_table_id, table_conf, recon_process_duration, 'aggregates-reconcile') + self._insert_into_rules_table(recon_table_id, reconcile_agg_output_list) + self._insert_aggregates_into_metrics_table(recon_table_id, reconcile_agg_output_list) + self._insert_aggregates_into_details_table( + recon_table_id, + reconcile_agg_output_list, + ) + + def _insert_into_rules_table(self, recon_table_id: int, reconcile_agg_output_list: list[AggregateQueryOutput]): + + rule_df_list = [] + for agg_output in reconcile_agg_output_list: + if not agg_output.rule: + logger.error("Aggregate Rule must be present for storing the rules") + continue + rule_id = hash(f"{recon_table_id}_{agg_output.rule.column_from_rule}") + rule_query = agg_output.rule.get_rule_query(rule_id) + rule_df_list.append( + self.spark.sql(rule_query) + .withColumn("inserted_ts", lit(datetime.now())) + .select("rule_id", "rule_type", "rule_info", "inserted_ts") + ) + + if rule_df_list: + rules_table_df = self._union_dataframes(rule_df_list) + _write_df_to_delta(rules_table_df, f"{self._db_prefix}.{_RECON_AGGREGATE_RULES_TABLE_NAME}") diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py new file mode 100644 index 0000000000..be76044f29 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass, field + +from pyspark.sql import DataFrame +from sqlglot import expressions as exp + +logger = logging.getLogger(__name__) + +_SUPPORTED_AGG_TYPES: set[str] = { + "min", + "max", + "count", + "sum", + "avg", + "mean", + "mode", + "stddev", + "variance", + "median", +} + + +class TableThresholdBoundsException(ValueError): + """Raise the error when the bounds for table threshold are invalid""" + + +class InvalidModelForTableThreshold(ValueError): + """Raise the error when the model for table threshold is invalid""" + + +@dataclass +class JdbcReaderOptions: + number_partitions: int + partition_column: str + lower_bound: str + upper_bound: str + fetch_size: int = 100 + + def __post_init__(self): + self.partition_column = self.partition_column.lower() + + +@dataclass +class ColumnMapping: + source_name: str + target_name: str + + def __post_init__(self): + self.source_name = self.source_name.lower() + self.target_name = self.target_name.lower() + + +@dataclass +class Transformation: + column_name: str + source: str | None = None + target: str | None = None + + def __post_init__(self): + self.column_name = self.column_name.lower() + + +@dataclass +class ColumnThresholds: + column_name: str + lower_bound: str + upper_bound: str + type: str + + def __post_init__(self): + self.column_name = self.column_name.lower() + self.type = self.type.lower() + + def get_mode(self): + return "percentage" if "%" in self.lower_bound or "%" in self.upper_bound else "absolute" + + def get_type(self): + if any(self.type in numeric_type.value.lower() for numeric_type in exp.DataType.NUMERIC_TYPES): + if self.get_mode() == "absolute": + return "number_absolute" + return "number_percentage" + + if any(self.type in numeric_type.value.lower() for numeric_type in exp.DataType.TEMPORAL_TYPES): + return "datetime" + return None + + +@dataclass +class TableThresholds: + lower_bound: str + upper_bound: str + model: str + + def __post_init__(self): + self.model = self.model.lower() + self.validate_threshold_bounds() + self.validate_threshold_model() + + def get_mode(self): + return "percentage" if "%" in self.lower_bound or "%" in self.upper_bound else "absolute" + + def validate_threshold_bounds(self): + lower_bound = int(self.lower_bound.replace("%", "")) + upper_bound = int(self.upper_bound.replace("%", "")) + if lower_bound < 0 or upper_bound < 0: + raise TableThresholdBoundsException("Threshold bounds for table cannot be negative.") + if lower_bound > upper_bound: + raise TableThresholdBoundsException("Lower bound cannot be greater than upper bound.") + + def validate_threshold_model(self): + if self.model not in ["mismatch"]: + raise InvalidModelForTableThreshold( + f"Invalid model for Table Threshold: expected 'mismatch', but got '{self.model}'." + ) + + +@dataclass +class Filters: + source: str | None = None + target: str | None = None + + +def to_lower_case(input_list: list[str]) -> list[str]: + return [element.lower() for element in input_list] + + +@dataclass +class Table: + source_name: str + target_name: str + aggregates: list[Aggregate] | None = None + join_columns: list[str] | None = None + jdbc_reader_options: JdbcReaderOptions | None = None + select_columns: list[str] | None = None + drop_columns: list[str] | None = None + column_mapping: list[ColumnMapping] | None = None + transformations: list[Transformation] | None = None + column_thresholds: list[ColumnThresholds] | None = None + filters: Filters | None = None + table_thresholds: list[TableThresholds] | None = None + + def __post_init__(self): + self.source_name = self.source_name.lower() + self.target_name = self.target_name.lower() + self.select_columns = to_lower_case(self.select_columns) if self.select_columns else None + self.drop_columns = to_lower_case(self.drop_columns) if self.drop_columns else None + self.join_columns = to_lower_case(self.join_columns) if self.join_columns else None + + @property + def to_src_col_map(self): + if self.column_mapping: + return {c.source_name: c.target_name for c in self.column_mapping} + return None + + @property + def to_tgt_col_map(self): + if self.column_mapping: + return {c.target_name: c.source_name for c in self.column_mapping} + return None + + def get_src_to_tgt_col_mapping_list(self, cols: list[str], layer: str) -> set[str]: + if layer == "source": + return set(cols) + if self.to_src_col_map: + return {self.to_src_col_map.get(col, col) for col in cols} + return set(cols) + + def get_layer_src_to_tgt_col_mapping(self, column_name: str, layer: str) -> str: + if layer == "source": + return column_name + if self.to_src_col_map: + return self.to_src_col_map.get(column_name, column_name) + return column_name + + def get_tgt_to_src_col_mapping_list(self, cols: list[str] | set[str]) -> set[str]: + if self.to_tgt_col_map: + return {self.to_tgt_col_map.get(col, col) for col in cols} + return set(cols) + + def get_layer_tgt_to_src_col_mapping(self, column_name: str, layer: str) -> str: + if layer == "source": + return column_name + if self.to_tgt_col_map: + return self.to_tgt_col_map.get(column_name, column_name) + return column_name + + def get_select_columns(self, schema: list[Schema], layer: str) -> set[str]: + if self.select_columns is None: + return {sch.column_name for sch in schema} + if self.to_src_col_map: + return self.get_src_to_tgt_col_mapping_list(self.select_columns, layer) + return set(self.select_columns) + + def get_threshold_columns(self, layer: str) -> set[str]: + if self.column_thresholds is None: + return set() + return {self.get_layer_src_to_tgt_col_mapping(thresh.column_name, layer) for thresh in self.column_thresholds} + + def get_join_columns(self, layer: str) -> set[str] | None: + if self.join_columns is None: + return None + return {self.get_layer_src_to_tgt_col_mapping(col, layer) for col in self.join_columns} + + def get_drop_columns(self, layer: str) -> set[str]: + if self.drop_columns is None: + return set() + return {self.get_layer_src_to_tgt_col_mapping(col, layer) for col in self.drop_columns} + + def get_transformation_dict(self, layer: str) -> dict[str, str]: + if self.transformations: + if layer == "source": + return { + trans.column_name: (trans.source if trans.source else trans.column_name) + for trans in self.transformations + } + return { + self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer): ( + trans.target if trans.target else self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer) + ) + for trans in self.transformations + } + return {} + + def get_partition_column(self, layer: str) -> set[str]: + if self.jdbc_reader_options and layer == "source": + return {self.jdbc_reader_options.partition_column} + return set() + + def get_filter(self, layer: str) -> str | None: + if self.filters is None: + return None + if layer == "source": + return self.filters.source + return self.filters.target + + +@dataclass +class Schema: + column_name: str + data_type: str + + +@dataclass +class MismatchOutput: + mismatch_df: DataFrame | None = None + mismatch_columns: list[str] | None = None + + +@dataclass +class ThresholdOutput: + threshold_df: DataFrame | None = None + threshold_mismatch_count: int = 0 + + +@dataclass +class DataReconcileOutput: + mismatch_count: int = 0 + missing_in_src_count: int = 0 + missing_in_tgt_count: int = 0 + mismatch: MismatchOutput = field(default_factory=MismatchOutput) + missing_in_src: DataFrame | None = None + missing_in_tgt: DataFrame | None = None + threshold_output: ThresholdOutput = field(default_factory=ThresholdOutput) + exception: str | None = None + + +@dataclass +class HashAlgoMapping: + source: Callable + target: Callable + + +@dataclass +class SchemaMatchResult: + source_column: str + source_datatype: str + databricks_column: str + databricks_datatype: str + is_valid: bool = True + + +@dataclass +class SchemaReconcileOutput: + is_valid: bool + compare_df: DataFrame | None = None + exception: str | None = None + + +@dataclass +class ReconcileProcessDuration: + start_ts: str + end_ts: str | None + + +@dataclass +class StatusOutput: + row: bool | None = None + column: bool | None = None + schema: bool | None = None + aggregate: bool | None = None + + +@dataclass +class ReconcileTableOutput: + target_table_name: str + source_table_name: str + status: StatusOutput = field(default_factory=StatusOutput) + exception_message: str | None = None + + +@dataclass +class ReconcileOutput: + recon_id: str + results: list[ReconcileTableOutput] + + +@dataclass +class ReconcileRecordCount: + source: int = 0 + target: int = 0 + + +@dataclass +class Aggregate: + agg_columns: list[str] + type: str + group_by_columns: list[str] | None = None + + def __post_init__(self): + self.agg_columns = to_lower_case(self.agg_columns) + self.type = self.type.lower() + self.group_by_columns = to_lower_case(self.group_by_columns) if self.group_by_columns else None + assert ( + self.type in _SUPPORTED_AGG_TYPES + ), f"Invalid aggregate type: {self.type}, only {_SUPPORTED_AGG_TYPES} are supported." + + def get_agg_type(self): + return self.type + + @classmethod + def _join_columns(cls, columns: list[str]): + return "+__+".join(columns) + + @property + def group_by_columns_as_str(self): + return self._join_columns(self.group_by_columns) if self.group_by_columns else "NA" + + @property + def agg_columns_as_str(self): + return self._join_columns(self.agg_columns) + + +@dataclass +class AggregateRule: + agg_type: str + agg_column: str + group_by_columns: list[str] | None + group_by_columns_as_str: str + rule_type: str = "AGGREGATE" + + @property + def column_from_rule(self): + # creates rule_column. e.g., min_col1_grp1_grp2 + return f"{self.agg_type}_{self.agg_column}_{self.group_by_columns_as_str}" + + @property + def group_by_columns_as_table_column(self): + # If group_by_columns are not defined, store is as null + group_by_cols_as_table_col = "NULL" + if self.group_by_columns: + # Sort the columns, convert to lower case and create a string: , e.g., grp1, grp2 + formatted_cols = ", ".join([f"{col.lower()}" for col in sorted(self.group_by_columns)]) + group_by_cols_as_table_col = f"\"{formatted_cols}\"" + return group_by_cols_as_table_col + + def get_rule_query(self, rule_id): + rule_info = f""" map( 'agg_type', '{self.agg_type}', + 'agg_column', '{self.agg_column}', + 'group_by_columns', {self.group_by_columns_as_table_column} + ) + """ + return f" SELECT {rule_id} as rule_id, " f" '{self.rule_type}' as rule_type, " f" {rule_info} as rule_info " + + +@dataclass +class AggregateQueryRules: + layer: str + group_by_columns: list[str] | None + group_by_columns_as_str: str + query: str + rules: list[AggregateRule] + + +@dataclass +class AggregateQueryOutput: + rule: AggregateRule | None + reconcile_output: DataReconcileOutput diff --git a/src/databricks/labs/remorph/reconcile/runner.py b/src/databricks/labs/remorph/reconcile/runner.py new file mode 100644 index 0000000000..ae39d814ec --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/runner.py @@ -0,0 +1,97 @@ +import logging +import webbrowser + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installation import SerdeError +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.tui import Prompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, PermissionDenied + +from databricks.labs.remorph.config import ReconcileConfig, TableRecon +from databricks.labs.remorph.deployment.recon import RECON_JOB_NAME +from databricks.labs.remorph.reconcile.execute import RECONCILE_OPERATION_NAME + +logger = logging.getLogger(__name__) + +_RECON_README_URL = "https://github.com/databrickslabs/remorph/blob/main/docs/recon_configurations/README.md" + + +class ReconcileRunner: + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + prompts: Prompts, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + self._prompts = prompts + + def run(self, operation_name=RECONCILE_OPERATION_NAME): + reconcile_config = self._get_verified_recon_config() + job_id = self._get_recon_job_id(reconcile_config) + logger.info(f"Triggering the reconcile job with job_id: `{job_id}`") + wait = self._ws.jobs.run_now(job_id, job_parameters={"operation_name": operation_name}) + if not wait.run_id: + raise SystemExit(f"Job {job_id} execution failed. Please check the job logs for more details.") + + job_run_url = f"{self._ws.config.host}/jobs/{job_id}/runs/{wait.run_id}" + logger.info( + f"'{operation_name.upper()}' job started. Please check the job_url `{job_run_url}` for the current status." + ) + if self._prompts.confirm(f"Would you like to open the job run URL `{job_run_url}` in the browser?"): + webbrowser.open(job_run_url) + + def _get_verified_recon_config(self) -> ReconcileConfig: + try: + recon_config = self._installation.load(ReconcileConfig) + except NotFound as err: + raise SystemExit("Cannot find existing `reconcile` installation. Please try reinstalling.") from err + except (PermissionDenied, SerdeError, ValueError, AttributeError) as e: + install_dir = self._installation.install_folder() + raise SystemExit( + f"Existing `reconcile` installation at {install_dir} is corrupted. Please try reinstalling." + ) from e + + self._verify_recon_table_config(recon_config) + return recon_config + + def _verify_recon_table_config(self, recon_config): + source_catalog_or_schema = ( + recon_config.database_config.source_catalog + if recon_config.database_config.source_catalog + else recon_config.database_config.source_schema + ) + # Filename pattern for recon table config `recon_config___.json` + # Example: recon_config_snowflake_sample_data_all.json + filename = f"recon_config_{recon_config.data_source}_{source_catalog_or_schema}_{recon_config.report_type}.json" + try: + logger.debug(f"Loading recon table config `{filename}` from workspace.") + self._installation.load(TableRecon, filename=filename) + except NotFound as e: + err_msg = ( + "Cannot find recon table configuration in existing `reconcile` installation. " + f"Please provide the configuration file {filename} in the workspace." + ) + logger.error(f"{err_msg}. For more details, please refer to {_RECON_README_URL}") + raise SystemExit(err_msg) from e + except (PermissionDenied, SerdeError, ValueError, AttributeError) as e: + install_dir = self._installation.install_folder() + err_msg = ( + f"Cannot load corrupted recon table configuration from {install_dir}/{filename}. " + f"Please validate the file." + ) + logger.error(f"{err_msg}. For more details, please refer to {_RECON_README_URL}") + raise SystemExit(err_msg) from e + + def _get_recon_job_id(self, reconcile_config: ReconcileConfig) -> int: + if reconcile_config.job_id: + logger.debug("Reconcile job id found in the reconcile config.") + return int(reconcile_config.job_id) + if RECON_JOB_NAME in self._install_state.jobs: + logger.debug("Reconcile job id found in the install state.") + return int(self._install_state.jobs[RECON_JOB_NAME]) + raise SystemExit("Reconcile Job ID not found. Please try reinstalling.") diff --git a/src/databricks/labs/remorph/reconcile/schema_compare.py b/src/databricks/labs/remorph/reconcile/schema_compare.py new file mode 100644 index 0000000000..903b23c177 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/schema_compare.py @@ -0,0 +1,130 @@ +import logging +from dataclasses import asdict + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import BooleanType, StringType, StructField, StructType +from sqlglot import Dialect, parse_one + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.recon_config import ( + Schema, + SchemaMatchResult, + SchemaReconcileOutput, + Table, +) +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks + +logger = logging.getLogger(__name__) + + +class SchemaCompare: + def __init__( + self, + spark: SparkSession, + ): + self.spark = spark + + # Define the schema for the schema compare DataFrame + _schema_compare_schema: StructType = StructType( + [ + StructField("source_column", StringType(), False), + StructField("source_datatype", StringType(), False), + StructField("databricks_column", StringType(), True), + StructField("databricks_datatype", StringType(), True), + StructField("is_valid", BooleanType(), False), + ] + ) + + @classmethod + def _build_master_schema( + cls, + source_schema: list[Schema], + databricks_schema: list[Schema], + table_conf: Table, + ) -> list[SchemaMatchResult]: + master_schema = source_schema + if table_conf.select_columns: + master_schema = [schema for schema in master_schema if schema.column_name in table_conf.select_columns] + if table_conf.drop_columns: + master_schema = [sschema for sschema in master_schema if sschema.column_name not in table_conf.drop_columns] + + target_column_map = table_conf.to_src_col_map or {} + master_schema_match_res = [ + SchemaMatchResult( + source_column=s.column_name, + databricks_column=target_column_map.get(s.column_name, s.column_name), + source_datatype=s.data_type, + databricks_datatype=next( + ( + tgt.data_type + for tgt in databricks_schema + if tgt.column_name == target_column_map.get(s.column_name, s.column_name) + ), + "", + ), + ) + for s in master_schema + ] + return master_schema_match_res + + def _create_dataframe(self, data: list, schema: StructType) -> DataFrame: + """ + :param data: Expectation is list of dataclass + :param schema: Target schema + :return: DataFrame + """ + data = [tuple(asdict(item).values()) for item in data] + df = self.spark.createDataFrame(data, schema) + + return df + + @classmethod + def _parse(cls, source: Dialect, column: str, data_type: str) -> str: + return ( + parse_one(f"create table dummy ({column} {data_type})", read=source) + .sql(dialect=get_dialect("databricks")) + .replace(", ", ",") + ) + + @classmethod + def _table_schema_status(cls, schema_compare_maps: list[SchemaMatchResult]) -> bool: + return bool(all(x.is_valid for x in schema_compare_maps)) + + @classmethod + def _validate_parsed_query(cls, master: SchemaMatchResult, parsed_query) -> None: + databricks_query = f"create table dummy ({master.source_column} {master.databricks_datatype})" + logger.info( + f""" + Source datatype: create table dummy ({master.source_column} {master.source_datatype}) + Parse datatype: {parsed_query} + Databricks datatype: {databricks_query} + """ + ) + if parsed_query.lower() != databricks_query.lower(): + master.is_valid = False + + def compare( + self, + source_schema: list[Schema], + databricks_schema: list[Schema], + source: Dialect, + table_conf: Table, + ) -> SchemaReconcileOutput: + """ + This method compares the source schema and the Databricks schema. It checks if the data types of the columns in the source schema + match with the corresponding columns in the Databricks schema by parsing using remorph transpile. + + Returns: + SchemaReconcileOutput: A dataclass object containing a boolean indicating the overall result of the comparison and a DataFrame with the comparison details. + """ + master_schema = self._build_master_schema(source_schema, databricks_schema, table_conf) + for master in master_schema: + if not isinstance(source, Databricks): + parsed_query = self._parse(source, master.source_column, master.source_datatype) + self._validate_parsed_query(master, parsed_query) + elif master.source_datatype.lower() != master.databricks_datatype.lower(): + master.is_valid = False + + df = self._create_dataframe(master_schema, self._schema_compare_schema) + final_result = self._table_schema_status(master_schema) + return SchemaReconcileOutput(final_result, df) diff --git a/src/databricks/labs/remorph/resources/__init__.py b/src/databricks/labs/remorph/resources/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/resources/reconcile/__init__.py b/src/databricks/labs/remorph/resources/reconcile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/__init__.py b/src/databricks/labs/remorph/resources/reconcile/dashboards/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/00_0_aggregate_recon_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/00_0_aggregate_recon_header.md new file mode 100644 index 0000000000..043ef23543 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/00_0_aggregate_recon_header.md @@ -0,0 +1,6 @@ +# Aggregates Reconcile Table Metrics +### It provides the following information: + +* Mismatch +* Missing in Source +* Missing in Target diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_0_recon_id.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_0_recon_id.filter.yml new file mode 100644 index 0000000000..e8e335c044 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_0_recon_id.filter.yml @@ -0,0 +1,6 @@ +columns: +- recon_id +- dd_recon_id +type: MULTI_SELECT +title: Recon Id +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_1_executed_by.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_1_executed_by.filter.yml new file mode 100644 index 0000000000..0dc7ef29d0 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_1_executed_by.filter.yml @@ -0,0 +1,5 @@ +columns: +- executed_by +type: MULTI_SELECT +title: Executed by +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_2_started_at.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_2_started_at.filter.yml new file mode 100644 index 0000000000..574f94d8fc --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/01_2_started_at.filter.yml @@ -0,0 +1,5 @@ +columns: +- start_ts +title: Started At +type: DATE_RANGE_PICKER +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_0_source_type.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_0_source_type.filter.yml new file mode 100644 index 0000000000..00d8651245 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_0_source_type.filter.yml @@ -0,0 +1,5 @@ +columns: +- source_type +type: MULTI_SELECT +title: Source Type +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_1_source_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_1_source_table.filter.yml new file mode 100644 index 0000000000..b45faaf596 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_1_source_table.filter.yml @@ -0,0 +1,5 @@ +columns: +- source_table +type: MULTI_SELECT +title: Source Table Name +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_2_target_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_2_target_table.filter.yml new file mode 100644 index 0000000000..c3c15fca7a --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/02_2_target_table.filter.yml @@ -0,0 +1,5 @@ +columns: +- target_table +type: MULTI_SELECT +title: Target Table Name +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/04_0_aggregate_summary_table.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/04_0_aggregate_summary_table.sql new file mode 100644 index 0000000000..a35c699560 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/04_0_aggregate_summary_table.sql @@ -0,0 +1,46 @@ +/* --title 'Aggregates Summary Table' --width 6 --height 6 */ +SELECT + main.recon_id, + main.source_type, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name, + IF( + ISNULL(source_catalog), + CONCAT_WS('.', source_schema, source_table_name), + CONCAT_WS( + '.', + source_catalog, + source_schema, + source_table_name + ) + ) AS source_table, + main.target_table.`catalog` AS target_catalog, + main.target_table.`schema` AS target_schema, + main.target_table.table_name AS target_table_name, + CONCAT_WS( + '.', + target_catalog, + target_schema, + target_table_name + ) AS target_table, + UPPER(rules.rule_info.agg_type) || CONCAT('(', rules.rule_info.agg_column, ')') AS aggregate_column, + rules.rule_info.group_by_columns, + metrics.run_metrics.status AS status, + metrics.run_metrics.exception_message AS exception, + metrics.recon_metrics.missing_in_source AS missing_in_source, + metrics.recon_metrics.missing_in_target AS missing_in_target, + metrics.recon_metrics.mismatch AS mismatch, + metrics.run_metrics.run_by_user AS executed_by, + main.start_ts AS start_ts, + main.end_ts AS end_ts +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.aggregate_metrics metrics + INNER JOIN remorph.reconcile.aggregate_rules rules + ON main.recon_table_id = metrics.recon_table_id + AND rules.rule_id = metrics.rule_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/05_0_aggregate_recon_drilldown_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/05_0_aggregate_recon_drilldown_header.md new file mode 100644 index 0000000000..910569f1f5 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/05_0_aggregate_recon_drilldown_header.md @@ -0,0 +1,2 @@ +# Drill Down +### The Aggregates Reconcile details table contains all the sample records information of mismatches and missing entries. diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_0_recon_id.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_0_recon_id.filter.yml new file mode 100644 index 0000000000..d460d900c8 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_0_recon_id.filter.yml @@ -0,0 +1,5 @@ +columns: +- dd_recon_id +type: MULTI_SELECT +title: Recon Id +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_1_category.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_1_category.filter.yml new file mode 100644 index 0000000000..bf06cf5d3c --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_1_category.filter.yml @@ -0,0 +1,5 @@ +columns: +- dd_recon_type +type: MULTI_SELECT +title: Category +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_2_aggregate_type.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_2_aggregate_type.filter.yml new file mode 100644 index 0000000000..ddf05b601b --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/06_2_aggregate_type.filter.yml @@ -0,0 +1,5 @@ +columns: +- dd_aggregate_type +type: MULTI_SELECT +title: Aggregate Type +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_0_target_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_0_target_table.filter.yml new file mode 100644 index 0000000000..2eaf96af06 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_0_target_table.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_target_table +type: MULTI_SELECT +title: Target Table Name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_1_source_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_1_source_table.filter.yml new file mode 100644 index 0000000000..09a67dbdbd --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/07_1_source_table.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_source_table +type: MULTI_SELECT +title: Source Table Name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/08_0_aggregate_details_table.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/08_0_aggregate_details_table.sql new file mode 100644 index 0000000000..ee3f70b1b4 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/08_0_aggregate_details_table.sql @@ -0,0 +1,92 @@ +/* --title 'Aggregates Reconciliation Details' --width 6 --height 6 */ +WITH details_view AS ( + SELECT + recon_table_id, + rule_id, + recon_type, + explode(data) AS agg_details + FROM + remorph.reconcile.aggregate_details +), + metrics_view AS ( + SELECT + recon_table_id, + rule_id, + recon_metrics, + run_metrics + FROM + remorph.reconcile.aggregate_metrics + ) +SELECT + recon_id AS dd_recon_id, + source_table AS dd_source_table, + target_table AS dd_target_table, + recon_type AS dd_recon_type, + aggregate_type AS dd_aggregate_type, + rule AS aggregate_column, + source_value, + target_value, + zip_with(rule_group_by_columns, group_by_column_values, (groupby, value) -> CONCAT_WS(':', TRIM(groupby), value)) AS group_by_columns, + COALESCE(status, 'false') AS status +FROM ( + SELECT + main.recon_id, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name, + IF( + ISNULL(source_catalog), + CONCAT_WS('.', source_schema, source_table_name), + CONCAT_WS( + '.', + source_catalog, + source_schema, + source_table_name + ) + ) AS source_table, + main.target_table.`catalog` AS target_catalog, + main.target_table.`schema` AS target_schema, + main.target_table.table_name AS target_table_name, + CONCAT_WS( + '.', + target_catalog, + target_schema, + target_table_name + ) AS target_table, + dtl.recon_type, + rul.rule_info.agg_type AS aggregate_type, + UPPER(rul.rule_info.agg_type) || CONCAT('(', rul.rule_info.agg_column, ')') AS rule, + CONCAT_WS( + '_', + 'source', + rul.rule_info.agg_type, + rul.rule_info.agg_column + ) AS source_agg_column, + dtl.agg_details[source_agg_column] AS source_value, + CONCAT_WS( + '_', + 'target', + rul.rule_info.agg_type, + rul.rule_info.agg_column + ) AS target_agg_column, + dtl.agg_details[target_agg_column] AS target_value, + SPLIT(rul.rule_info.group_by_columns, ',') AS rule_group_by_columns, + TRANSFORM(rule_group_by_columns, colm -> + COALESCE(dtl.agg_details[CONCAT('source_group_by_', TRIM(colm))], + dtl.agg_details[CONCAT('target_group_by_', TRIM(colm))])) AS group_by_column_values, + CONCAT_WS( + '_', + 'match', + rul.rule_info.agg_type, + rul.rule_info.agg_column + ) AS status_column, + dtl.agg_details[status_column] AS status + FROM + metrics_view mtc + INNER JOIN remorph.reconcile.main main ON main.recon_table_id = mtc.recon_table_id + INNER JOIN details_view dtl ON mtc.recon_table_id = dtl.recon_table_id + INNER JOIN remorph.reconcile.aggregate_rules rul ON mtc.rule_id = dtl.rule_id + AND dtl.rule_id = rul.rule_id + ) +ORDER BY + recon_id diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/09_0_aggregate_missing_mismatch_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/09_0_aggregate_missing_mismatch_header.md new file mode 100644 index 0000000000..d1665f253b --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/09_0_aggregate_missing_mismatch_header.md @@ -0,0 +1 @@ +# Visualization of Missing and Mismatched Records diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/10_0_aggr_mismatched_records.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/10_0_aggr_mismatched_records.sql new file mode 100644 index 0000000000..26b5fe9af7 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/10_0_aggr_mismatched_records.sql @@ -0,0 +1,19 @@ +/* --title 'Mismatched Records' --width 6 */ +SELECT + main.recon_id, + CONCAT_WS( + '.', + main.target_table.`catalog`, + main.target_table.`schema`, + main.target_table.table_name + ) AS target_table, + main.start_ts, + metrics.recon_metrics.mismatch AS mismatch +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.aggregate_metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_0_aggr_missing_in_databricks.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_0_aggr_missing_in_databricks.sql new file mode 100644 index 0000000000..bcd0113d7f --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_0_aggr_missing_in_databricks.sql @@ -0,0 +1,19 @@ +/* --title 'Missing in Databricks' --width 3 */ +SELECT + main.recon_id, + CONCAT_WS( + '.', + main.target_table.`catalog`, + main.target_table.`schema`, + main.target_table.table_name + ) AS target_table, + main.start_ts, + metrics.recon_metrics.missing_in_target AS missing_in_target +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.aggregate_metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_1_aggr_missing_in_source.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_1_aggr_missing_in_source.sql new file mode 100644 index 0000000000..4bde21239d --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/11_1_aggr_missing_in_source.sql @@ -0,0 +1,19 @@ +/* --title 'Missing in Source' --width 3 */ +SELECT + main.recon_id, + CONCAT_WS( + '.', + main.target_table.`catalog`, + main.target_table.`schema`, + main.target_table.table_name + ) AS target_table, + main.start_ts, + metrics.recon_metrics.missing_in_source AS missing_in_source +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.aggregate_metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/dashboard.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/dashboard.yml new file mode 100644 index 0000000000..bc24691eab --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/aggregate_reconciliation_metrics/dashboard.yml @@ -0,0 +1,365 @@ +display_name: "Aggregate Reconciliation Metrics" +tiles: + 04_0_aggregate_summary_table: + overrides: + spec: + withRowNumber: true + encodings: + columns: + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: recon_id + title: recon_id + type: string + cellFormat: + default: + foregroundColor: + rules: + - if: + column: status + fn: '=' + literal: 'true' + value: + foregroundColor: '#3BD973' + - if: + column: status + fn: '=' + literal: 'false' + value: + foregroundColor: '#E92828' + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_type + title: source_type + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_catalog + title: source_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_schema + title: source_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table_name + title: source_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table + title: source_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_catalog + title: target_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_schema + title: target_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table_name + title: target_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table + title: target_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: aggregate_column + title: aggregate_column + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: group_by_columns + title: group_by_columns + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: status + title: status + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: exception + title: exception + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: missing_in_source + title: missing_in_source + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: missing_in_target + title: missing_in_target + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: mismatch + title: mismatch + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: executed_by + title: executed_by + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: datetime + fieldName: start_ts + title: start_ts + type: datetime + dateTimeFormat: 'YYYY-MM-DD HH:mm:ss.SSS' + - booleanValues: + - 'false' + - 'true' + displayAs: datetime + fieldName: end_ts + title: end_ts + type: datetime + dateTimeFormat: 'YYYY-MM-DD HH:mm:ss.SSS' + 08_0_aggregate_details_table: + overrides: + spec: + withRowNumber: true + encodings: + columns: + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: dd_recon_id + title: recon_id + type: string + cellFormat: + default: + foregroundColor: + rules: + - if: + column: status + fn: '=' + literal: 'true' + value: + foregroundColor: '#3BD973' + - if: + column: status + fn: '=' + literal: 'false' + value: + foregroundColor: '#E92828' + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: dd_source_table + title: source_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: dd_target_table + title: target_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: dd_recon_type + title: recon_type + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: dd_aggregate_type + title: aggregate_type + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: aggregate_column + title: aggregate_column + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_value + title: source_value + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_value + title: target_value + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: group_by_columns + title: group_by_columns + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: status + title: status + type: string + 10_0_aggr_mismatched_records: + overrides: + queries: + - name: main_query + query: + datasetName: 10_0_aggr_mismatched_records + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: mismatch + expression: '`mismatch`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: mismatch + scale: + type: quantitative + displayName: mismatch + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + 11_0_aggr_missing_in_databricks: + overrides: + queries: + - name: main_query + query: + datasetName: 11_0_aggr_missing_in_databricks + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: missing_in_target + expression: '`missing_in_target`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: missing_in_target + scale: + type: quantitative + displayName: missing_in_target + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + 11_1_aggr_missing_in_source: + overrides: + queries: + - name: main_query + query: + datasetName: 11_1_aggr_missing_in_source + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: missing_in_source + expression: '`missing_in_source`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: missing_in_source + scale: + type: quantitative + displayName: missing_in_source + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + + diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/00_0_recon_main.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/00_0_recon_main.md new file mode 100644 index 0000000000..aaa3cf8aa5 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/00_0_recon_main.md @@ -0,0 +1,3 @@ +# Main Reconciliation Table + +### This table provides comprehensive information on the report's status, including failure indications, schema matching status, and details on missing and mismatched records. diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_0_recon_id.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_0_recon_id.filter.yml new file mode 100644 index 0000000000..e8e335c044 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_0_recon_id.filter.yml @@ -0,0 +1,6 @@ +columns: +- recon_id +- dd_recon_id +type: MULTI_SELECT +title: Recon Id +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_1_report_type.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_1_report_type.filter.yml new file mode 100644 index 0000000000..86d005eba4 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_1_report_type.filter.yml @@ -0,0 +1,5 @@ +columns: +- report_type +type: MULTI_SELECT +title: Report Type +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_2_executed_by.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_2_executed_by.filter.yml new file mode 100644 index 0000000000..0dc7ef29d0 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/01_2_executed_by.filter.yml @@ -0,0 +1,5 @@ +columns: +- executed_by +type: MULTI_SELECT +title: Executed by +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_0_source_type.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_0_source_type.filter.yml new file mode 100644 index 0000000000..00d8651245 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_0_source_type.filter.yml @@ -0,0 +1,5 @@ +columns: +- source_type +type: MULTI_SELECT +title: Source Type +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_1_source_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_1_source_table.filter.yml new file mode 100644 index 0000000000..53b7510a1f --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_1_source_table.filter.yml @@ -0,0 +1,6 @@ +columns: +- source_table +- dd_source_table +type: MULTI_SELECT +title: Source Table Name +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_2_target_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_2_target_table.filter.yml new file mode 100644 index 0000000000..141cd17a35 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/02_2_target_table.filter.yml @@ -0,0 +1,6 @@ +columns: +- target_table +- dd_target_table +type: MULTI_SELECT +title: Target Table Name +width: 2 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/03_0_started_at.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/03_0_started_at.filter.yml new file mode 100644 index 0000000000..28f11d147c --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/03_0_started_at.filter.yml @@ -0,0 +1,5 @@ +columns: +- start_ts +title: Started At +type: DATE_RANGE_PICKER +width: 6 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/05_0_summary_table.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/05_0_summary_table.sql new file mode 100644 index 0000000000..4183d241a9 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/05_0_summary_table.sql @@ -0,0 +1,38 @@ +/* --title 'Summary Table' --width 6 --height 6 */ +SELECT main.recon_id, + main.source_type, + main.report_type, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name, + IF( + ISNULL(source_catalog), + CONCAT_WS('.', source_schema, source_table_name), + CONCAT_WS( + '.', + source_catalog, + source_schema, + source_table_name + ) + ) AS source_table, + main.target_table.`catalog` AS target_catalog, + main.target_table.`schema` AS target_schema, + main.target_table.table_name AS target_table_name, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS target_table, + metrics.run_metrics.status AS status, + metrics.run_metrics.exception_message AS exception, + metrics.recon_metrics.row_comparison.missing_in_source AS missing_in_source, + metrics.recon_metrics.row_comparison.missing_in_target AS missing_in_target, + metrics.recon_metrics.column_comparison.absolute_mismatch AS absolute_mismatch, + metrics.recon_metrics.column_comparison.threshold_mismatch AS threshold_mismatch, + metrics.recon_metrics.column_comparison.mismatch_columns AS mismatch_columns, + metrics.recon_metrics.schema_comparison AS schema_comparison, + metrics.run_metrics.run_by_user AS executed_by, + main.start_ts AS start_ts, + main.end_ts AS end_ts +FROM remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/06_0_schema_comparison_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/06_0_schema_comparison_header.md new file mode 100644 index 0000000000..ae616d0ef1 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/06_0_schema_comparison_header.md @@ -0,0 +1,3 @@ +# Schema Comparison Details + +### This table provides a detailed view of schema mismatches. diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/07_0_schema_details_table.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/07_0_schema_details_table.sql new file mode 100644 index 0000000000..9f7cc115bd --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/07_0_schema_details_table.sql @@ -0,0 +1,42 @@ +/* --title 'Schema Details' --width 6 */ +WITH tmp AS ( + SELECT + recon_table_id, + inserted_ts, + explode(data) AS schema_data + FROM + remorph.reconcile.details + WHERE + recon_type = 'schema' +) +SELECT + main.recon_id, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name, + IF( + ISNULL(source_catalog), + CONCAT_WS('.', source_schema, source_table_name), + CONCAT_WS( + '.', + source_catalog, + source_schema, + source_table_name + ) + ) AS source_table, + main.target_table.`catalog` AS target_catalog, + main.target_table.`schema` AS target_schema, + main.target_table.table_name AS target_table_name, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS target_table, + schema_data['source_column'] AS source_column, + schema_data['source_datatype'] AS source_datatype, + schema_data['databricks_column'] AS databricks_column, + schema_data['databricks_datatype'] AS databricks_datatype, + schema_data['is_valid'] AS is_valid +FROM + remorph.reconcile.main main + INNER JOIN tmp ON main.recon_table_id = tmp.recon_table_id +ORDER BY + tmp.inserted_ts DESC, + main.recon_id, + main.target_table diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/08_0_drill_down_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/08_0_drill_down_header.md new file mode 100644 index 0000000000..d527f3624c --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/08_0_drill_down_header.md @@ -0,0 +1,3 @@ +# Drill Down + +### The details table contains all the sample records for mismatches and missing entries, providing users with exact details to pinpoint the issues. diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_0_recon_id.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_0_recon_id.filter.yml new file mode 100644 index 0000000000..9c4ca13dea --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_0_recon_id.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_recon_id +type: MULTI_SELECT +title: Recon Id diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_1_category.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_1_category.filter.yml new file mode 100644 index 0000000000..b6c1a9db27 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/09_1_category.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_recon_type +type: MULTI_SELECT +title: Category diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_0_target_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_0_target_table.filter.yml new file mode 100644 index 0000000000..2eaf96af06 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_0_target_table.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_target_table +type: MULTI_SELECT +title: Target Table Name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_1_source_table.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_1_source_table.filter.yml new file mode 100644 index 0000000000..09a67dbdbd --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/10_1_source_table.filter.yml @@ -0,0 +1,4 @@ +columns: +- dd_source_table +type: MULTI_SELECT +title: Source Table Name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/11_0_recon_details_pivot.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/11_0_recon_details_pivot.sql new file mode 100644 index 0000000000..6edad4d9af --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/11_0_recon_details_pivot.sql @@ -0,0 +1,40 @@ +/* --title 'Recon Details Drill Down' --height 6 --width 6 */ +WITH tmp AS ( + SELECT + recon_table_id, + inserted_ts, + recon_type, + explode(data) AS data, + row_number() OVER (PARTITION BY recon_table_id, recon_type ORDER BY recon_table_id) AS rn + FROM + remorph.reconcile.details + WHERE + recon_type != 'schema' +) +SELECT + main.recon_id AS dd_recon_id, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name, + IF( + ISNULL(source_catalog), + CONCAT_WS('.', source_schema, source_table_name), + CONCAT_WS( + '.', + source_catalog, + source_schema, + source_table_name + ) + ) AS dd_source_table, + main.target_table.`catalog` AS target_catalog, + main.target_table.`schema` AS target_schema, + main.target_table.table_name AS target_table_name, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS dd_target_table, + recon_type AS dd_recon_type, + key, + value, + rn +FROM tmp + INNER JOIN remorph.reconcile.main main + ON main.recon_table_id = tmp.recon_table_id + LATERAL VIEW explode(data) exploded_data AS key, value diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/12_0_daily_data_validation_issue_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/12_0_daily_data_validation_issue_header.md new file mode 100644 index 0000000000..c58fd9528e --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/12_0_daily_data_validation_issue_header.md @@ -0,0 +1,3 @@ +# Daily Data Validation Issues Report + +### This summary report provides an overview of all data validation runs conducted on a specific day. It highlights whether each table has encountered any validation issues, without delving into the low-level details. This report aims to give a quick and clear status of data integrity across all tables for the day. diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/13_0_success_fail_.filter.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/13_0_success_fail_.filter.yml new file mode 100644 index 0000000000..48d8a00c4a --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/13_0_success_fail_.filter.yml @@ -0,0 +1,4 @@ +columns: +- start_date +type: DATE_RANGE_PICKER +width: 6 diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/14_0_failed_recon_ids.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/14_0_failed_recon_ids.sql new file mode 100644 index 0000000000..89228a8adb --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/14_0_failed_recon_ids.sql @@ -0,0 +1,15 @@ +/* --title 'Number of Distinct Recon IDs per Target Table Failed' --width 6 */ +SELECT + main.recon_id AS rec_id, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS t_table, + DATE(main.start_ts) AS start_date +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics +ON main.recon_table_id = metrics.recon_table_id +WHERE + metrics.run_metrics.status = FALSE +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_0_total_failed_runs.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_0_total_failed_runs.sql new file mode 100644 index 0000000000..71fa10faa9 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_0_total_failed_runs.sql @@ -0,0 +1,10 @@ +/* --title 'Total number of runs failed' --width 2 */ +SELECT + main.recon_id AS rec_id, + DATE(main.start_ts) AS start_date +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics +ON main.recon_table_id = metrics.recon_table_id +WHERE + metrics.run_metrics.status = FALSE diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_1_failed_targets.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_1_failed_targets.sql new file mode 100644 index 0000000000..ff16c43558 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_1_failed_targets.sql @@ -0,0 +1,10 @@ +/* --title 'Unique target tables failed' --width 2 */ +SELECT + CONCAT_WS('.', main.target_table.catalog, main.target_table.schema, main.target_table.table_name) AS t_table, + DATE(main.start_ts) AS start_date +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics +ON main.recon_table_id = metrics.recon_table_id +WHERE + metrics.run_metrics.status = FALSE diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_2_successful_targets.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_2_successful_targets.sql new file mode 100644 index 0000000000..2d2a6eb191 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/15_2_successful_targets.sql @@ -0,0 +1,10 @@ +/* --title 'Unique target tables successful' --width 2 */ +SELECT + CONCAT_WS('.', main.target_table.catalog, main.target_table.schema, main.target_table.table_name) AS t_table, + DATE(main.start_ts) AS start_date +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics +ON main.recon_table_id = metrics.recon_table_id +WHERE + metrics.run_metrics.status = TRUE diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/16_0_missing_mismatch_header.md b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/16_0_missing_mismatch_header.md new file mode 100644 index 0000000000..d1665f253b --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/16_0_missing_mismatch_header.md @@ -0,0 +1 @@ +# Visualization of Missing and Mismatched Records diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_0_mismatched_records.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_0_mismatched_records.sql new file mode 100644 index 0000000000..09eea3885f --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_0_mismatched_records.sql @@ -0,0 +1,14 @@ +/* --title 'Mismatched Records' --width 3 */ +SELECT + main.recon_id, + CONCAT_WS('.', main.target_table.catalog, main.target_table.schema, main.target_table.table_name) AS target_table, + metrics.recon_metrics.column_comparison.absolute_mismatch AS absolute_mismatch, + main.start_ts AS start_ts +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_1_threshold_mismatches.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_1_threshold_mismatches.sql new file mode 100644 index 0000000000..781fa2c100 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/17_1_threshold_mismatches.sql @@ -0,0 +1,14 @@ +/* --title 'Threshold Mismatches' --width 3 */ +SELECT + main.recon_id, + CONCAT_WS('.', main.target_table.catalog, main.target_table.schema, main.target_table.table_name) AS target_table, + metrics.recon_metrics.column_comparison.threshold_mismatch AS threshold_mismatch, + main.start_ts AS start_ts +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_0_missing_in_databricks.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_0_missing_in_databricks.sql new file mode 100644 index 0000000000..1dbcf8073f --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_0_missing_in_databricks.sql @@ -0,0 +1,14 @@ +/* --title 'Missing in Databricks' --width 3 */ +SELECT + main.recon_id, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS target_table, + metrics.recon_metrics.row_comparison.missing_in_target AS missing_in_target, + main.start_ts AS start_ts +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_1_missing_in_source.sql b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_1_missing_in_source.sql new file mode 100644 index 0000000000..f6b392d990 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/18_1_missing_in_source.sql @@ -0,0 +1,14 @@ +/* --title 'Missing in Source' --width 3 */ +SELECT + main.recon_id, + CONCAT(main.target_table.catalog, '.', main.target_table.schema, '.', main.target_table.table_name) AS target_table, + metrics.recon_metrics.row_comparison.missing_in_source AS missing_in_source, + main.start_ts AS start_ts +FROM + remorph.reconcile.main main + INNER JOIN remorph.reconcile.metrics metrics + ON main.recon_table_id = metrics.recon_table_id +ORDER BY + metrics.inserted_ts DESC, + main.recon_id, + main.target_table.table_name diff --git a/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/dashboard.yml b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/dashboard.yml new file mode 100644 index 0000000000..699575c3cc --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/dashboards/reconciliation_metrics/dashboard.yml @@ -0,0 +1,545 @@ +display_name: "Reconciliation Metrics" +tiles: + 05_0_summary_table: + overrides: + spec: + withRowNumber: true + encodings: + columns: + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: recon_id + title: recon_id + type: string + cellFormat: + default: + foregroundColor: + rules: + - if: + column: status + fn: '=' + literal: 'true' + value: + foregroundColor: '#3BD973' + - if: + column: status + fn: '=' + literal: 'false' + value: + foregroundColor: '#E92828' + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_type + title: source_type + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: report_type + title: report_type + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_catalog + title: source_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_schema + title: source_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table_name + title: source_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table + title: source_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_catalog + title: target_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_schema + title: target_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table_name + title: target_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table + title: target_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: status + title: status + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: exception + title: exception + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: missing_in_source + title: missing_in_source + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: missing_in_target + title: missing_in_target + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: absolute_mismatch + title: absolute_mismatch + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: threshold_mismatch + title: threshold_mismatch + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: mismatch_columns + title: mismatch_columns + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: schema_comparison + title: schema_comparison + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: executed_by + title: executed_by + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: datetime + fieldName: start_ts + title: start_ts + type: datetime + dateTimeFormat: 'YYYY-MM-DD HH:mm:ss.SSS' + - booleanValues: + - 'false' + - 'true' + displayAs: datetime + fieldName: end_ts + title: end_ts + type: datetime + dateTimeFormat: 'YYYY-MM-DD HH:mm:ss.SSS' + 07_0_schema_details_table: + overrides: + spec: + withRowNumber: true + encodings: + columns: + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: recon_id + title: recon_id + type: string + cellFormat: + default: + foregroundColor: + rules: + - if: + column: is_valid + fn: '=' + literal: 'false' + value: + foregroundColor: '#E92828' + - if: + column: is_valid + fn: '=' + literal: 'true' + value: + foregroundColor: '#3BD973' + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_catalog + title: source_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_schema + title: source_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table_name + title: source_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_table + title: source_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_catalog + title: target_catalog + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_schema + title: target_schema + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table_name + title: target_table_name + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: target_table + title: target_table + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_column + title: source_column + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: source_datatype + title: source_datatype + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: databricks_column + title: databricks_column + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: databricks_datatype + title: databricks_datatype + type: string + - booleanValues: + - 'false' + - 'true' + displayAs: string + fieldName: is_valid + title: is_valid + type: string + 11_0_recon_details_pivot: + overrides: + spec: + version: 3 + widgetType: pivot + encodings: + rows: + - fieldName: dd_recon_id + displayName: recon_id + - fieldName: dd_source_table + displayName: source_table + - fieldName: dd_target_table + displayName: target_table + - fieldName: dd_recon_type + displayName: recon_type + - fieldName: rn + displayName: rn + columns: + - fieldName: key + displayName: key + cell: + fieldName: value + cellType: text + displayName: value + 14_0_failed_recon_ids: + overrides: + spec: + version: 3 + widgetType: bar + encodings: + x: + fieldName: t_table + scale: + type: categorical + sort: + by: y-reversed + displayName: Target table + 'y': + fieldName: countdistinct(rec_id) + scale: + type: quantitative + displayName: Count of Unique Recon Ids + label: + show: true + queries: + - name: main_query + query: + datasetName: 14_0_failed_recon_ids + fields: + - name: t_table + expression: '`t_table`' + - name: countdistinct(rec_id) + expression: COUNT(DISTINCT `rec_id`) + disaggregated: false + 15_0_total_failed_runs: + overrides: + spec: + version: 2 + widgetType: counter + encodings: + value: + fieldName: countdistinct(rec_id) + displayName: countdistinct(rec_id) + queries: + - name: main_query + query: + datasetName: 15_0_total_failed_runs + fields: + - name: countdistinct(rec_id) + expression: 'COUNT(DISTINCT `rec_id`)' + disaggregated: false + 15_1_failed_targets: + overrides: + spec: + version: 2 + widgetType: counter + encodings: + value: + fieldName: countdistinct(t_table) + displayName: countdistinct(t_table) + queries: + - name: main_query + query: + datasetName: 15_1_failed_targets + fields: + - name: countdistinct(t_table) + expression: 'COUNT(DISTINCT `t_table`)' + disaggregated: false + 15_2_successful_targets: + overrides: + spec: + version: 2 + widgetType: counter + encodings: + value: + fieldName: countdistinct(t_table) + displayName: countdistinct(t_table) + queries: + - name: main_query + query: + datasetName: 15_2_successful_targets + fields: + - name: countdistinct(t_table) + expression: 'COUNT(DISTINCT `t_table`)' + disaggregated: false + 17_0_mismatched_records: + overrides: + queries: + - name: main_query + query: + datasetName: 17_0_mismatched_records + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: absolute_mismatch + expression: '`absolute_mismatch`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: absolute_mismatch + scale: + type: quantitative + displayName: absolute_mismatch + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + 17_1_threshold_mismatches: + overrides: + queries: + - name: main_query + query: + datasetName: 17_1_threshold_mismatches + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: threshold_mismatch + expression: '`threshold_mismatch`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: threshold_mismatch + scale: + type: quantitative + displayName: threshold_mismatch + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + 18_0_missing_in_databricks: + overrides: + queries: + - name: main_query + query: + datasetName: 18_0_missing_in_databricks + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: missing_in_target + expression: '`missing_in_target`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: missing_in_target + scale: + type: quantitative + displayName: missing_in_target + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false + 18_1_missing_in_source: + overrides: + queries: + - name: main_query + query: + datasetName: 18_1_missing_in_source + fields: + - name: target_table + expression: '`target_table`' + - name: hourly(start_ts) + expression: 'DATE_TRUNC("HOUR", `start_ts`)' + - name: missing_in_source + expression: '`missing_in_source`' + disaggregated: true + spec: + version: 3 + widgetType: area + encodings: + x: + fieldName: hourly(start_ts) + scale: + type: temporal + displayName: start_ts + 'y': + fieldName: missing_in_source + scale: + type: quantitative + displayName: missing_in_source + color: + fieldName: target_table + scale: + type: categorical + displayName: target_table + label: + show: false diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/__init__.py b/src/databricks/labs/remorph/resources/reconcile/queries/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/__init__.py b/src/databricks/labs/remorph/resources/reconcile/queries/installation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_details.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_details.sql new file mode 100644 index 0000000000..d1cc40ccfb --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_details.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS aggregate_details ( + recon_table_id BIGINT NOT NULL, + rule_id BIGINT NOT NULL, + recon_type STRING NOT NULL, + data ARRAY> NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_metrics.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_metrics.sql new file mode 100644 index 0000000000..f3c1381c8e --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_metrics.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS aggregate_metrics ( + recon_table_id BIGINT NOT NULL, + rule_id BIGINT NOT NULL, + recon_metrics STRUCT< + missing_in_source: INTEGER, + missing_in_target: INTEGER, + mismatch: INTEGER + >, + run_metrics STRUCT< + status: BOOLEAN NOT NULL, + run_by_user: STRING NOT NULL, + exception_message: STRING + > NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_rules.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_rules.sql new file mode 100644 index 0000000000..62d188f601 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/aggregate_rules.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS aggregate_rules ( + rule_id BIGINT NOT NULL, + rule_type STRING NOT NULL, + rule_info MAP NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/details.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/details.sql new file mode 100644 index 0000000000..f27da317aa --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/details.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS details ( + recon_table_id BIGINT NOT NULL, + recon_type STRING NOT NULL, + status BOOLEAN NOT NULL, + data ARRAY> NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/main.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/main.sql new file mode 100644 index 0000000000..192a7c33ce --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/main.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS main ( + recon_table_id BIGINT NOT NULL, + recon_id STRING NOT NULL, + source_type STRING NOT NULL, + source_table STRUCT< + catalog: STRING, + schema: STRING NOT NULL, + table_name: STRING NOT NULL + > , + target_table STRUCT< + catalog: STRING NOT NULL, + schema: STRING NOT NULL, + table_name: STRING NOT NULL + > NOT NULL, + report_type STRING NOT NULL, + operation_name STRING NOT NULL, + start_ts TIMESTAMP, + end_ts TIMESTAMP +) +TBLPROPERTIES ( + 'delta.columnMapping.mode' = 'name', + 'delta.minReaderVersion' = '2', + 'delta.minWriterVersion' = '5' +); diff --git a/src/databricks/labs/remorph/resources/reconcile/queries/installation/metrics.sql b/src/databricks/labs/remorph/resources/reconcile/queries/installation/metrics.sql new file mode 100644 index 0000000000..33582e2d23 --- /dev/null +++ b/src/databricks/labs/remorph/resources/reconcile/queries/installation/metrics.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS metrics ( + recon_table_id BIGINT NOT NULL, + recon_metrics STRUCT< + row_comparison: STRUCT< + missing_in_source: BIGINT, + missing_in_target: BIGINT + >, + column_comparison: STRUCT< + absolute_mismatch: BIGINT, + threshold_mismatch: BIGINT, + mismatch_columns: STRING + >, + schema_comparison: BOOLEAN + >, + run_metrics STRUCT< + status: BOOLEAN NOT NULL, + run_by_user: STRING NOT NULL, + exception_message: STRING + > NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/src/databricks/labs/remorph/transpiler/__init__.py b/src/databricks/labs/remorph/transpiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/transpiler/execute.py b/src/databricks/labs/remorph/transpiler/execute.py new file mode 100644 index 0000000000..e3ae529c88 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/execute.py @@ -0,0 +1,290 @@ +import logging +import os +from pathlib import Path + +from sqlglot.dialects.dialect import Dialect +from databricks.labs.remorph.__about__ import __version__ +from databricks.labs.remorph.config import ( + TranspileConfig, + get_dialect, + TranspilationResult, + ValidationResult, +) +from databricks.labs.remorph.helpers import db_sql +from databricks.labs.remorph.helpers.execution_time import timeit +from databricks.labs.remorph.helpers.file_utils import ( + dir_walk, + is_sql_file, + make_dir, + remove_bom, +) +from databricks.labs.remorph.transpiler.transpile_status import ( + TranspileStatus, + ParserError, + ValidationError, +) +from databricks.labs.remorph.helpers.validation import Validator +from databricks.labs.remorph.transpiler.sqlglot import lca_utils +from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine +from databricks.sdk import WorkspaceClient + +# pylint: disable=unspecified-encoding + +logger = logging.getLogger(__name__) + + +def _process_file( + config: TranspileConfig, + validator: Validator | None, + transpiler: SqlglotEngine, + input_file: str | Path, + output_file: str | Path, +): + logger.info(f"started processing for the file ${input_file}") + validate_error_list = [] + no_of_sqls = 0 + + input_file = Path(input_file) + output_file = Path(output_file) + + with input_file.open("r") as f: + sql = remove_bom(f.read()) + + lca_error = lca_utils.check_for_unsupported_lca(get_dialect(config.source_dialect.lower()), sql, str(input_file)) + + if lca_error: + validate_error_list.append(lca_error) + + write_dialect = config.get_write_dialect() + + transpiler_result: TranspilationResult = _parse(transpiler, write_dialect, sql, input_file, []) + + with output_file.open("w") as w: + for output in transpiler_result.transpiled_sql: + if output: + no_of_sqls = no_of_sqls + 1 + if config.skip_validation: + w.write(output) + w.write("\n;\n") + elif validator: + validation_result: ValidationResult = _validation(validator, config, output) + w.write(validation_result.validated_sql) + if validation_result.exception_msg is not None: + validate_error_list.append(ValidationError(str(input_file), validation_result.exception_msg)) + else: + warning_message = ( + f"Skipped a query from file {input_file!s}. " + f"Check for unsupported operations related to STREAM, TASK, SESSION etc." + ) + logger.warning(warning_message) + + return no_of_sqls, transpiler_result.parse_error_list, validate_error_list + + +def _process_directory( + config: TranspileConfig, + validator: Validator | None, + transpiler: SqlglotEngine, + root: str | Path, + base_root: str, + files: list[str], +): + output_folder = config.output_folder + parse_error_list = [] + validate_error_list = [] + counter = 0 + + root = Path(root) + + for file in files: + logger.info(f"Processing file :{file}") + if is_sql_file(file): + if output_folder in {None, "None"}: + output_folder_base = f"{root.name}/transpiled" + else: + output_folder_base = f'{str(output_folder).rstrip("/")}/{base_root}' + + output_file_name = Path(output_folder_base) / Path(file).name + make_dir(output_folder_base) + + no_of_sqls, parse_error, validation_error = _process_file( + config, validator, transpiler, file, output_file_name + ) + counter = counter + no_of_sqls + parse_error_list.extend(parse_error) + validate_error_list.extend(validation_error) + else: + # Only SQL files are processed with extension .sql or .ddl + pass + + return counter, parse_error_list, validate_error_list + + +def _process_recursive_dirs( + config: TranspileConfig, input_sql_path: Path, validator: Validator | None, transpiler: SqlglotEngine +): + input_sql = input_sql_path + parse_error_list = [] + validate_error_list = [] + + file_list = [] + counter = 0 + for root, _, files in dir_walk(input_sql): + base_root = str(root).replace(str(input_sql), "") + folder = str(input_sql.resolve().joinpath(base_root)) + msg = f"Processing for sqls under this folder: {folder}" + logger.info(msg) + file_list.extend(files) + no_of_sqls, parse_error, validation_error = _process_directory( + config, validator, transpiler, root, base_root, files + ) + counter = counter + no_of_sqls + parse_error_list.extend(parse_error) + validate_error_list.extend(validation_error) + + error_log = parse_error_list + validate_error_list + + return TranspileStatus(file_list, counter, len(parse_error_list), len(validate_error_list), error_log) + + +@timeit +def transpile(workspace_client: WorkspaceClient, config: TranspileConfig): + """ + [Experimental] Transpiles the SQL queries from one dialect to another. + + :param config: The configuration for the morph operation. + :param workspace_client: The WorkspaceClient object. + """ + if not config.input_source: + logger.error("Input SQL path is not provided.") + raise ValueError("Input SQL path is not provided.") + + input_sql = Path(config.input_source) + status = [] + result = TranspileStatus([], 0, 0, 0, []) + + read_dialect = config.get_read_dialect() + transpiler = SqlglotEngine(read_dialect) + validator = None + if not config.skip_validation: + sql_backend = db_sql.get_sql_backend(workspace_client) + logger.info(f"SQL Backend used for query validation: {type(sql_backend).__name__}") + validator = Validator(sql_backend) + + if input_sql.is_file(): + if is_sql_file(input_sql): + msg = f"Processing for sqls under this file: {input_sql}" + logger.info(msg) + if config.output_folder in {None, "None"}: + output_folder = input_sql.parent / "transpiled" + else: + output_folder = Path(str(config.output_folder).rstrip("/")) + + make_dir(output_folder) + output_file = output_folder / input_sql.name + no_of_sqls, parse_error, validation_error = _process_file( + config, validator, transpiler, input_sql, output_file + ) + error_log = parse_error + validation_error + result = TranspileStatus([str(input_sql)], no_of_sqls, len(parse_error), len(validation_error), error_log) + else: + msg = f"{input_sql} is not a SQL file." + logger.warning(msg) + elif input_sql.is_dir(): + result = _process_recursive_dirs(config, input_sql, validator, transpiler) + else: + msg = f"{input_sql} does not exist." + logger.error(msg) + raise FileNotFoundError(msg) + + error_list_count = result.parse_error_count + result.validate_error_count + if not config.skip_validation: + logger.info(f"No of Sql Failed while Validating: {result.validate_error_count}") + + error_log_file = "None" + if error_list_count > 0: + error_log_file = str(Path.cwd().joinpath(f"err_{os.getpid()}.lst")) + if result.error_log_list: + with Path(error_log_file).open("a") as e: + e.writelines(f"{err}\n" for err in result.error_log_list) + + status.append( + { + "total_files_processed": len(result.file_list), + "total_queries_processed": result.no_of_queries, + "no_of_sql_failed_while_parsing": result.parse_error_count, + "no_of_sql_failed_while_validating": result.validate_error_count, + "error_log_file": str(error_log_file), + } + ) + return status + + +def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClient: + # pylint: disable=protected-access + """ + [Private] Verifies and updates the workspace client configuration. + + TODO: In future refactor this function so it can be used for reconcile module without cross access. + """ + if workspace_client.config._product != "remorph": + workspace_client.config._product = "remorph" + if workspace_client.config._product_version != __version__: + workspace_client.config._product_version = __version__ + return workspace_client + + +def _parse( + transpiler: SqlglotEngine, + write_dialect: Dialect, + sql: str, + input_file: str | Path, + error_list: list[ParserError], +) -> TranspilationResult: + return transpiler.transpile(write_dialect, sql, str(input_file), error_list) + + +def _validation( + validator: Validator, + config: TranspileConfig, + sql: str, +) -> ValidationResult: + return validator.validate_format_result(config, sql) + + +@timeit +def transpile_sql( + workspace_client: WorkspaceClient, + config: TranspileConfig, + sql: str, +) -> tuple[TranspilationResult, ValidationResult | None]: + """[Experimental] Transpile a single SQL query from one dialect to another.""" + ws_client: WorkspaceClient = verify_workspace_client(workspace_client) + + read_dialect: Dialect = config.get_read_dialect() + write_dialect: Dialect = config.get_write_dialect() + transpiler: SqlglotEngine = SqlglotEngine(read_dialect) + + transpiler_result = _parse(transpiler, write_dialect, sql, "inline_sql", []) + + if not config.skip_validation: + sql_backend = db_sql.get_sql_backend(ws_client) + logger.info(f"SQL Backend used for query validation: {type(sql_backend).__name__}") + validator = Validator(sql_backend) + return transpiler_result, _validation(validator, config, transpiler_result.transpiled_sql[0]) + + return transpiler_result, None + + +@timeit +def transpile_column_exp( + workspace_client: WorkspaceClient, + config: TranspileConfig, + expressions: list[str], +) -> list[tuple[TranspilationResult, ValidationResult | None]]: + """[Experimental] Transpile a list of SQL expressions from one dialect to another.""" + config.skip_validation = True + result = [] + for sql in expressions: + result.append(transpile_sql(workspace_client, config, sql)) + return result diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/__init__.py b/src/databricks/labs/remorph/transpiler/sqlglot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/generator/__init__.py b/src/databricks/labs/remorph/transpiler/sqlglot/generator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/generator/databricks.py b/src/databricks/labs/remorph/transpiler/sqlglot/generator/databricks.py new file mode 100644 index 0000000000..c465716fd5 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/generator/databricks.py @@ -0,0 +1,770 @@ +import logging +import re + +from sqlglot import expressions as exp +from sqlglot.dialects import databricks as org_databricks +from sqlglot.dialects import hive +from sqlglot.dialects.dialect import if_sql +from sqlglot.dialects.dialect import rename_func +from sqlglot.errors import UnsupportedError +from sqlglot.helper import apply_index_offset, csv + +from databricks.labs.remorph.transpiler.sqlglot import lca_utils, local_expression + +# pylint: disable=too-many-public-methods + +logger = logging.getLogger(__name__) + +VALID_DATABRICKS_TYPES = { + "BIGINT", + "BINARY", + "BOOLEAN", + "DATE", + "DECIMAL", + "DOUBLE", + "FLOAT", + "INT", + "INTERVAL", + "VOID", + "SMALLINT", + "STRING", + "TIMESTAMP", + "TINYINT", + "ARRAY", + "MAP", + "STRUCT", +} + +PRECISION_CONST = 38 +SCALE_CONST = 0 + + +def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str: + return self.func("DATE_TRUNC", exp.Literal.string(expression.text("unit").upper()), self.sql(expression.this)) + + +def _parm_sfx(self, expression: local_expression.Parameter) -> str: + this = self.sql(expression, "this") + this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}" + suffix = self.sql(expression, "suffix") + PARAMETER_TOKEN = "$" # noqa: N806 pylint: disable=invalid-name + return f"{PARAMETER_TOKEN}{this}{suffix}" + + +def _lateral_bracket_sql(self, expression: local_expression.Bracket) -> str: + """Overwrites `sqlglot/generator.py` `bracket_sql()` function + to convert `[COL_NAME]` to `.COL_NAME`. + Example: c[val] ==> c.val + """ + expressions = apply_index_offset(expression.this, expression.expressions, self.dialect.INDEX_OFFSET) + expressions = [self.sql(e.alias_or_name.strip("'")) for e in expressions] + # If expression contains space in between encode it in backticks(``): + # e.g. ref."ID Number" -> ref.`ID Number`. + expressions_sql = ", ".join(f"`{e}`" if " " in e else e for e in expressions) + return f"{self.sql(expression, 'this')}:{expressions_sql}" + + +def _format_create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + + # Remove modifiers in order to simplify the schema. For example, this removes things like "IF NOT EXISTS" + # from "CREATE TABLE foo IF NOT EXISTS". + args_to_delete = ["temporary", "transient", "external", "exists", "unique", "materialized", "properties"] + for arg_to_delete in args_to_delete: + if expression.args.get(arg_to_delete): + del expression.args[arg_to_delete] + + return self.create_sql(expression) + + +def _curr_time(): + return "date_format(current_timestamp(), 'HH:mm:ss')" + + +def _select_contains_index(expression: exp.Select) -> bool: + for expr in expression.expressions: + column = expr.unalias() if isinstance(expr, exp.Alias) else expr + if column.name == "index": + return True + return False + + +def _has_parse_json(expression): + if expression.find(exp.ParseJSON): + return True + _select = expression.find_ancestor(exp.Select) + if _select: + _from = _select.find(exp.From) + if _from: + _parse_json = _from.find(exp.ParseJSON) + if _parse_json: + return True + return False + + +def _generate_function_str(select_contains_index, has_parse_json, generator_expr, alias, is_outer, alias_str): + if select_contains_index: + generator_function_str = f"POSEXPLODE({generator_expr})" + alias_str = f"{' ' + alias.name if isinstance(alias, exp.TableAlias) else ''} AS index, value" + elif has_parse_json and is_outer: + generator_function_str = f"VARIANT_EXPLODE_OUTER({generator_expr})" + elif has_parse_json: + generator_function_str = f"VARIANT_EXPLODE({generator_expr})" + else: + generator_function_str = f"VIEW EXPLODE({generator_expr})" + + return generator_function_str, alias_str + + +def _generate_lateral_statement(self, select_contains_index, has_parse_json, generator_function_str, alias_str): + if select_contains_index: + lateral_statement = self.sql(f"LATERAL VIEW OUTER {generator_function_str}{alias_str}") + elif has_parse_json: + lateral_statement = self.sql(f", LATERAL {generator_function_str}{alias_str}") + else: + lateral_statement = self.sql(f" LATERAL {generator_function_str}{alias_str}") + + return lateral_statement + + +def _lateral_view(self: org_databricks.Databricks.Generator, expression: exp.Lateral) -> str: + has_parse_json = _has_parse_json(expression) + this = expression.args['this'] + alias = expression.args['alias'] + alias_str = f" AS {alias.name}" if isinstance(alias, exp.TableAlias) else "" + generator_function_str = self.sql(this) + is_outer = False + select_contains_index = False + + if isinstance(this, exp.Explode): + explode_expr = this + parent_select = explode_expr.parent_select + select_contains_index = _select_contains_index(parent_select) if parent_select else False + generator_expr = "" + if isinstance(explode_expr.this, exp.Kwarg): + generator_expr = self.sql(explode_expr.this, 'expression') + if not isinstance(explode_expr.this.expression, exp.ParseJSON): + generator_expr = generator_expr.replace("{", "").replace("}", "") + for expr in explode_expr.expressions: + node = str(expr.this).upper() + if node == "PATH": + generator_expr += "." + self.sql(expr, 'expression').replace("'", "") + if node == "OUTER": + is_outer = True + + if not generator_expr: + generator_expr = expression.this.this + + generator_function_str, alias_str = _generate_function_str( + select_contains_index, has_parse_json, generator_expr, alias, is_outer, alias_str + ) + + alias_cols = alias.columns if alias else [] + if len(alias_cols) <= 2: + alias_str = f" As {', '.join([item.this for item in alias_cols])}" + + lateral_statement = _generate_lateral_statement( + self, select_contains_index, has_parse_json, generator_function_str, alias_str + ) + return lateral_statement + + +# [TODO] Add more datatype coverage https://docs.databricks.com/sql/language-manual/sql-ref-datatypes.html +def _datatype_map(self, expression) -> str: + if expression.this in [exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.CHAR]: + return "STRING" + if expression.this in [exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMPLTZ]: + return "TIMESTAMP" + if expression.this == exp.DataType.Type.BINARY: + return "BINARY" + if expression.this == exp.DataType.Type.NCHAR: + return "STRING" + return self.datatype_sql(expression) + + +def try_to_date(self, expression: local_expression.TryToDate): + func = "TRY_TO_TIMESTAMP" + time_format = self.sql(expression, "format") + if not time_format: + time_format = hive.Hive.DATE_FORMAT + + ts_result = self.func(func, expression.this, time_format) + return exp.Date(this=ts_result) + + +def try_to_number(self, expression: local_expression.TryToNumber): + func = "TRY_TO_NUMBER" + precision = self.sql(expression, "precision") + scale = self.sql(expression, "scale") + + if not precision: + precision = 38 + + if not scale: + scale = 0 + + func_expr = self.func(func, expression.this) + if expression.expression: + func_expr = self.func(func, expression.this, expression.expression) + else: + func_expr = expression.this + + return f"CAST({func_expr} AS DECIMAL({precision}, {scale}))" + + +def _to_boolean(self: org_databricks.Databricks.Generator, expression: local_expression.ToBoolean) -> str: + this = self.sql(expression, "this") + logger.debug(f"Converting {this} to Boolean") + raise_error = self.sql(expression, "raise_error") + raise_error_str = "RAISE_ERROR('Invalid parameter type for TO_BOOLEAN')" if bool(int(raise_error)) else "NULL" + transformed = f""" + CASE + WHEN {this} IS NULL THEN NULL + WHEN TYPEOF({this}) = 'boolean' THEN BOOLEAN({this}) + WHEN TYPEOF({this}) = 'string' THEN + CASE + WHEN LOWER({this}) IN ('true', 't', 'yes', 'y', 'on', '1') THEN TRUE + WHEN LOWER({this}) IN ('false', 'f', 'no', 'n', 'off', '0') THEN FALSE + ELSE RAISE_ERROR('Boolean value of x is not recognized by TO_BOOLEAN') + END + WHEN TRY_CAST({this} AS DOUBLE) IS NOT NULL THEN + CASE + WHEN ISNAN(CAST({this} AS DOUBLE)) OR CAST({this} AS DOUBLE) = DOUBLE('infinity') THEN + RAISE_ERROR('Invalid parameter type for TO_BOOLEAN') + ELSE CAST({this} AS DOUBLE) != 0.0 + END + ELSE {raise_error_str} + END + """ + return transformed + + +def _is_integer(self: org_databricks.Databricks.Generator, expression: local_expression.IsInteger) -> str: + this = self.sql(expression, "this") + transformed = f""" + CASE + WHEN {this} IS NULL THEN NULL + WHEN {this} RLIKE '^-?[0-9]+$' AND TRY_CAST({this} AS INT) IS NOT NULL THEN TRUE + ELSE FALSE + END + """ + return transformed + + +def _parse_json_extract_path_text( + self: org_databricks.Databricks.Generator, expression: local_expression.JsonExtractPathText +) -> str: + this = self.sql(expression, "this") + path_name = expression.args["path_name"] + if path_name.is_string: + path = f"{self.dialect.QUOTE_START}$.{expression.text('path_name')}{self.dialect.QUOTE_END}" + else: + path = f"CONCAT('$.', {self.sql(expression, 'path_name')})" + return f"GET_JSON_OBJECT({this}, {path})" + + +def _array_construct_compact( + self: org_databricks.Databricks.Generator, expression: local_expression.ArrayConstructCompact +) -> str: + exclude = "ARRAY(NULL)" + array_expr = f"ARRAY({self.expressions(expression, flat=True)})" + return f"ARRAY_EXCEPT({array_expr}, {exclude})" + + +def _array_slice(self: org_databricks.Databricks.Generator, expression: local_expression.ArraySlice) -> str: + from_expr = self.sql(expression, "from") + # In Databricks: array indices start at 1 in function `slice(array, start, length)` + parsed_from_expr = 1 if from_expr == "0" else from_expr + + to_expr = self.sql(expression, "to") + # Convert string expression to number and check if it is negative number + if int(to_expr) < 0: + err_message = "In Databricks: function `slice` length must be greater than or equal to 0" + raise UnsupportedError(err_message) + + func = "SLICE" + func_expr = self.func(func, expression.this, exp.Literal.number(parsed_from_expr), expression.args["to"]) + return func_expr + + +def _to_command(self, expr: exp.Command): + this_sql = self.sql(expr, 'this') + expression = self.sql(expr.expression, 'this') + prefix = f"-- {this_sql}" + if this_sql == "!": + return f"{prefix}{expression}" + return f"{prefix} {expression}" + + +def _parse_json(self, expression: exp.ParseJSON) -> str: + return self.func("PARSE_JSON", expression.this, expression.expression) + + +def _to_number(self, expression: local_expression.ToNumber): + func = "TO_NUMBER" + precision = self.sql(expression, "precision") + scale = self.sql(expression, "scale") + + func_expr = expression.this + # if format is provided, else it will be vanilla cast to decimal + if expression.expression: + func_expr = self.func(func, expression.this, expression.expression) + if precision: + return f"CAST({func_expr} AS DECIMAL({precision}, {scale}))" + return func_expr + if not precision: + precision = 38 + if not scale: + scale = 0 + if not expression.expression and not precision: + exception_msg = f"""Error Parsing expression {expression}: + * `format`: is required in Databricks [mandatory] + * `precision` and `scale`: are considered as (38, 0) if not specified. + """ + raise UnsupportedError(exception_msg) + + precision = PRECISION_CONST if not precision else precision + scale = SCALE_CONST if not scale else scale + return f"CAST({func_expr} AS DECIMAL({precision}, {scale}))" + + +def _uuid(self: org_databricks.Databricks.Generator, expression: local_expression.UUID) -> str: + namespace = self.sql(expression, "this") + name = self.sql(expression, "name") + + if namespace and name: + logger.warning("UUID version 5 is not supported currently. Needs manual intervention.") + return f"UUID({namespace}, {name})" + + return "UUID()" + + +def _parse_date_trunc(self: org_databricks.Databricks.Generator, expression: local_expression.DateTrunc) -> str: + if not expression.args.get("unit"): + error_message = f"Required keyword: 'unit' missing for {exp.DateTrunc}" + raise UnsupportedError(error_message) + return self.func("TRUNC", expression.this, expression.args.get("unit")) + + +def _get_within_group_params( + expr: exp.ArrayAgg | exp.GroupConcat, + within_group: exp.WithinGroup, +) -> local_expression.WithinGroupParams: + has_distinct = isinstance(expr.this, exp.Distinct) + agg_col = expr.this.expressions[0] if has_distinct else expr.this + order_clause = within_group.expression + order_cols = [] + for e in order_clause.expressions: + desc = e.args.get("desc") + is_order_a = not desc or exp.false() == desc + order_cols.append((e.this, is_order_a)) + return local_expression.WithinGroupParams( + agg_col=agg_col, + order_cols=order_cols, + ) + + +def _create_named_struct_for_cmp(wg_params: local_expression.WithinGroupParams) -> exp.Expression: + agg_col = wg_params.agg_col + order_kv = [] + for i, (col, _) in enumerate(wg_params.order_cols): + order_kv.extend([exp.Literal(this=f"sort_by_{i}", is_string=True), col]) + + named_struct_func = exp.Anonymous( + this="named_struct", + expressions=[ + exp.Literal(this="value", is_string=True), + agg_col, + *order_kv, + ], + ) + return named_struct_func + + +def _current_date(self, expression: exp.CurrentDate) -> str: + zone = self.sql(expression, "this") + return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE()" + + +def _not_sql(self, expression: exp.Not) -> str: + if isinstance(expression.this, exp.Is): + return f"{self.sql(expression.this, 'this')} IS NOT {self.sql(expression.this, 'expression')}" + return f"NOT {self.sql(expression, 'this')}" + + +def to_array(self, expression: exp.ToArray) -> str: + return f"IF({self.sql(expression.this)} IS NULL, NULL, {self.func('ARRAY', expression.this)})" + + +class Databricks(org_databricks.Databricks): # + # Instantiate Databricks Dialect + databricks = org_databricks.Databricks() + NULL_ORDERING = "nulls_are_small" + + class Generator(org_databricks.Databricks.Generator): + INVERSE_TIME_MAPPING: dict[str, str] = { + **{v: k for k, v in org_databricks.Databricks.TIME_MAPPING.items()}, + "%-d": "dd", + } + + COLLATE_IS_FUNC = True + # [TODO]: Variant needs to be transformed better, for now parsing to string was deemed as the choice. + TYPE_MAPPING = { + **org_databricks.Databricks.Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "TINYINT", + exp.DataType.Type.SMALLINT: "SMALLINT", + exp.DataType.Type.INT: "INT", + exp.DataType.Type.BIGINT: "BIGINT", + exp.DataType.Type.DATETIME: "TIMESTAMP", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.VARIANT: "VARIANT", + exp.DataType.Type.FLOAT: "DOUBLE", + exp.DataType.Type.OBJECT: "STRING", + exp.DataType.Type.GEOGRAPHY: "STRING", + } + + TRANSFORMS = { + **org_databricks.Databricks.Generator.TRANSFORMS, + exp.Create: _format_create_sql, + exp.DataType: _datatype_map, + exp.CurrentTime: _curr_time(), + exp.Lateral: _lateral_view, + exp.FromBase64: rename_func("UNBASE64"), + exp.AutoIncrementColumnConstraint: lambda *_: "GENERATED ALWAYS AS IDENTITY", + local_expression.Parameter: _parm_sfx, + local_expression.ToBoolean: _to_boolean, + local_expression.Bracket: _lateral_bracket_sql, + local_expression.MakeDate: rename_func("MAKE_DATE"), + local_expression.TryToDate: try_to_date, + local_expression.TryToNumber: try_to_number, + local_expression.IsInteger: _is_integer, + local_expression.JsonExtractPathText: _parse_json_extract_path_text, + local_expression.BitOr: rename_func("BIT_OR"), + local_expression.ArrayConstructCompact: _array_construct_compact, + local_expression.ArrayIntersection: rename_func("ARRAY_INTERSECT"), + local_expression.ArraySlice: _array_slice, + local_expression.ObjectKeys: rename_func("JSON_OBJECT_KEYS"), + exp.ParseJSON: _parse_json, + local_expression.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), + local_expression.ToDouble: rename_func("DOUBLE"), + exp.Rand: rename_func("RANDOM"), + local_expression.ToVariant: rename_func("TO_JSON"), + local_expression.ToObject: rename_func("TO_JSON"), + exp.ToBase64: rename_func("BASE64"), + local_expression.ToNumber: _to_number, + local_expression.UUID: _uuid, + local_expression.DateTrunc: _parse_date_trunc, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), + exp.TimestampTrunc: timestamptrunc_sql, + exp.Mod: rename_func("MOD"), + exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), + exp.If: if_sql(false_value="NULL"), + exp.Command: _to_command, + exp.CurrentDate: _current_date, + exp.Not: _not_sql, + local_expression.ToArray: to_array, + local_expression.ArrayExists: rename_func("EXISTS"), + } + + def preprocess(self, expression: exp.Expression) -> exp.Expression: + fixed_ast = expression.transform(lca_utils.unalias_lca_in_select, copy=False) + return super().preprocess(fixed_ast) + + def format_time(self, expression: exp.Expression, inverse_time_mapping=None, inverse_time_trie=None): + return super().format_time(expression, self.INVERSE_TIME_MAPPING) + + def join_sql(self, expression: exp.Join) -> str: + """Overwrites `join_sql()` in `sqlglot/generator.py` + Added logic to handle Lateral View + """ + op_list = [ + expression.method, + "GLOBAL" if expression.args.get("global") else None, + expression.side, + expression.kind, + expression.hint if self.JOIN_HINTS else None, + ] + + op_sql = " ".join(op for op in op_list if op) + on_sql = self.sql(expression, "on") + using = expression.args.get("using") + + if not on_sql and using: + on_sql = csv(*(self.sql(column) for column in using)) + + this_sql = self.sql(expression, "this") + + if on_sql: + on_sql = self.indent(on_sql, skip_first=True) + space = self.seg(" " * self.pad) if self.pretty else " " + if using: + on_sql = f"{space}USING ({on_sql})" + else: + on_sql = f"{space}ON {on_sql}" + # Added the below elif block to handle Lateral View clause + elif not op_sql and isinstance(expression.this, exp.Lateral): + return f"\n {this_sql}" + elif not op_sql: + return f", {this_sql}" + + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + return f"{self.seg(op_sql)} {this_sql}{on_sql}" + + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: + sql = self.func("ARRAY_AGG", expression.this) + within_group = expression.parent if isinstance(expression.parent, exp.WithinGroup) else None + if not within_group: + return sql + + wg_params = _get_within_group_params(expression, within_group) + if len(wg_params.order_cols) == 1: + order_col, is_order_asc = wg_params.order_cols[0] + if wg_params.agg_col == order_col: + return f"SORT_ARRAY({sql}{'' if is_order_asc else ', FALSE'})" + + named_struct_func = _create_named_struct_for_cmp(wg_params) + comparisons = [] + for i, (_, is_order_asc) in enumerate(wg_params.order_cols): + comparisons.append( + f"WHEN left.sort_by_{i} < right.sort_by_{i} THEN {'-1' if is_order_asc else '1'} " + f"WHEN left.sort_by_{i} > right.sort_by_{i} THEN {'1' if is_order_asc else '-1'}" + ) + + array_sort = self.func( + "ARRAY_SORT", + self.func("ARRAY_AGG", named_struct_func), + f"""(left, right) -> CASE + {' '.join(comparisons)} + ELSE 0 + END""", + ) + return self.func("TRANSFORM", array_sort, "s -> s.value") + + def groupconcat_sql(self, expr: exp.GroupConcat) -> str: + arr_agg = exp.ArrayAgg(this=expr.this) + within_group = expr.parent.copy() if isinstance(expr.parent, exp.WithinGroup) else None + if within_group: + arr_agg.parent = within_group + + return self.func( + "ARRAY_JOIN", + arr_agg, + expr.args.get("separator") or exp.Literal(this="", is_string=True), + ) + + def withingroup_sql(self, expression: exp.WithinGroup) -> str: + agg_expr = expression.this + if isinstance(agg_expr, (exp.ArrayAgg, exp.GroupConcat)): + return self.sql(agg_expr) + + return super().withingroup_sql(expression) + + def split_sql(self, expression: local_expression.Split) -> str: + """ + :param expression: local_expression.Split expression to be parsed + :return: Converted expression (SPLIT) compatible with Databricks + """ + delimiter = " " + # To handle default delimiter + if expression.expression: + delimiter = expression.expression.name + + # Parsing logic to handle String and Table columns + if expression.name and isinstance(expression.name, str): + expr_name = f"'{expression.name}'" + else: + expr_name = expression.args["this"] + return f"""SPLIT({expr_name},'[{delimiter}]')""" + + def delete_sql(self, expression: exp.Delete) -> str: + this = self.sql(expression, "this") + using = self.sql(expression, "using") + where = self.sql(expression, "where") + returning = self.sql(expression, "returning") + limit = self.sql(expression, "limit") + tables = self.expressions(expression, key="tables") + tables = f" {tables}" if tables else "" + + if using: + using = f" USING {using}" if using else "" + where = where.replace("WHERE", "ON") + else: + this = f"FROM {this}" if this else "" + + if self.RETURNING_END: + expression_sql = f" {this}{using}{where}{returning}{limit}" + else: + expression_sql = f"{returning}{this}{where}{limit}" + + if using: + return self.prepend_ctes(expression, f"MERGE INTO {tables}{expression_sql} WHEN MATCHED THEN DELETE;") + + return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql};") + + def converttimezone_sql(self, expression: exp.ConvertTimezone): + func = "CONVERT_TIMEZONE" + expr = expression.args["tgtTZ"] + if len(expression.args) == 3 and expression.args.get("this"): + expr = expression.args["this"] + + result = self.func(func, expression.args["srcTZ"], expr) + if len(expression.args) == 3: + result = self.func(func, expression.args["srcTZ"], expression.args["tgtTZ"], expr) + + return result + + def strtok_sql(self, expression: local_expression.StrTok) -> str: + """ + :param expression: local_expression.StrTok expression to be parsed + :return: Converted expression (SPLIT_PART) compatible with Databricks + """ + # To handle default delimiter + if expression.expression: + delimiter = expression.expression.name + else: + delimiter = " " + + # Handle String and Table columns + if expression.name and isinstance(expression.name, str): + expr_name = f"'{expression.name}'" + else: + expr_name = expression.args["this"] + + # Handle Partition Number + if len(expression.args) == 3 and expression.args.get("partNum"): + part_num = expression.args["partNum"] + else: + part_num = 1 + + return f"SPLIT_PART({expr_name}, '{delimiter}', {part_num})" + + def splitpart_sql(self, expression: local_expression.SplitPart) -> str: + """ + :param expression: local_expression.SplitPart expression to be parsed + :return: Converted expression (SPLIT_PART) compatible with Databricks + """ + expr_name = self.sql(expression.this) + delimiter = self.sql(expression.expression) + part_num = self.sql(expression.args["partNum"]) + return f"SPLIT_PART({expr_name}, {delimiter}, {part_num})" + + def transaction_sql(self, expression: exp.Transaction) -> str: + """ + Skip begin command + :param expression: + :return: Empty string for unsupported operation + """ + return "" + + def rollback_sql(self, expression: exp.Rollback) -> str: + """ + Skip rollback command + :param expression: + :return: Empty string for unsupported operation + """ + return "" + + def commit_sql(self, expression: exp.Commit) -> str: + """ + Skip commit command + :param expression: + :return: Empty string for unsupported operation + """ + return "" + + def command_sql(self, expression: exp.Command) -> str: + """ + Skip any session, stream, task related commands + :param expression: + :return: Empty string for unsupported operations or objects + """ + filtered_commands = [ + "CREATE", + "ALTER", + "DESCRIBE", + "DROP", + "SHOW", + "EXECUTE", + ] + ignored_objects = [ + "STREAM", + "TASK", + "STREAMS", + "TASKS", + "SESSION", + ] + + command = self.sql(expression, "this").upper() + expr = expression.text("expression").strip() + obj = re.split(r"\s+", expr, maxsplit=2)[0].upper() if expr else "" + if command in filtered_commands and obj in ignored_objects: + return "" + return f"{command} {expr}" + + def currenttimestamp_sql(self, _: exp.CurrentTimestamp) -> str: + return self.func("CURRENT_TIMESTAMP") + + def update_sql(self, expression: exp.Update) -> str: + this = self.sql(expression, "this") + set_sql = self.expressions(expression, flat=True) + from_sql = self.sql(expression, "from") + where_sql = self.sql(expression, "where") + returning = self.sql(expression, "returning") + order = self.sql(expression, "order") + limit = self.sql(expression, "limit") + + if from_sql: + from_sql = from_sql.replace("FROM", "USING", 1) + where_sql = where_sql.replace("WHERE", "ON") + + if self.RETURNING_END: + expression_sql = f"{from_sql}{where_sql}{returning}" + else: + expression_sql = f"{returning}{from_sql}{where_sql}" + + if from_sql: + sql = f"MERGE INTO {this}{expression_sql} WHEN MATCHED THEN UPDATE SET {set_sql}{order}{limit}" + else: + sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}" + + return self.prepend_ctes(expression, sql) + + def struct_sql(self, expression: exp.Struct) -> str: + expression.set( + "expressions", + [ + ( + exp.alias_( + e.expression, e.name if hasattr(e.this, "is_string") and e.this.is_string else e.this + ) + if isinstance(e, exp.PropertyEQ) + else e + ) + for e in expression.expressions + ], + ) + + return self.function_fallback_sql(expression) + + def anonymous_sql(self: org_databricks.Databricks.Generator, expression: exp.Anonymous) -> str: + if expression.this == "EDITDISTANCE": + return self.func("LEVENSHTEIN", *expression.expressions) + if expression.this == "TO_TIMESTAMP": + return self.sql( + exp.Cast(this=expression.expressions[0], to=exp.DataType(this=exp.DataType.Type.TIMESTAMP)) + ) + + return self.func(self.sql(expression, "this"), *expression.expressions) + + def order_sql(self, expression: exp.Order, flat: bool = False) -> str: + if isinstance(expression.parent, exp.Window): + for ordered_expression in expression.expressions: + if isinstance(ordered_expression, exp.Ordered) and ordered_expression.args.get('desc') is None: + ordered_expression.args['desc'] = False + return super().order_sql(expression, flat) + + def add_column_sql(self, expression: exp.Alter) -> str: + # Final output contains ADD COLUMN before each column + # This function will handle this issue and return the final output + columns = self.expressions(expression, key="actions", flat=True) + return f"ADD COLUMN {columns}" diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/lca_utils.py b/src/databricks/labs/remorph/transpiler/sqlglot/lca_utils.py new file mode 100644 index 0000000000..9a4981dccf --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/lca_utils.py @@ -0,0 +1,137 @@ +import logging +from collections.abc import Iterable + +from sqlglot import expressions as exp +from sqlglot import parse +from sqlglot.dialects.dialect import DialectType +from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError +from sqlglot.expressions import Expression, Select +from sqlglot.optimizer.scope import Scope, build_scope + +from databricks.labs.remorph.transpiler.transpile_status import ValidationError +from databricks.labs.remorph.transpiler.sqlglot.local_expression import AliasInfo + +logger = logging.getLogger(__name__) + + +def check_for_unsupported_lca( + dialect: DialectType, + sql: str, + filename: str, +) -> ValidationError | None: + """ + Check for presence of unsupported lateral column aliases in window expressions and where clauses + :return: An error if found + """ + + try: + all_parsed_expressions: Iterable[Expression | None] = parse(sql, read=dialect, error_level=ErrorLevel.RAISE) + root_expressions: Iterable[Expression] = [pe for pe in all_parsed_expressions if pe is not None] + except (ParseError, TokenError, UnsupportedError) as e: + logger.warning(f"Error while preprocessing {filename}: {e}") + return None + + aliases_in_where = set() + aliases_in_window = set() + + for expr in root_expressions: + for select in expr.find_all(exp.Select, bfs=False): + alias_info = _find_aliases_in_select(select) + aliases_in_where.update(_find_invalid_lca_in_where(select, alias_info)) + aliases_in_window.update(_find_invalid_lca_in_window(select, alias_info)) + + if not (aliases_in_where or aliases_in_window): + return None + + err_messages = [f"Unsupported operation found in file {filename}. Needs manual review of transpiled query."] + if aliases_in_where: + err_messages.append(f"Lateral column aliases `{', '.join(aliases_in_where)}` found in where clause.") + + if aliases_in_window: + err_messages.append(f"Lateral column aliases `{', '.join(aliases_in_window)}` found in window expressions.") + + return ValidationError(filename, " ".join(err_messages)) + + +def unalias_lca_in_select(expr: exp.Expression) -> exp.Expression: + if not isinstance(expr, exp.Select): + return expr + + root_select: Scope | None = build_scope(expr) + if not root_select: + return expr + + # We won't search inside nested selects, they will be visited separately + nested_selects = {*root_select.derived_tables, *root_select.subqueries} + alias_info = _find_aliases_in_select(expr) + where_ast: Expression | None = expr.args.get("where") + if where_ast: + for column in where_ast.walk(prune=lambda n: n in nested_selects): + _replace_aliases(column, alias_info) + for window in _find_windows_in_select(expr): + for column in window.walk(): + _replace_aliases(column, alias_info) + return expr + + +def _replace_aliases(column: Expression, alias_info: dict[str, AliasInfo]): + if ( + isinstance(column, exp.Column) + and column.name in alias_info + and not alias_info[column.name].is_same_name_as_column + ): + unaliased_expr = alias_info[column.name].expression + column.replace(unaliased_expr) + for col in unaliased_expr.walk(): + _replace_aliases(col, alias_info) + + +def _find_windows_in_select(select: Select) -> list[exp.Window]: + window_expressions = [] + for expr in select.expressions: + window_expr = expr.find(exp.Window) + if window_expr: + window_expressions.append(window_expr) + return window_expressions + + +def _find_aliases_in_select(select_expr: Select) -> dict[str, AliasInfo]: + aliases = {} + for expr in select_expr.expressions: + if isinstance(expr, exp.Alias): + alias_name = expr.output_name + is_same_name_as_column = False + for column in expr.find_all(exp.Column): + if column.name == alias_name: + is_same_name_as_column = True + break + aliases[alias_name] = AliasInfo(alias_name, expr.unalias().copy(), is_same_name_as_column) + return aliases + + +def _find_invalid_lca_in_where( + select_expr: Select, + aliases: dict[str, AliasInfo], +) -> set[str]: + aliases_in_where = set() + where_ast: Expression | None = select_expr.args.get("where") + if where_ast: + for column in where_ast.find_all(exp.Column): + if column.name in aliases and not aliases[column.name].is_same_name_as_column: + aliases_in_where.add(column.name) + + return aliases_in_where + + +def _find_invalid_lca_in_window( + select_expr: Select, + aliases: dict[str, AliasInfo], +) -> set[str]: + aliases_in_window = set() + windows = _find_windows_in_select(select_expr) + for window in windows: + for column in window.find_all(exp.Column): + if column.name in aliases and not aliases[column.name].is_same_name_as_column: + aliases_in_window.add(column.name) + + return aliases_in_window diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/local_expression.py b/src/databricks/labs/remorph/transpiler/sqlglot/local_expression.py new file mode 100644 index 0000000000..e61e97cbdd --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/local_expression.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass + +from sqlglot import expressions as exp +from sqlglot.expressions import AggFunc, Condition, Expression, Func + + +class NthValue(AggFunc): + arg_types = {"this": True, "offset": False} + + +class Parameter(Expression): + arg_types = {"this": True, "wrapped": False, "suffix": False} + + +class Collate(Func): + arg_types = {"this": True, "expressions": True} + + +class Bracket(Condition): + arg_types = {"this": True, "expressions": True} + + +class Split(Func): + """ + Redefined Split(sqlglot/expression) class with expression: False to handle default delimiter + Please refer the test case `test_strtok_to_array` -> `select STRTOK_TO_ARRAY('my text is divided')` + """ + + arg_types = {"this": True, "expression": False, "limit": False} + + +class MakeDate(Func): + arg_types = {"this": True, "expression": False, "zone": False} + + +class ConvertTimeZone(Func): + arg_types = {"srcTZ": True, "tgtTZ": True, "this": False} + + +class TryToDate(Func): + arg_types = {"this": True, "format": False} + + +class TryToTimestamp(Func): + arg_types = {"this": True, "format": False} + + +class SplitPart(Func): + arg_types = {"this": True, "expression": False, "partNum": False} + + +class StrTok(Func): + arg_types = {"this": True, "expression": False, "partNum": False} + + +class TryToNumber(Func): + arg_types = {"this": True, "expression": False, "precision": False, "scale": False} + + _sql_names = ["TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"] + + +class DateFormat(Func): + arg_types = {"this": True, "expression": False} + + +class IsInteger(Func): + pass + + +class JsonExtractPathText(Func): + arg_types = {"this": True, "path_name": True} + + +class BitOr(AggFunc): + pass + + +class ArrayConstructCompact(Func): + arg_types = {"expressions": False} + + is_var_len_args = True + + +class ArrayIntersection(Func): + arg_types = {"this": True, "expression": True} + + +class ArraySlice(Func): + arg_types = {"this": True, "from": True, "to": True} + + +class ObjectKeys(Func): + arg_types = {"this": True} + + +class ToBoolean(Func): + arg_types = {"this": True, "raise_error": False} + + +class ToDouble(Func): + pass + + +class ToObject(Func): + pass + + +class ToNumber(Func): + arg_types = {"this": True, "expression": False, "precision": False, "scale": False} + + _sql_names = ["TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"] + + +class TimestampFromParts(Func): + arg_types = { + "this": True, + "expression": True, + "day": True, + "hour": True, + "min": True, + "sec": True, + "nanosec": False, + "Zone": False, + } + + +class ToVariant(Func): + pass + + +class UUID(Func): + arg_types = {"this": False, "name": False} + + +class DateTrunc(Func): + arg_types = {"unit": False, "this": True, "zone": False} + + +class Median(Func): + arg_types = {"this": True} + + +class CumeDist(Func): + arg_types = {"this": False} + + +class DenseRank(Func): + arg_types = {"this": False} + + +class Rank(Func): + arg_types = {"this": False} + + +class PercentRank(Func): + arg_types = {"this": False} + + +class Ntile(Func): + arg_types = {"this": True, "is_string": False} + + +class ToArray(Func): + arg_types = {"this": True, "expression": False} + + +@dataclass +class WithinGroupParams: + agg_col: exp.Column + order_cols: list[tuple[exp.Column, bool]] # List of (column, is ascending) + + +@dataclass +class AliasInfo: + name: str + expression: exp.Expression + is_same_name_as_column: bool + + +class MapKeys(Func): + arg_types = {"this": True} + + +class ArrayExists(Func): + arg_types = {"this": True, "expression": True} + + +class Locate(Func): + arg_types = {"substring": True, "this": True, "position": False} + + +class NamedStruct(Func): + arg_types = {"expressions": True} + + +class GetJsonObject(Func): + arg_types = {"this": True, "path": True} diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/parsers/__init__.py b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/parsers/bigquery.py b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/bigquery.py new file mode 100644 index 0000000000..97db07c5e7 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/bigquery.py @@ -0,0 +1,94 @@ +import logging +from sqlglot import exp +from sqlglot.helper import seq_get +from sqlglot.dialects.bigquery import BigQuery as bigquery + +from databricks.labs.remorph.transpiler.sqlglot import local_expression + +logger = logging.getLogger(__name__) + +bigquery_to_databricks = { + "%A": "EEEE", # Full weekday name + "%a": "EEE", # Abbreviated weekday name + "%B": "MMMM", # Full month name + "%b": "MMM", # Abbreviated month name + "%C": "yy", # Century + "%c": "EEE MMM dd HH:mm:ss yyyy", # Date and time representation + "%D": "MM/dd/yy", # Date in mm/dd/yy + "%d": "dd", # Day of the month (2 digits) + "%e": "d", # Day of month (single digit without leading zero) + "%F": "yyyy-MM-dd", # ISO 8601 date + "%H": "HH", # 24-hour clock hour + "%h": "MMM", # Abbreviated month name (duplicate of %b in BigQuery) + "%I": "hh", # 12-hour clock hour + "%j": "DDD", # Day of year + "%k": "H", # 24-hour clock without leading zero + "%l": "h", # 12-hour clock without leading zero + "%M": "mm", # Minute + "%m": "MM", # Month (2 digits) + "%P": "a", # am/pm in lowercase + "%p": "a", # AM/PM in uppercase + "%Q": "q", # Quarter + "%R": "HH:mm", # Time in HH:mm + "%S": "ss", # Second + "%s": "epoch", # Seconds since epoch (special handling required) + "%T": "HH:mm:ss", # Time in HH:mm:ss + "%U": "ww", # Week number of year (Sunday start) + "%u": "e", # ISO weekday (Monday start) + "%V": "ww", # ISO week number of year + "%W": "ww", # Week number (Monday start) + "%w": "e", # Weekday (Sunday start) + "%X": "HH:mm:ss", # Time representation + "%x": "MM/dd/yy", # Date representation + "%Y": "yyyy", # Year with century + "%y": "yy", # Year without century + "%Z": "z", # Time zone name + "%z": "xxx", # Time zone offset + "%%": "%", # Literal percent + "%Ez": "xxx", # RFC 3339 numeric time zone + "%E*S": "ss.SSSSSS", # Full fractional seconds + "%E4Y": "yyyy", # Four-character years +} + + +def _parse_format_date(args: list): + format_element = str(args[0].this) + if format_element == "%s": + return exp.StrToUnix(this=seq_get(args, 1)) + if format_element == "%V": + return exp.Extract(this=exp.Var(this="W"), expression=seq_get(args, 1)) + if format_element == "%u": + return exp.Extract(this=exp.Var(this="DAYOFWEEK_ISO"), expression=seq_get(args, 1)) + if format_element == "%w": + return exp.Sub( + this=exp.Extract(this=exp.Var(this="DAYOFWEEK"), expression=seq_get(args, 1)), + expression=exp.Literal(this='1', is_string=False), + ) + if format_element == "%C": + return exp.Round( + this=exp.Div( + this=exp.Extract(this=exp.Var(this="YEAR"), expression=seq_get(args, 1)), + expression=exp.Literal(this='100', is_string=False), + ) + ) + databricks_datetime_pattern = ( + bigquery_to_databricks.get(format_element) if format_element in bigquery_to_databricks else format_element + ) + return local_expression.DateFormat( + this=seq_get(args, 1), expression=exp.Literal(this=databricks_datetime_pattern, is_string=True) + ) + + +class BigQuery(bigquery): + class Parser(bigquery.Parser): + VALUES_FOLLOWED_BY_PAREN = False + + FUNCTIONS = { + **bigquery.Parser.FUNCTIONS, + "FORMAT_DATE": _parse_format_date, + } + + class Tokenizer(bigquery.Tokenizer): + KEYWORDS = { + **bigquery.Tokenizer.KEYWORDS, + } diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/parsers/oracle.py b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/oracle.py new file mode 100644 index 0000000000..8922f17ca8 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/oracle.py @@ -0,0 +1,23 @@ +from sqlglot.dialects.oracle import Oracle as Orc +from sqlglot.tokens import TokenType + + +class Oracle(Orc): + # Instantiate Oracle Dialect + oracle = Orc() + + class Tokenizer(Orc.Tokenizer): + KEYWORDS = { + **Orc.Tokenizer.KEYWORDS, + 'LONG': TokenType.TEXT, + 'NCLOB': TokenType.TEXT, + 'ROWID': TokenType.TEXT, + 'UROWID': TokenType.TEXT, + 'ANYTYPE': TokenType.TEXT, + 'ANYDATA': TokenType.TEXT, + 'ANYDATASET': TokenType.TEXT, + 'XMLTYPE': TokenType.TEXT, + 'SDO_GEOMETRY': TokenType.TEXT, + 'SDO_TOPO_GEOMETRY': TokenType.TEXT, + 'SDO_GEORASTER': TokenType.TEXT, + } diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/parsers/presto.py b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/presto.py new file mode 100644 index 0000000000..a0f1e19801 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/presto.py @@ -0,0 +1,202 @@ +import logging +from sqlglot.dialects.presto import Presto as presto +from sqlglot import exp +from sqlglot.helper import seq_get +from sqlglot.errors import ParseError +from sqlglot.tokens import TokenType + +from databricks.labs.remorph.transpiler.sqlglot import local_expression + + +logger = logging.getLogger(__name__) + + +def _build_approx_percentile(args: list) -> exp.Expression: + if len(args) == 4: + arg3 = seq_get(args, 3) + try: + number = float(arg3.this) if arg3 is not None else 0 + return exp.ApproxQuantile( + this=seq_get(args, 0), + weight=seq_get(args, 1), + quantile=seq_get(args, 2), + accuracy=exp.Literal(this=f'{int((1/number) * 100)} ', is_string=False), + ) + except ValueError as exc: + raise ParseError(f"Expected a string representation of a number for argument 2, but got {arg3}") from exc + if len(args) == 3: + arg2 = seq_get(args, 2) + try: + number = float(arg2.this) if arg2 is not None else 0 + return exp.ApproxQuantile( + this=seq_get(args, 0), + quantile=seq_get(args, 1), + accuracy=exp.Literal(this=f'{int((1/number) * 100)}', is_string=False), + ) + except ValueError as exc: + raise ParseError(f"Expected a string representation of a number for argument 2, but got {arg2}") from exc + return exp.ApproxQuantile.from_arg_list(args) + + +def _build_any_keys_match(args: list) -> local_expression.ArrayExists: + return local_expression.ArrayExists( + this=local_expression.MapKeys(this=seq_get(args, 0)), expression=seq_get(args, 1) + ) + + +def _build_str_position(args: list) -> local_expression.Locate: + # TODO the 3rd param in presto strpos and databricks locate has different implementation. + # For now we haven't implemented the logic same as presto for 3rd param. + # Users should be vigilant when using 3 param function in presto strpos. + if len(args) == 3: + msg = ( + "*Warning:: The third parameter in Presto's `strpos` function and Databricks' `locate` function " + "have different implementations. Please exercise caution when using the three-parameter version " + "of the `strpos` function in Presto." + ) + logger.warning(msg) + return local_expression.Locate(substring=seq_get(args, 1), this=seq_get(args, 0), position=seq_get(args, 2)) + return local_expression.Locate(substring=seq_get(args, 1), this=seq_get(args, 0)) + + +def _build_array_average(args: list) -> exp.Reduce: + return exp.Reduce( + this=exp.ArrayFilter( + this=seq_get(args, 0), + expression=exp.Lambda( + this=exp.Not(this=exp.Is(this=exp.Identifier(this="x", quoted=False), expression=exp.Null())), + expressions=[exp.Identifier(this="x", quoted=False)], + ), + ), + initial=local_expression.NamedStruct( + expressions=[ + exp.Literal(this="sum", is_string=True), + exp.Cast(this=exp.Literal(this="0", is_string=False), to=exp.DataType(this="DOUBLE")), + exp.Literal(this="cnt", is_string=True), + exp.Literal(this="0", is_string=False), + ], + ), + merge=exp.Lambda( + this=local_expression.NamedStruct( + expressions=[ + exp.Literal(this="sum", is_string=True), + exp.Add( + this=exp.Dot( + this=exp.Identifier(this="acc", quoted=False), + expression=exp.Identifier(this="sum", quoted=False), + ), + expression=exp.Identifier(this="x", quoted=False), + ), + exp.Literal(this="cnt", is_string=True), + exp.Add( + this=exp.Dot( + this=exp.Identifier(this="acc", quoted=False), + expression=exp.Identifier(this="cnt", quoted=False), + ), + expression=exp.Literal(this="1", is_string=False), + ), + ], + ), + expressions=[exp.Identifier(this="acc", quoted=False), exp.Identifier(this="x", quoted=False)], + ), + finish=exp.Lambda( + this=exp.Anonymous( + this="try_divide", + expressions=[ + exp.Dot( + this=exp.Identifier(this="acc", quoted=False), + expression=exp.Identifier(this="sum", quoted=False), + ), + exp.Dot( + this=exp.Identifier(this="acc", quoted=False), + expression=exp.Identifier(this="cnt", quoted=False), + ), + ], + ), + expressions=[exp.Identifier(this="acc", quoted=False)], + ), + ) + + +def _build_json_size(args: list): + return exp.Case( + ifs=[ + exp.If( + this=exp.Like( + this=local_expression.GetJsonObject( + this=exp.Column(this=seq_get(args, 0)), + path=exp.Column(this=seq_get(args, 1)), + ), + expression=exp.Literal(this="{%", is_string=True), + ), + true=exp.ArraySize( + this=exp.Anonymous( + this="from_json", + expressions=[ + local_expression.GetJsonObject( + this=exp.Column(this=seq_get(args, 0)), + path=exp.Column(this=seq_get(args, 1)), + ), + exp.Literal(this="map", is_string=True), + ], + ) + ), + ), + exp.If( + this=exp.Like( + this=local_expression.GetJsonObject( + this=exp.Column(this=seq_get(args, 0)), + path=exp.Column(this=seq_get(args, 1)), + ), + expression=exp.Literal(this="[%", is_string=True), + ), + true=exp.ArraySize( + this=exp.Anonymous( + this="from_json", + expressions=[ + local_expression.GetJsonObject( + this=exp.Column(this=seq_get(args, 0)), + path=exp.Column(this=seq_get(args, 1)), + ), + exp.Literal(this="array", is_string=True), + ], + ) + ), + ), + exp.If( + this=exp.Not( + this=exp.Is( + this=local_expression.GetJsonObject( + this=exp.Column(this=seq_get(args, 0)), + path=exp.Column(this=seq_get(args, 1)), + ), + expression=exp.Null(), + ) + ), + true=exp.Literal(this="0", is_string=False), + ), + ], + default=exp.Null(), + ) + + +class Presto(presto): + + class Parser(presto.Parser): + VALUES_FOLLOWED_BY_PAREN = False + + FUNCTIONS = { + **presto.Parser.FUNCTIONS, + "APPROX_PERCENTILE": _build_approx_percentile, + "STRPOS": _build_str_position, + "ANY_KEYS_MATCH": _build_any_keys_match, + "ARRAY_AVERAGE": _build_array_average, + "JSON_SIZE": _build_json_size, + "FORMAT_DATETIME": local_expression.DateFormat.from_arg_list, + } + + class Tokenizer(presto.Tokenizer): + KEYWORDS = { + **presto.Tokenizer.KEYWORDS, + "JSON": TokenType.TEXT, + } diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/parsers/snowflake.py b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/snowflake.py new file mode 100644 index 0000000000..8a1ddce5d6 --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/parsers/snowflake.py @@ -0,0 +1,506 @@ +import logging +import re + +from sqlglot import expressions as exp +from sqlglot.dialects.dialect import build_date_delta as parse_date_delta, build_formatted_time +from sqlglot.dialects.snowflake import Snowflake as SqlglotSnowflake +from sqlglot.errors import ParseError +from sqlglot.helper import is_int, seq_get +from sqlglot.optimizer.simplify import simplify_literals +from sqlglot.parser import build_var_map as parse_var_map +from sqlglot.tokens import Token, TokenType +from sqlglot.trie import new_trie + +from databricks.labs.remorph.transpiler.sqlglot import local_expression + +logger = logging.getLogger(__name__) +# pylint: disable=protected-access +""" SF Supported Date and Time Parts: + https://docs.snowflake.com/en/sql-reference/functions-date-time#label-supported-date-time-parts + Covers DATEADD, DATEDIFF, DATE_TRUNC, LAST_DAY +""" +DATE_DELTA_INTERVAL = { + "years": "year", + "year": "year", + "yrs": "year", + "yr": "year", + "yyyy": "year", + "yyy": "year", + "yy": "year", + "y": "year", + "quarters": "quarter", + "quarter": "quarter", + "qtrs": "quarter", + "qtr": "quarter", + "q": "quarter", + "months": "month", + "month": "month", + "mons": "month", + "mon": "month", + "mm": "month", + "weekofyear": "week", + "week": "week", + "woy": "week", + "wy": "week", + "wk": "week", + "w": "week", + "dayofmonth": "day", + "days": "day", + "day": "day", + "dd": "day", + "d": "day", +} + +rank_functions = ( + local_expression.CumeDist, + exp.FirstValue, + exp.LastValue, + local_expression.NthValue, + local_expression.Ntile, +) + + +def _parse_to_timestamp(args: list) -> exp.StrToTime | exp.UnixToTime | exp.TimeStrToTime: + if len(args) == 2: + first_arg, second_arg = args + if second_arg.is_string: + # case: [ , ] + return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) + return exp.UnixToTime(this=first_arg, scale=second_arg) + + # The first argument might be an expression like 40 * 365 * 86400, so we try to + # reduce it using `simplify_literals` first and then check if it's a Literal. + first_arg = seq_get(args, 0) + if not isinstance(simplify_literals(first_arg, root=True), exp.Literal): + # case: or other expressions such as columns + return exp.TimeStrToTime.from_arg_list(args) + + if first_arg.is_string: + if is_int(first_arg.this): + # case: + return exp.UnixToTime.from_arg_list(args) + + # case: + return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) + + # case: + return exp.UnixToTime.from_arg_list(args) + + +def _parse_date_add(args: list) -> exp.DateAdd: + return exp.DateAdd(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) + + +def _parse_split_part(args: list) -> local_expression.SplitPart: + if len(args) != 3: + err_msg = f"Error Parsing args `{args}`. Number of args must be 3, given {len(args)}" + raise ParseError(err_msg) + part_num_literal = seq_get(args, 2) + part_num_if = None + if isinstance(part_num_literal, exp.Literal): + # In Snowflake if the partNumber is 0, it is treated as 1. + # Please refer to https://docs.snowflake.com/en/sql-reference/functions/split_part + if part_num_literal.is_int and int(part_num_literal.name) == 0: + part_num_literal = exp.Literal.number(1) + else: + cond = exp.EQ(this=part_num_literal, expression=exp.Literal.number(0)) + part_num_if = exp.If(this=cond, true=exp.Literal.number(1), false=part_num_literal) + + part_num = part_num_if if part_num_if is not None else part_num_literal + return local_expression.SplitPart(this=seq_get(args, 0), expression=seq_get(args, 1), partNum=part_num) + + +def _div0_to_if(args: list) -> exp.If: + cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) + true = exp.Literal.number(0) + false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) + return exp.If(this=cond, true=true, false=false) + + +def _div0null_to_if(args: list) -> exp.If: + cond = exp.Or( + this=exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)), + expression=exp.Is(this=seq_get(args, 1), expression=exp.Null()), + ) + true = exp.Literal.number(0) + false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) + return exp.If(this=cond, true=true, false=false) + + +def _parse_json_extract_path_text(args: list) -> local_expression.JsonExtractPathText: + if len(args) != 2: + err_message = f"Error Parsing args `{args}`. Number of args must be 2, given {len(args)}" + raise ParseError(err_message) + return local_expression.JsonExtractPathText(this=seq_get(args, 0), path_name=seq_get(args, 1)) + + +def _parse_array_contains(args: list) -> exp.ArrayContains: + if len(args) != 2: + err_message = f"Error Parsing args `{args}`. Number of args must be 2, given {len(args)}" + raise ParseError(err_message) + return exp.ArrayContains(this=seq_get(args, 1), expression=seq_get(args, 0)) + + +def _parse_dayname(args: list) -> local_expression.DateFormat: + """ + * E, EE, EEE, returns short day name (Mon) + * EEEE, returns full day name (Monday) + :param args: node expression + :return: DateFormat with `E` format + """ + if len(args) != 1: + err_message = f"Error Parsing args `{args}`. Number of args must be 1, given {len(args)}" + raise ParseError(err_message) + return local_expression.DateFormat(this=seq_get(args, 0), expression=exp.Literal.string("E")) + + +def _parse_trytonumber(args: list) -> local_expression.TryToNumber: + if len(args) == 1: + msg = f"""*Warning:: Parsing args `{args}`: + * `format` is missing + * assuming defaults `precision`[38] and `scale`[0] + """ + logger.warning(msg) + elif len(args) == 3: + msg = f"""Error Parsing args `{args}`: + * `format` is required + * `precision` and `scale` both are required [if specified] + """ + raise ParseError(msg) + + if len(args) == 4: + return local_expression.TryToNumber( + this=seq_get(args, 0), expression=seq_get(args, 1), precision=seq_get(args, 2), scale=seq_get(args, 3) + ) + + return local_expression.TryToNumber(this=seq_get(args, 0), expression=seq_get(args, 1)) + + +def _parse_monthname(args: list) -> local_expression.DateFormat: + if len(args) != 1: + err_message = f"Error Parsing args `{args}`. Number of args must be 1, given {len(args)}" + raise ParseError(err_message) + return local_expression.DateFormat(this=seq_get(args, 0), expression=exp.Literal.string("MMM")) + + +def _parse_object_construct(args: list) -> exp.StarMap | exp.Struct: + expression = parse_var_map(args) + + if isinstance(expression, exp.StarMap): + return exp.Struct(expressions=[expression.this]) + + return exp.Struct( + expressions=[ + exp.PropertyEQ(this=k.this, expression=v) for k, v in zip(expression.keys, expression.values, strict=False) + ] + ) + + +def _parse_to_boolean(args: list, *, error=False) -> local_expression.ToBoolean: + this_arg = seq_get(args, 0) + return local_expression.ToBoolean(this=this_arg, raise_error=exp.Literal.number(1 if error else 0)) + + +def _parse_tonumber(args: list) -> local_expression.ToNumber: + if len(args) > 4: + error_msg = f"""Error Parsing args args: + * Number of args cannot be more than `4`, given `{len(args)}` + """ + raise ParseError(error_msg) + + match len(args): + case 1: + msg = ( + "Precision and Scale are not specified, assuming defaults `precision`[38] and `scale`[0]. " + "If Format is not specified, it will be inferred as simple cast as decimal" + ) + logger.warning(msg) + return local_expression.ToNumber(this=seq_get(args, 0)) + case 3: + msg = "If Format is not specified, it will be inferred as simple cast as decimal" + logger.warning(msg) + return local_expression.ToNumber(this=seq_get(args, 0), precision=seq_get(args, 1), scale=seq_get(args, 2)) + case 4: + return local_expression.ToNumber( + this=seq_get(args, 0), expression=seq_get(args, 1), precision=seq_get(args, 2), scale=seq_get(args, 3) + ) + + return local_expression.ToNumber(this=seq_get(args, 0), expression=seq_get(args, 1)) + + +def contains_expression(expr, target_type): + if isinstance(expr, target_type): + return True + if hasattr(expr, 'this') and contains_expression(expr.this, target_type): + return True + if hasattr(expr, 'expressions'): + for sub_expr in expr.expressions: + if contains_expression(sub_expr, target_type): + return True + return False + + +def _parse_sha2(args: list) -> exp.SHA2: + if len(args) == 1: + return exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)) + return exp.SHA2(this=seq_get(args, 0), length=seq_get(args, 1)) + + +class Snowflake(SqlglotSnowflake): + # Instantiate Snowflake Dialect + snowflake = SqlglotSnowflake() + + class Tokenizer(SqlglotSnowflake.Tokenizer): + + COMMENTS = ["--", "//", ("/*", "*/")] + STRING_ESCAPES = ["\\", "'"] + + CUSTOM_TOKEN_MAP = { + r"(?i)CREATE\s+OR\s+REPLACE\s+PROCEDURE": TokenType.PROCEDURE, + r"(?i)var\s+\w+\s+=\s+\w+?": TokenType.VAR, + } + + SINGLE_TOKENS = { + **SqlglotSnowflake.Tokenizer.SINGLE_TOKENS, + "&": TokenType.PARAMETER, # https://docs.snowflake.com/en/user-guide/snowsql-use#substituting-variables-in-a-session + "!": TokenType.COMMAND, + } + + KEYWORDS = {**SqlglotSnowflake.Tokenizer.KEYWORDS} + # DEC is not a reserved keyword in Snowflake it can be used as table alias + KEYWORDS.pop("DEC") + + @classmethod + def update_keywords(cls, new_key_word_dict): + cls.KEYWORDS = new_key_word_dict | cls.KEYWORDS + + @classmethod + def merge_trie(cls, parent_trie, curr_trie): + merged_trie = {} + logger.debug(f"The Parent Trie is {parent_trie}") + logger.debug(f"The Input Trie is {curr_trie}") + for key in set(parent_trie.keys()) | set(curr_trie.keys()): # Get all unique keys from both tries + if key in parent_trie and key in curr_trie: # If the key is in both tries, merge the subtries + if isinstance(parent_trie[key], dict) and isinstance(curr_trie[key], dict): + logger.debug(f"New trie inside the key is {curr_trie}") + logger.debug(f"Parent trie inside the key is {parent_trie}") + merged_trie[key] = cls.merge_trie(parent_trie[key], curr_trie[key]) + logger.debug(f"Merged Trie is {merged_trie}") + elif isinstance(parent_trie[key], dict): + merged_trie[key] = parent_trie[key] + else: + merged_trie[key] = curr_trie[key] + elif key in parent_trie: # If the key is only in trie1, add it to the merged trie + merged_trie[key] = parent_trie[key] + else: # If the key is only in trie2, add it to the merged trie + merged_trie[key] = curr_trie[key] + return merged_trie + + @classmethod + def update_keyword_trie( + cls, + curr_trie, + parent_trie=None, + ): + if parent_trie is None: + parent_trie = cls._KEYWORD_TRIE + cls.KEYWORD_TRIE = cls.merge_trie(parent_trie, curr_trie) + + def match_strings_token_dict(self, string, pattern_dict): + result_dict = {} + for pattern in pattern_dict: + matches = re.finditer(pattern, string, re.MULTILINE | re.IGNORECASE | re.DOTALL) + for _, match in enumerate(matches, start=1): + result_dict[match.group().upper()] = pattern_dict[pattern] + return result_dict + + def match_strings_list(self, string, pattern_dict): + result = [] + for pattern in pattern_dict: + matches = re.finditer(pattern, string, re.MULTILINE | re.IGNORECASE | re.DOTALL) + for _, match in enumerate(matches, start=1): + result.append(match.group().upper()) + return result + + def tokenize(self, sql: str) -> list[Token]: + """Returns a list of tokens corresponding to the SQL string `sql`.""" + self.reset() + self.sql = sql + # Update Keywords + ref_dict = self.match_strings_token_dict(sql, self.CUSTOM_TOKEN_MAP) + self.update_keywords(ref_dict) + # Update Keyword Trie + custom_trie = new_trie(self.match_strings_list(sql, self.CUSTOM_TOKEN_MAP)) + logger.debug( + f"The New Trie after adding the REF, VAR and IF ELSE blocks " + f"based on {self.CUSTOM_TOKEN_MAP}, is \n\n {custom_trie}" + ) + self.update_keyword_trie(custom_trie) + logger.debug(f"Updated New Trie is {custom_trie}") + # Parent Code + self.size = len(sql) + try: + self._scan() + except Exception as e: + start = self._current - 50 + end = self._current + 50 + start = start if start > 0 else 0 + end = end if end < self.size else self.size - 1 + context = self.sql[start:end] + msg = f"Error tokenizing '{context}'" + raise ParseError(msg) from e + return self.tokens + + class Parser(SqlglotSnowflake.Parser): + FUNCTIONS = { + **SqlglotSnowflake.Parser.FUNCTIONS, + "ARRAY_AGG": exp.ArrayAgg.from_arg_list, + "STRTOK_TO_ARRAY": local_expression.Split.from_arg_list, + "DATE_FROM_PARTS": local_expression.MakeDate.from_arg_list, + "CONVERT_TIMEZONE": local_expression.ConvertTimeZone.from_arg_list, + "TRY_TO_DATE": local_expression.TryToDate.from_arg_list, + "TRY_TO_TIMESTAMP": local_expression.TryToTimestamp.from_arg_list, + "STRTOK": local_expression.StrTok.from_arg_list, + "SPLIT_PART": _parse_split_part, + "TIMESTAMPADD": _parse_date_add, + "TRY_TO_DECIMAL": _parse_trytonumber, + "TRY_TO_NUMBER": _parse_trytonumber, + "TRY_TO_NUMERIC": _parse_trytonumber, + "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "IS_INTEGER": local_expression.IsInteger.from_arg_list, + "DIV0": _div0_to_if, + "DIV0NULL": _div0null_to_if, + "JSON_EXTRACT_PATH_TEXT": _parse_json_extract_path_text, + "BITOR_AGG": local_expression.BitOr.from_arg_list, + "ARRAY_CONTAINS": _parse_array_contains, + "DAYNAME": _parse_dayname, + "BASE64_ENCODE": exp.ToBase64.from_arg_list, + "BASE64_DECODE_STRING": exp.FromBase64.from_arg_list, + "TRY_BASE64_DECODE_STRING": exp.FromBase64.from_arg_list, + "ARRAY_CONSTRUCT_COMPACT": local_expression.ArrayConstructCompact.from_arg_list, + "ARRAY_INTERSECTION": local_expression.ArrayIntersection.from_arg_list, + "ARRAY_SLICE": local_expression.ArraySlice.from_arg_list, + "MONTHNAME": _parse_monthname, + "MONTH_NAME": _parse_monthname, + "OBJECT_CONSTRUCT": _parse_object_construct, + "OBJECT_KEYS": local_expression.ObjectKeys.from_arg_list, + "TRY_PARSE_JSON": exp.ParseJSON.from_arg_list, + "TIMEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "TIMESTAMPDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "TIMEADD": _parse_date_add, + "TO_BOOLEAN": lambda args: _parse_to_boolean(args, error=True), + "TO_DECIMAL": _parse_tonumber, + "TO_DOUBLE": local_expression.ToDouble.from_arg_list, + "TO_NUMBER": _parse_tonumber, + "TO_NUMERIC": _parse_tonumber, + "TO_OBJECT": local_expression.ToObject.from_arg_list, + "TO_TIME": _parse_to_timestamp, + "TIMESTAMP_FROM_PARTS": local_expression.TimestampFromParts.from_arg_list, + "TO_VARIANT": local_expression.ToVariant.from_arg_list, + "TRY_TO_BOOLEAN": lambda args: _parse_to_boolean(args, error=False), + "UUID_STRING": local_expression.UUID.from_arg_list, + "SYSDATE": exp.CurrentTimestamp.from_arg_list, + "TRUNC": lambda args: local_expression.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "NTH_VALUE": local_expression.NthValue.from_arg_list, + "MEDIAN": local_expression.Median.from_arg_list, + "CUME_DIST": local_expression.CumeDist.from_arg_list, + "DENSE_RANK": local_expression.DenseRank.from_arg_list, + "RANK": local_expression.Rank.from_arg_list, + "PERCENT_RANK": local_expression.PercentRank.from_arg_list, + "NTILE": local_expression.Ntile.from_arg_list, + "TO_ARRAY": local_expression.ToArray.from_arg_list, + "SHA2": _parse_sha2, + } + + FUNCTION_PARSERS = { + **SqlglotSnowflake.Parser.FUNCTION_PARSERS, + "LISTAGG": lambda self: self._parse_list_agg(), + } + + PLACEHOLDER_PARSERS = { + **SqlglotSnowflake.Parser.PLACEHOLDER_PARSERS, + TokenType.PARAMETER: lambda self: self._parse_parameter(), + } + + FUNC_TOKENS = {*SqlglotSnowflake.Parser.FUNC_TOKENS, TokenType.COLLATE} + + COLUMN_OPERATORS = { + **SqlglotSnowflake.Parser.COLUMN_OPERATORS, + } + + TIMESTAMPS: set[TokenType] = SqlglotSnowflake.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + + RANGE_PARSERS = { + **SqlglotSnowflake.Parser.RANGE_PARSERS, + } + + ALTER_PARSERS = {**SqlglotSnowflake.Parser.ALTER_PARSERS} + + def _parse_list_agg(self) -> exp.GroupConcat: + if self._match(TokenType.DISTINCT): + args: list[exp.Expression] = [self.expression(exp.Distinct, expressions=[self._parse_conjunction()])] + if self._match(TokenType.COMMA): + args.extend(self._parse_csv(self._parse_conjunction)) + else: + args = self._parse_csv(self._parse_conjunction) + + return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1)) + + def _parse_types( + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True + ) -> exp.Expression | None: + this = super()._parse_types(check_func=check_func, schema=schema, allow_identifiers=allow_identifiers) + # https://docs.snowflake.com/en/sql-reference/data-types-numeric Numeric datatype alias + if ( + isinstance(this, exp.DataType) + and this.is_type("numeric", "decimal", "number", "integer", "int", "smallint", "bigint") + and not this.expressions + ): + return exp.DataType.build("DECIMAL(38,0)") + return this + + def _parse_parameter(self): + wrapped = self._match(TokenType.L_BRACE) + this = self._parse_var() or self._parse_identifier() or self._parse_primary() + self._match(TokenType.R_BRACE) + suffix: exp.Expression | None = None + if not self._match(TokenType.SPACE) or self._match(TokenType.DOT): + suffix = self._parse_var() or self._parse_identifier() or self._parse_primary() + + return self.expression(local_expression.Parameter, this=this, wrapped=wrapped, suffix=suffix) + + def _parse_window(self, this: exp.Expression | None, alias: bool = False) -> exp.Expression | None: + window = super()._parse_window(this=this, alias=alias) + # Adding default window frame for the rank-related functions in snowflake + if window and contains_expression(window.this, rank_functions) and window.args.get('spec') is None: + window.args['spec'] = self.expression( + exp.WindowSpec, + kind="ROWS", + start="UNBOUNDED", + start_side="PRECEDING", + end="UNBOUNDED", + end_side="FOLLOWING", + ) + return window + + def _parse_alter_table_add(self) -> list[exp.Expression]: + index = self._index - 1 + if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False): + return self._parse_csv( + lambda: self.expression(exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint)) + ) + + self._retreat(index) + if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"): + return self._parse_wrapped_csv(self._parse_field_def, optional=True) + + if self._match_text_seq("ADD", "COLUMN"): + schema = self._parse_schema() + if schema: + return [schema] + # return self._parse_csv in case of COLUMNS are not enclosed in brackets () + return self._parse_csv(self._parse_field_def) + + return self._parse_wrapped_csv(self._parse_add_column, optional=True) diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py new file mode 100644 index 0000000000..99e886bbec --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py @@ -0,0 +1,52 @@ +from sqlglot import expressions as exp, parse, transpile +from sqlglot.dialects.dialect import Dialect +from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError +from sqlglot.expressions import Expression + +from databricks.labs.remorph.config import TranspilationResult +from databricks.labs.remorph.helpers.file_utils import refactor_hexadecimal_chars +from databricks.labs.remorph.transpiler.transpile_status import ParserError + + +class SqlglotEngine: + def __init__(self, read_dialect: Dialect): + self.read_dialect = read_dialect + + def transpile( + self, write_dialect: Dialect, sql: str, file_name: str, error_list: list[ParserError] + ) -> TranspilationResult: + try: + transpiled_sql = transpile(sql, read=self.read_dialect, write=write_dialect, pretty=True, error_level=None) + except (ParseError, TokenError, UnsupportedError) as e: + transpiled_sql = [""] + error_list.append(ParserError(file_name, refactor_hexadecimal_chars(str(e)))) + + return TranspilationResult(transpiled_sql, error_list) + + def parse(self, sql: str, file_name: str) -> tuple[list[Expression | None] | None, ParserError | None]: + expression = None + error = None + try: + expression = parse(sql, read=self.read_dialect, error_level=ErrorLevel.IMMEDIATE) + except (ParseError, TokenError, UnsupportedError) as e: + error = ParserError(file_name, str(e)) + + return expression, error + + def parse_sql_content(self, sql, file_name): + parsed_expression, _ = self.parse(sql, file_name) + if parsed_expression is not None: + for expr in parsed_expression: + child = str(file_name) + if expr is not None: + for create in expr.find_all(exp.Create, exp.Insert, exp.Merge, bfs=False): + child = self._find_root_tables(create) + + for select in expr.find_all(exp.Select, exp.Join, exp.With, bfs=False): + yield self._find_root_tables(select), child + + @staticmethod + def _find_root_tables(expression) -> str | None: + for table in expression.find_all(exp.Table, bfs=False): + return table.name + return None diff --git a/src/databricks/labs/remorph/transpiler/transpile_status.py b/src/databricks/labs/remorph/transpiler/transpile_status.py new file mode 100644 index 0000000000..b22c72366d --- /dev/null +++ b/src/databricks/labs/remorph/transpiler/transpile_status.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + + +@dataclass +class ParserError: + file_name: str + exception: str + + +@dataclass +class ValidationError: + file_name: str + exception: str + + +@dataclass +class TranspileStatus: + file_list: list[str] + no_of_queries: int + parse_error_count: int + validate_error_count: int + error_log_list: list[ParserError | ValidationError] | None diff --git a/src/databricks/labs/remorph/uninstall.py b/src/databricks/labs/remorph/uninstall.py new file mode 100644 index 0000000000..5939d67e7e --- /dev/null +++ b/src/databricks/labs/remorph/uninstall.py @@ -0,0 +1,28 @@ +import logging + +from databricks.labs.blueprint.entrypoint import is_in_debug +from databricks.sdk import WorkspaceClient + +from databricks.labs.remorph.__about__ import __version__ +from databricks.labs.remorph.contexts.application import ApplicationContext + +logger = logging.getLogger("databricks.labs.remorph.install") + + +def run(context: ApplicationContext): + context.workspace_installation.uninstall(context.remorph_config) + + +if __name__ == "__main__": + logger.setLevel("INFO") + if is_in_debug(): + logging.getLogger("databricks").setLevel(logging.DEBUG) + + run( + ApplicationContext( + WorkspaceClient( + product="remorph", + product_version=__version__, + ) + ) + ) diff --git a/src/databricks/labs/remorph/upgrades/v0.4.0_add_main_table_operation_name_column.py b/src/databricks/labs/remorph/upgrades/v0.4.0_add_main_table_operation_name_column.py new file mode 100644 index 0000000000..7c6b2bf990 --- /dev/null +++ b/src/databricks/labs/remorph/upgrades/v0.4.0_add_main_table_operation_name_column.py @@ -0,0 +1,79 @@ +# pylint: disable=invalid-name +import logging + + +from databricks.labs.blueprint.installation import Installation +from databricks.sdk import WorkspaceClient + +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.deployment.recon import RECON_JOB_NAME +from databricks.labs.remorph.helpers import db_sql + +from databricks.labs.remorph.deployment.upgrade_common import ( + current_table_columns, + installed_table_columns, + recreate_table_sql, +) + +logger = logging.getLogger(__name__) + + +def _check_table_mismatch( + installed_table, + current_table, +) -> bool: + current_table = [x for x in current_table if x != "operation_name"] + # Compare the current main table columns with the installed main table columns + if "operation_name" in installed_table and len(sorted(installed_table)) != len(sorted(current_table)): + return True + return False + + +def _upgrade_reconcile_metadata_main_table( + installation: Installation, + ws: WorkspaceClient, + app_context: ApplicationContext, +): + """ + Add operation_name column to the main table as part of the upgrade process. + - Compare the current main table columns with the installed main table columns. If there is any mismatch: + * Verify all the current main table columns are present in the installed main table and then use CTAS to recreate the main table + * If any of the current main table columns are missing in the installed main table, prompt the user to recreate the main table: + - If the user confirms, recreate the main table using the main DDL file, else log an error message and exit + :param installation: + :param ws: + :param app_context: + """ + reconcile_config = app_context.recon_config + assert reconcile_config, "Reconcile config must be present to upgrade the reconcile metadata main table" + table_name = "main" + table_identifier = ( + f"{reconcile_config.metadata_config.catalog}.{reconcile_config.metadata_config.schema}.{table_name}" + ) + installed_columns = installed_table_columns(ws, table_identifier) + current_columns = current_table_columns(table_name, table_identifier) + sql: str | None = f"ALTER TABLE {table_identifier} ADD COLUMN operation_name STRING AFTER report_type" + if _check_table_mismatch(installed_columns, current_columns): + logger.info("Recreating main table") + sql = recreate_table_sql(table_identifier, installed_columns, current_columns, app_context.prompts) + if sql: + logger.debug(f"Executing SQL to upgrade main table: \n{sql}") + db_sql.get_sql_backend(ws).execute(sql) + installation.save(reconcile_config) + logger.debug("Upgraded Reconcile main table") + + +def _upgrade_reconcile_workflow(app_context: ApplicationContext): + if app_context.recon_config: + logger.info("Upgrading reconcile workflow") + wheels = app_context.product_info.wheels(app_context.workspace_client) + wheel_path = f"/Workspace{wheels.upload_to_wsfs()}" + app_context.job_deployment.deploy_recon_job(RECON_JOB_NAME, app_context.recon_config, wheel_path) + logger.debug("Upgraded reconcile workflow") + + +def upgrade(installation: Installation, ws: WorkspaceClient): + app_context = ApplicationContext(ws) + if app_context.recon_config is not None: + _upgrade_reconcile_metadata_main_table(installation, ws, app_context) + _upgrade_reconcile_workflow(app_context) diff --git a/src/databricks/labs/remorph/upgrades/v0.6.0_alter_metrics_datatype.py b/src/databricks/labs/remorph/upgrades/v0.6.0_alter_metrics_datatype.py new file mode 100644 index 0000000000..671c116e8d --- /dev/null +++ b/src/databricks/labs/remorph/upgrades/v0.6.0_alter_metrics_datatype.py @@ -0,0 +1,51 @@ +# pylint: disable=invalid-name +import logging + +from databricks.labs.blueprint.installation import Installation +from databricks.sdk import WorkspaceClient + +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.deployment.upgrade_common import ( + current_table_columns, + installed_table_columns, + check_table_mismatch, + recreate_table_sql, +) +from databricks.labs.remorph.helpers import db_sql + +logger = logging.getLogger(__name__) + + +def _upgrade_reconcile_metadata_metrics_table( + installation: Installation, ws: WorkspaceClient, app_context: ApplicationContext +): + reconcile_config = app_context.recon_config + assert reconcile_config, "Reconcile config must be present to upgrade the reconcile metadata main table" + table_name = "metrics" + table_identifier = ( + f"{reconcile_config.metadata_config.catalog}.{reconcile_config.metadata_config.schema}.{table_name}" + ) + installed_columns = installed_table_columns(ws, table_identifier) + current_columns = current_table_columns(table_name, table_identifier) + sqls: list | None = [ + f"ALTER TABLE {table_identifier} SET TBLPROPERTIES ('delta.enableTypeWidening' = 'true')", + f"ALTER TABLE {table_identifier} ALTER COLUMN recon_metrics.row_comparison.missing_in_source TYPE BIGINT", + f"ALTER TABLE {table_identifier} ALTER COLUMN recon_metrics.row_comparison.missing_in_target TYPE BIGINT", + f"ALTER TABLE {table_identifier} ALTER COLUMN recon_metrics.column_comparison.absolute_mismatch TYPE BIGINT", + f"ALTER TABLE {table_identifier} ALTER COLUMN recon_metrics.column_comparison.threshold_mismatch TYPE BIGINT", + ] + if check_table_mismatch(installed_columns, current_columns): + logger.info("Recreating main table") + sqls = [recreate_table_sql(table_identifier, installed_columns, current_columns, app_context.prompts)] + if sqls: + for sql in sqls: + logger.debug(f"Executing SQL to upgrade metrics table: \n{sql}") + db_sql.get_sql_backend(ws).execute(sql) + installation.save(reconcile_config) + logger.debug("Upgraded Reconcile metrics table") + + +def upgrade(installation: Installation, ws: WorkspaceClient): + app_context = ApplicationContext(ws) + if app_context.recon_config is not None: + _upgrade_reconcile_metadata_metrics_table(installation, ws, app_context) diff --git a/tests/__init__.py b/tests/__init__.py index 2482f42782..b1941733de 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# ... \ No newline at end of file +# ... diff --git a/tests/resources/dashboards/queries/00_description.md b/tests/resources/dashboards/queries/00_description.md new file mode 100644 index 0000000000..06fe3f7a41 --- /dev/null +++ b/tests/resources/dashboards/queries/00_description.md @@ -0,0 +1 @@ +#Reconcile Dashboard Queries Test diff --git a/tests/resources/dashboards/queries/01_queries.sql b/tests/resources/dashboards/queries/01_queries.sql new file mode 100644 index 0000000000..6326979d6b --- /dev/null +++ b/tests/resources/dashboards/queries/01_queries.sql @@ -0,0 +1,8 @@ +SELECT +main.recon_id, +main.source_type, +main.report_type, +main.source_table.`catalog` as source_catalog, +main.source_table.`schema` as source_schema, +main.source_table.table_name as source_table_name +FROM remorph.reconcile.main main diff --git a/tests/resources/dashboards/queries/dashboard.yml b/tests/resources/dashboards/queries/dashboard.yml new file mode 100644 index 0000000000..2bb507f55f --- /dev/null +++ b/tests/resources/dashboards/queries/dashboard.yml @@ -0,0 +1 @@ +display_name: "Reconciliation Metrics Test" diff --git a/tests/resources/datasets/customer.sql b/tests/resources/datasets/customer.sql new file mode 100644 index 0000000000..c378bb5cd2 --- /dev/null +++ b/tests/resources/datasets/customer.sql @@ -0,0 +1,115 @@ +---------------------------------------------------------------- +CREATE TABLE customer ( + c_custkey BIGINT, + c_name VARCHAR(500), + c_address VARCHAR(500), + c_nationkey BIGINT, + c_phone VARCHAR(500), + c_acctbal DECIMAL(18,2), + c_mktsegment VARCHAR(500), + c_comment VARCHAR(500) + ); + +---------------------------------------------------------------- + +INSERT INTO customer (c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment) VALUES + (3295930, 'Customer#003295930', 'V0E9sOQGMCpNW', 6, '16-588-250-3730', 4319.11, 'HOUSEHOLD', 'pinto beans. quickly express deposits are slyly final accounts. fu'), + (3371083, 'Customer#003371083', 'X7Dh9jIV,OFXaim1 Y', 6, '16-850-721-5939', 362.92, 'AUTOMOBILE', 'riously final Tiresias affix carefully above the slyly final packages. ironic, fin'), + (9882436, 'Customer#009882436', 'coHXRsx1FXd', 7, '17-304-701-9401', 2153.59, 'HOUSEHOLD', 'ickly fluffily special gifts. carefully final theodolites run slyly after'), + (90817804, 'Customer#090817804', 'k5HudfszaRZibxdeLwt', 16, '26-186-501-5193', 8003.26, 'FURNITURE', 'even, ironic dependencies. bold deposits boost packages. pending instructions are across th'), + (29928037, 'Customer#029928037', '2c Qy6ygNwkiBjJUtiGshrBZGj 1t47Xu6y', 15, '25-295-364-7866', 8817.38, 'FURNITURE', 'slyly pending requests cajole across the dependencies. fluffy, unusual requests haggle. ironic foxes amon'), + (66957875, 'Customer#066957875', 'Gz9Xe cQMl1DN8hu', 24, '34-734-299-8721', 4175.54, 'AUTOMOBILE', 'regular accounts; blithely bold courts cajole furiously. fluffy theodolites detect carefully among the final de'), + (56612041, 'Customer#056612041', 'MzprTl9,kz,JEisQlJdBqs', 8, '18-786-542-3567', 288.87, 'BUILDING', 'es according to the deposits '), + (70006357, 'Customer#070006357', 'wNgBSZsW2x,qMa8zw', 12, '22-986-714-5298', 4259.25, 'AUTOMOBILE', ' deposits. blithely unusual pinto beans use furiously blithely ironic packages. fluffily express '), + (81762487, 'Customer#081762487', 'gdWxUbki3XCcKZ2AomsEiJfVjer,R ', 14, '24-304-592-6742', 4780.45, 'MACHINERY', 'regular theodolites snooze slyly. pinto beans wake. regular foxes cajole against the package'), + (16250920, 'Customer#016250920', 'm0y1xSh79pONxmD Zx4Pu7PRhDTjLhdJSQAYu3X', 0, '10-716-976-3184', 9612.95, 'MACHINERY', 'ions maintain platelets; special theodolites about the slyly express requests nag against the slyly final deposit'), + (27236141, 'Customer#027236141', 'hOXVVYFVPe,gyrF9 p3Ad0xJvJjWsLQM', 19, '29-138-757-5625', -872.77, 'FURNITURE', 'odolites lose quickly alongside of the i'), + (26393408, 'Customer#026393408', 'pLOZA XA9hxJ,bJASfcgbL8r0Ziunm', 1, '11-769-490-3598', 8699.31, 'FURNITURE', ' the quickly even accounts. blithely bold deposits among'), + (92748641, 'Customer#092748641', '7E7Ua1YT5uL8sKsCEV', 19, '29-532-424-6629', 6717.57, 'AUTOMOBILE', 'e of the regular decoys. daring reques'), + (43999894, 'Customer#043999894', ',6yCYx5oJof9dSVjPMsS3osuBb4HI4MCS', 12, '22-785-570-2761', 4739.59, 'FURNITURE', 'express foxes. blithe, final accounts haggle silently close deposits. blithely regula'), + (44077571, 'Customer#044077571', 'VYLBAUge5CDQP c', 21, '31-143-742-6463', 3112.61, 'AUTOMOBILE', 'ideas are beyond the pending deposits. bold ideas nag after the'), + (44484776, 'Customer#044484776', '8kT0a8iG4IL8y2CLZ', 7, '17-181-177-6657', 8679.17, 'BUILDING', 'he furiously final instructions! accounts nag. furiously final instruct'), + (44667518, 'Customer#044667518', 'YmjjqysI03BCtuKha8PDw3Y', 8, '18-241-558-3799', 8709.51, 'AUTOMOBILE', 'ndencies. blithely regular excuses among the blithely stealthy accounts shal'), + (41860129, 'Customer#041860129', 'jncIr8PT9UyfZBjFevzvKMZtwFEOgxfiY', 15, '25-943-885-3681', -862.01, 'HOUSEHOLD', ' deposits nod slowly alongside of the furiously special pinto beans. quickly fina'), + (22251439, 'Customer#022251439', 'wL0RPk9Z0kY8Z,asuvnArkMaXBRE', 13, '23-888-146-1548', 3853.54, 'FURNITURE', 'y along the express, even theodolites. carefully final requests affix slyly sp'), + (64339465, 'Customer#064339465', 'viqTMXIG,k2cEMv7vlSAFMrGBmlaboINE6rK', 11, '21-966-285-7538', 2099.69, 'MACHINERY', ' slyly carefully ironic requests. even deposits sleep quickly regular a'), + (64823140, 'Customer#064823140', 'i4b3lzSdfPdLQ28', 24, '34-322-280-3318', 6838.87, 'AUTOMOBILE', 'ways quick dependencies? furiously even ideas into the quickly express ideas believe deposits. blithely'), + (82493404, 'Customer#082493404', 'izWtgaCF1 qRdtMI48wKAzaDt234HuUYRq5g8Oo', 10, '20-709-137-3338', 1003.20, 'FURNITURE', 'final packages after the fluffily expr'), + (82568518, 'Customer#082568518', 'tSDFsUkSh1L0bo', 14, '24-199-866-4074', 1182.76, 'BUILDING', 'es sleep after the regular foxes. furiously unusual accounts hinder fluffily '), + (61723415, 'Customer#061723415', 'oppYIk5komPMY1bJuw6MbwwW', 22, '32-926-228-3091', 5850.52, 'MACHINERY', 'uriously regular accounts are blithely after the final dependencies. pending, express foxes sleep slyly. furio'), + (60394615, 'Customer#060394615', 'f69i3Ag2jhK4eRtAcXE94p', 24, '34-877-318-9214', 5450.81, 'BUILDING', 'eas sleep blithely furious theodolites. carefully unusual requests sleep quickly. carefull'), + (60479711, 'Customer#060479711', 'ee3e3yZ6LuxeeTGyf3wU6pxOfvg16z', 11, '21-198-628-4822', 4588.09, 'MACHINERY', 'kages. unusual accounts against the slyly special deposits sleep blithe, ironic foxes'), + (61000036, 'Customer#061000036', 'yQXWvXMcATmPd,LCuH5EM p', 24, '34-121-630-8556', 1905.97, 'HOUSEHOLD', 's. fluffily final dependencies believe furiously after the final, regular'), + (88909205, 'Customer#088909205', 'aPSLB1SuhH9,0,yZo4si1x', 4, '14-221-326-7431', 6450.00, 'AUTOMOBILE', 'rnis cajole slyly after the excuses. carefully even accounts was blithely regular ide'), + (14115415, 'Customer#014115415', 'IO70hIan2SKpUhdBtjUmyeY', 5, '15-668-222-7805', 9292.04, 'HOUSEHOLD', 'odolites according to the furiously pending deposits cajol'), + (86114407, 'Customer#086114407', 'pGTUWREajuurKiWlmkARceM22VJ0cA27rkQ', 7, '17-635-861-8255', 2441.56, 'HOUSEHOLD', 'e blithely special accounts dazzle alongside of the quickly final attainments. slyly unusual deposits affix careful'), + (27997306, 'Customer#027997306', 'Y7Xj7TKai5', 8, '18-737-462-6661', 543.80, 'BUILDING', 'above the furiously special multipliers. qu'), + (16618916, 'Customer#016618916', 'dYnZmt965mG8rhkwV0IG X,aVH2SuuLD', 17, '27-899-805-8823', 8615.15, 'MACHINERY', 'e deposits are slyly unusual theodolites. express foxes are slyly blithely idle foxes. blithely '), + (52969150, 'Customer#052969150', 'm7rolG7yT0JGa8ZUQsIiQJaMtMmcUqKhIr', 24, '34-630-959-3766', 8043.08, 'FURNITURE', 'the carefully quiet theodolites? carefully special accounts promise slyly above the final foxes. final, regular s'), + (111727408, 'Customer#111727408', 'iXWJQC2h,T5Z1kJMJkheRv3', 10, '20-169-685-8612', 4930.17, 'AUTOMOBILE', ' requests promise slyly ironic foxes. furiously pending asymptotes boost silent, final'), + (141049697, 'Customer#141049697', 'A81tY2 Mkf4yEzoePHKUqT2ytQECyn', 13, '23-940-686-2364', 9847.66, 'AUTOMOBILE', 'ven requests. pending, special'), + (115251748, 'Customer#115251748', 'rSAECwxXmrlf0LB0wT', 10, '20-247-765-7572', 1297.63, 'AUTOMOBILE', 'eodolites doubt slyly. quickly '), + (107777776, 'Customer#107777776', 'Hg15J3J4eD3AnU7K30vF', 14, '24-468-371-7737', 8860.15, 'AUTOMOBILE', ' asymptotes. regular, unusual asymptotes above the bold packages doze sl'), + (107810401, 'Customer#107810401', 'V3M2UNyRx5dYzVXBqtTRtuNP WpqOy', 12, '22-835-708-1117', 955.39, 'BUILDING', 's. carefully enticing packages nag carefully. blithely final depo'), + (124827856, 'Customer#124827856', '00Sz4L2t0ATjXaq', 7, '17-103-808-4641', 6843.57, 'AUTOMOBILE', ' are. pending dolphins about the bold packages lose around the furiously pending ideas. eve'), + (124830292, 'Customer#124830292', 'xwvdT9FjdvVqsAv', 2, '12-173-784-6421', 8713.61, 'AUTOMOBILE', 'deas about the quickly silent deposits s'), + (130056940, 'Customer#130056940', 'qmnqrvSUEDIV', 9, '19-343-620-7833', 9910.32, 'BUILDING', ' fluffily unusual foxes cajole fluffily carefully special packages. quickly express requests cajole alon'), + (116068132, 'Customer#116068132', 'jaFu9KjPhi', 11, '21-737-617-9865', 8970.32, 'BUILDING', 'egular dependencies cajole flu'), + (144597874, 'Customer#144597874', '7yeq5MwruFqPRat6dU1eRTpljZdr1qT6LiHpLqnL', 7, '17-189-348-6322', 4634.89, 'HOUSEHOLD', 'st slyly express packages. carefully express e'), + (135420649, 'Customer#135420649', 'z6CbgfKB9p', 23, '33-405-746-7262', 6241.04, 'HOUSEHOLD', 's among the carefully ironic excuses affix fluffily according to the pendi'), + (104479618, 'Customer#104479618', 'XMBU G81SjmR8blzOQ0', 6, '16-706-347-7510', 2270.04, 'MACHINERY', 'courts x-ray. fluffily regular ideas wake. requests according to th'), + (104726965, 'Customer#104726965', 't9bU20uruPnVy,', 13, '23-422-125-7062', 4590.46, 'BUILDING', 'ular pinto beans. quickly ironic accounts detect quickly along the care'), + (110216159, 'Customer#110216159', 'Kq5a9QTGkSdhmhpbqXJqYJUUtWbmD', 16, '26-394-525-7408', 4016.75, 'MACHINERY', ' believe furiously busily special ideas. final instructions above the ruthless, silent deposits promise ag'), + (106455751, 'Customer#106455751', 'UgkcoMjsta1W8QUYR1VMhL5ZYUT3pXmUgcpvI,D', 13, '23-930-469-4536', -399.76, 'BUILDING', 's play fluffily slyly regular accounts. furiously unusual accounts wake packages. even, ironi'), + (129199537, 'Customer#129199537', 'OzCeVXmCRk3eSpijo8P5VFw4Y2sxvGa', 9, '19-958-467-4138', 9531.66, 'FURNITURE', 'ously regular requests wake furiously around the slowly silent theodolites. furiously regular pac'), + (103749508, 'Customer#103749508', 'HMl2mr0D 0lL46iDXIaU4r9DItev2aGK8mjjBriq', 21, '31-482-445-7542', 6729.78, 'FURNITURE', 'carefully unusual packages wake about the unusual theodolites. ironic accounts hang ironic courts-- '), + (117950323, 'Customer#117950323', 'BlUNrAwiKiISCA025h9oaWXzEMdgkw VbJ', 11, '21-158-125-3573', 5482.46, 'AUTOMOBILE', ' platelets. carefully regular '), + (301816, 'Customer#000301816', '1gKYaw6j3rIMed,wrszGg3SZD', 7, '17-494-705-8319', -183.82, 'BUILDING', 'odolites thrash blithely deposits! '), + (715720, 'Customer#000715720', '8z5w3Md9hvqTV', 24, '34-202-501-6354', 5637.29, 'HOUSEHOLD', ' shall lose regular, slow accounts. packages lose of the regular, pending packages. pending accounts around the pla'), + (778084, 'Customer#000778084', 'WXkJ IJ702WDQIItTcxwN3VYJy', 4, '14-849-423-9056', 2507.93, 'FURNITURE', 'lyly pending foxes cajole furiously along the blithely idle theodolites. final, ironic ideas cajole'), + (1775348, 'Customer#001775348', 'JO6MaQ7RNAfRZybyMiBfr', 2, '12-558-100-4368', 7458.42, 'BUILDING', 'e blithely stealthy foxes. final f'), + (2289380, 'Customer#002289380', 'fK3S7GlARbWwq7GPST3Di46clRSadr8SbyvZf', 20, '30-775-281-7181', -574.97, 'AUTOMOBILE', 'latelets cajole about the slyly final pinto be'), + (2474150, 'Customer#002474150', 'UParMy35Tlj,wSU0a', 7, '17-139-110-9865', 4059.63, 'AUTOMOBILE', 'y regular packages hinder carefully quickly reg'), + (6197447, 'Customer#006197447', '3WyURSVY12p2FGiZimH6VubjDKwwH4', 22, '32-866-948-1468', 2488.75, 'MACHINERY', 'ajole slyly slyly silent excuses. even foxes among the furiously silent instructions sleep quickly even '), + (7092508, 'Customer#007092508', 'o1OyRWJG4H5vJT1DqxSnKDj', 17, '27-870-285-1913', 7420.91, 'MACHINERY', 'ously regular platelets sleep blithely bold, unusual'), + (32945209, 'Customer#032945209', 'xCFb,s9K gMRlpvZVvVSdvR', 5, '15-753-106-1327', 3025.70, 'FURNITURE', 'xes. furiously regular packages behind the slyly silent accounts wake after the sauternes-- final pa'), + (33029540, 'Customer#033029540', 'n 01ywmbXRWq1', 7, '17-912-141-3150', 8863.91, 'BUILDING', 'pecial platelets integrate after the fluffily special '), + (40023562, 'Customer#040023562', 'BRhiaF3ke64kaQo5mpCLvrPLU7t', 13, '23-988-211-5885', -39.70, 'MACHINERY', 'y even courts wake according to the furiously '), + (46071896, 'Customer#046071896', ',g3OPrUwuamL3hjE5bM ', 9, '19-103-374-3653', 5270.32, 'FURNITURE', 'ges nag slowly pending theodolites; blithely even ideas sleep slyly across the instructions. furiously stealth'), + (36899983, 'Customer#036899983', 'gqk7ABDT0AOrGob92fRYM1UA', 13, '23-366-569-7779', 7793.04, 'BUILDING', ' nag quickly according to the bold, regular deposits. furiously even dependencies nag above the ironic packages.'), + (36963334, 'Customer#036963334', 'zvkyeSVBOmIECMafWRzHYPVHXqTCcTRzdYVhH', 3, '13-835-218-1189', 8150.59, 'MACHINERY', 'heodolites integrate carefully. slyly special dependencies are quickly. quiet accounts us'), + (55622011, 'Customer#055622011', 'iIkeh1NYv8tKlp2,BsscD00', 2, '12-574-192-7162', 5004.48, 'HOUSEHOLD', 's. accounts use carefully past the ironic'), + (32112589, 'Customer#032112589', '5aIL5hHIyUPj1dbogTi8msiOlRZfLfD', 5, '15-940-412-2429', -758.26, 'HOUSEHOLD', 'arls. unusual accounts sleep bravely even deposits. bold, even pinto beans nag furiously express packages.'), + (32510872, 'Customer#032510872', 'yzRCrNCECZETdzu', 21, '31-167-589-1416', 1187.24, 'BUILDING', 'ges. even, ironic deposits grow blithely furiously unusual pinto beans. always final packages haggle carefully'), + (30351406, 'Customer#030351406', 'NP6FSnAeBiTtO,Pg3', 12, '22-472-889-1931', 7303.56, 'MACHINERY', 'ly final dugouts. regular, pendin'), + (71133277, 'Customer#071133277', 'o,P3,IQJg6Z5nXjq2vkgJkIyOvyrHvkQ', 9, '19-755-321-3999', 2690.52, 'BUILDING', 'packages run blithely express decoys. pending pinto beans poach among the even deposits. carefully final r'), + (39132812, 'Customer#039132812', 'GwRGx1ZXS2dW,', 16, '26-669-297-6373', 6644.79, 'FURNITURE', 'jole according to the express foxes. ironic, even packages use blithely express somas. q'), + (39134293, 'Customer#039134293', 'wnamG9FwX8fIm a', 12, '22-780-208-6338', 9277.52, 'BUILDING', 'sly slyly ironic foxes. carefully even '), + (28546706, 'Customer#028546706', 'r4rKzvkAIVLa zo2K,JzXTkzaCzL23Jplt9QbhvR', 19, '29-687-142-7693', -423.30, 'AUTOMOBILE', ' final theodolites wake furiously special requests. silent requests'), + (29099422, 'Customer#029099422', '0pkJyXiLHMfIXZvaU7Tb', 7, '17-571-125-7961', 5838.93, 'HOUSEHOLD', 'l foxes unwind blithely final pinto beans. slyly'), + (75985876, 'Customer#075985876', 'YxZKeQM2epK5FWR3C7FHRpsPDn46f LS', 10, '20-697-111-2636', 7147.46, 'BUILDING', ' pains according to the enticing ideas cajole carefull'), + (43184669, 'Customer#043184669', 'QsL9GfPvif', 11, '21-635-565-2030', 2325.84, 'BUILDING', 'press packages. quickly ironic packages among the pending pinto beans wake furiously regular requests--'), + (60109081, 'Customer#060109081', '3tWw5R67ta4xkQyjyOuwWFNKUQrzzQAy', 13, '23-486-178-3898', 9746.68, 'AUTOMOBILE', 'posits. final requests hang after the accounts! regular, quiet requests sleep slyl'), + (77599394, 'Customer#077599394', ',OdEQageFqcyM46v67HfvgPX8F8h,qnZPHsC', 3, '13-285-640-5162', 9859.54, 'BUILDING', 'ilent accounts about the regular, final deposits doze furiously unusual pinto beans. pending, entic'), + (78001627, 'Customer#078001627', ' pY4oKC0oZ1x0bjmJqPKVJYlJ8t', 13, '23-517-390-2161', 3226.40, 'MACHINERY', 'lent, ironic requests. stealthily regul'), + (84486485, 'Customer#084486485', ',msEaXcmm9tLtwbGwrBV9Mw224L4GxeHDHRtYO', 17, '27-351-972-1334', 1904.45, 'MACHINERY', ' even courts haggle carefully according to the special packages. slyly bold packages'), + (79060019, 'Customer#079060019', '7Yz9Mwn216ztCkVo3Z6XFbDh', 21, '31-289-485-3665', 1589.74, 'HOUSEHOLD', ' the carefully regular gifts engage foxes. carefully final sauternes nag fu'), + (21059945, 'Customer#021059945', 'ec192DMIne p1RQO tEYh8Fljs', 1, '11-378-171-1951', 1825.06, 'BUILDING', 'deas. regular, unusual pains sleep slyly above the furiously unusual theodolites. furiously even '), + (87758125, 'Customer#087758125', 'gQBgoWCB5M4lWxRJRlz7WNg0iB5', 17, '27-181-783-9233', 6341.38, 'BUILDING', 'l packages. blithely express asymptotes cajole sl'), + (50497948, 'Customer#050497948', ' 7H3 k2OR4SoykiwNF', 22, '32-641-297-7568', 1473.52, 'AUTOMOBILE', ' pending requests wake fluffily according to the bold packages? blithely ironic packages'), + (73955563, 'Customer#073955563', 'JxjeqfBTEyEIYdAKKp9Rsb68vlvye', 8, '18-617-763-6215', 5478.07, 'BUILDING', 'thely regular accounts at the regular asymptotes kindle slyly carefully special grouches. blithely un'), + (18983180, 'Customer#018983180', 'Mg7cuwJ27,og,sicMzCVZS7yFLe', 9, '19-107-604-1926', 216.33, 'FURNITURE', 'ilent ideas haggle slyly carefully pending foxes-- slyly silent accounts promise. caref'), + (136776016, 'Customer#136776016', 'ccGKDjLaS,jdhEL5pVAATUe97vHZD9ohiD', 13, '23-805-453-6615', 2120.84, 'HOUSEHOLD', 'uests need to are blithely. slyly express dependencies nag blithely along the blithely special foxes. final acco'), + (127465336, 'Customer#127465336', 'lZJOREBKPbYIa181T', 12, '22-703-969-7962', 2996.40, 'HOUSEHOLD', ' accounts. furious excuses sleep carefully slyly d'), + (127587628, 'Customer#127587628', 'MMoH,fVPuxyVCoBb2Wp', 1, '11-186-390-2505', 5787.26, 'BUILDING', 'accounts after the furiously regular asymptotes promise furio'), + (133544900, 'Customer#133544900', 'mNcLlslQI OLd', 18, '28-104-667-6381', 5921.95, 'AUTOMOBILE', 'ructions. final deposits hagg'), + (102534112, 'Customer#102534112', 'ncZrveEuEFrH', 9, '19-499-358-4803', 7724.23, 'HOUSEHOLD', ' regular packages use furiously quickly'), + (138267985, 'Customer#138267985', '9DsYMHRLD,MoCQEpZnVZ1l1JjurNu', 6, '16-148-197-5648', 2281.47, 'FURNITURE', 'cally even deposits doubt blithely after the regular, regular ideas. doggedly re'), + (113008447, 'Customer#113008447', '9Mpt7cWOpVyYUigS9cO', 22, '32-829-953-3050', 9674.50, 'AUTOMOBILE', 'nal platelets sleep accounts. carefully regular packages wake thinly-- furiously final requests'), + (122590799, 'Customer#122590799', 'e2NhKW4yFB,IUqGTdoGAQOw yDz7 ', 10, '20-412-224-4135', 8065.26, 'MACHINERY', 'y even foxes cajole stealthily pending packages. regular platelets wake furiously across the bold, eve'), + (122692442, 'Customer#122692442', 'xoPRicyeRUILOyfIiS6ZzzgyXWM E8q', 16, '26-256-231-3820', 8416.29, 'MACHINERY', 'deas wake blithely above the regular reque'), + (105154997, 'Customer#105154997', 'WTOHyzeBJA8gOTKQ3frIUEOfbxbX5 aM', 13, '23-113-745-6522', 9272.38, 'MACHINERY', ' final requests thrash slyly. blithely pending foxes snoo'), + (146808497, 'Customer#146808497', '4m6gQ6r9NmOHB8bDoX', 22, '32-456-538-3620', 7759.23, 'HOUSEHOLD', 'efully carefully bold foxes. pending courts'), + (147002917, 'Customer#147002917', '8DdRwZczPXZzw8c512veavAaSZDoVXtGJs8Iq', 20, '30-329-176-9254', 5200.75, 'BUILDING', 'y furious foxes believe fluffily carefully regular dependencies. asymptotes sleep slyly c'), + (119401594, 'Customer#119401594', 'rm0DquXXO,VEG9 V17CFuV', 5, '15-145-487-9477', 960.66, 'MACHINERY', 'layers. closely special dependencies sleep carefully'), + (123313904, 'Customer#123313904', 'cG3sxDt19f', 13, '23-862-755-7543', 7240.55, 'HOUSEHOLD', 'p slyly blithely final packages? bold deposits integrate. doggedly even foxes above the accounts n'); diff --git a/tests/resources/datasets/lineitem.sql b/tests/resources/datasets/lineitem.sql new file mode 100644 index 0000000000..91f8a9170f --- /dev/null +++ b/tests/resources/datasets/lineitem.sql @@ -0,0 +1,173 @@ +-------------------------------------------------------------------------------- +CREATE TABLE lineitem ( + l_orderkey BIGINT, + l_partkey BIGINT, + l_suppkey BIGINT, + l_linenumber INT, + l_quantity DECIMAL(18,2), + l_extendedprice DECIMAL(18,2), + l_discount DECIMAL(18,2), + l_tax DECIMAL(18,2), + l_returnflag VARCHAR(500), + l_linestatus VARCHAR(500), + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct VARCHAR(500), + l_shipmode VARCHAR(500), + l_comment VARCHAR(500) + ); + +-------------------------------------------------------------------------------- + +INSERT INTO lineitem (l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment) VALUES +(1, 155189345, 7689361, 1, 17.00, 24252.03, 0.04, 0.02, 'N', 'O', '1996-03-13', '1996-02-12', '1996-03-22', 'DELIVER IN PERSON', 'TRUCK', 'egular courts above the'), + (1, 67309080, 7309081, 2, 36.00, 39085.92, 0.09, 0.06, 'N', 'O', '1996-04-12', '1996-02-28', '1996-04-20', 'TAKE BACK RETURN', 'MAIL', 'ly final dependencies: slyly bold '), + (1, 63699776, 3699777, 3, 8.00, 14180.72, 0.10, 0.02, 'N', 'O', '1996-01-29', '1996-03-05', '1996-01-31', 'TAKE BACK RETURN', 'REG AIR', 'riously. regular, express dep'), + (1, 2131495, 4631496, 4, 28.00, 42738.92, 0.09, 0.06, 'N', 'O', '1996-04-21', '1996-03-30', '1996-05-16', 'NONE', 'AIR', 'lites. fluffily even de'), + (1, 24026634, 1526641, 5, 24.00, 37426.32, 0.10, 0.04, 'N', 'O', '1996-03-30', '1996-03-14', '1996-04-01', 'NONE', 'FOB', ' pending foxes. slyly re'), + (1, 15634450, 634453, 6, 32.00, 44277.44, 0.07, 0.02, 'N', 'O', '1996-01-30', '1996-02-07', '1996-02-03', 'DELIVER IN PERSON', 'MAIL', 'arefully slyly ex'), + (2, 106169722, 1169743, 1, 38.00, 67883.96, 0.00, 0.05, 'N', 'O', '1997-01-28', '1997-01-14', '1997-02-02', 'TAKE BACK RETURN', 'RAIL', 'ven requests. deposits breach a'), + (3, 4296962, 1796963, 1, 45.00, 88143.75, 0.06, 0.00, 'R', 'F', '1994-02-02', '1994-01-04', '1994-02-23', 'NONE', 'AIR', 'ongside of the furiously brave acco'), + (3, 19035429, 6535433, 2, 49.00, 66810.03, 0.10, 0.00, 'R', 'F', '1993-11-09', '1993-12-20', '1993-11-24', 'TAKE BACK RETURN', 'RAIL', ' unusual accounts. eve'), + (3, 128448229, 3448254, 3, 27.00, 31611.60, 0.06, 0.07, 'A', 'F', '1994-01-16', '1993-11-22', '1994-01-23', 'DELIVER IN PERSON', 'SHIP', 'nal foxes wake. '), + (3, 29379610, 1879613, 4, 2.00, 3376.30, 0.01, 0.06, 'A', 'F', '1993-12-04', '1994-01-07', '1994-01-01', 'NONE', 'TRUCK', 'y. fluffily pending d'), + (3, 183094077, 594132, 5, 28.00, 29733.76, 0.04, 0.00, 'R', 'F', '1993-12-14', '1994-01-10', '1994-01-01', 'TAKE BACK RETURN', 'FOB', 'ages nag slyly pending'), + (3, 62142591, 9642610, 6, 26.00, 42392.74, 0.10, 0.02, 'A', 'F', '1993-10-29', '1993-12-18', '1993-11-04', 'TAKE BACK RETURN', 'RAIL', 'ges sleep after the caref'), + (4, 88034684, 5534709, 1, 30.00, 48428.40, 0.03, 0.08, 'N', 'O', '1996-01-10', '1995-12-14', '1996-01-18', 'DELIVER IN PERSON', 'REG AIR', '- quickly regular packages sleep. idly'), + (5, 108569283, 8569284, 1, 15.00, 20202.90, 0.02, 0.04, 'R', 'F', '1994-10-31', '1994-08-31', '1994-11-20', 'NONE', 'AIR', 'ts wake furiously '), + (5, 123926789, 3926790, 2, 26.00, 47049.34, 0.07, 0.08, 'R', 'F', '1994-10-16', '1994-09-25', '1994-10-19', 'NONE', 'FOB', 'sts use slyly quickly special instruc'), + (5, 37530180, 30184, 3, 50.00, 60415.50, 0.08, 0.03, 'A', 'F', '1994-08-08', '1994-10-13', '1994-08-26', 'DELIVER IN PERSON', 'AIR', 'eodolites. fluffily unusual'), + (6, 139635455, 2135469, 1, 37.00, 51188.39, 0.08, 0.03, 'A', 'F', '1992-04-27', '1992-05-15', '1992-05-02', 'TAKE BACK RETURN', 'TRUCK', 'p furiously special foxes'), + (7, 182051839, 9551894, 1, 12.00, 21380.76, 0.07, 0.03, 'N', 'O', '1996-05-07', '1996-03-13', '1996-06-03', 'TAKE BACK RETURN', 'FOB', 'ss pinto beans wake against th'), + (7, 145242743, 7742758, 2, 9.00, 15106.32, 0.08, 0.08, 'N', 'O', '1996-02-01', '1996-03-02', '1996-02-19', 'TAKE BACK RETURN', 'SHIP', 'es. instructions'), + (7, 94779739, 9779758, 3, 46.00, 83444.00, 0.10, 0.07, 'N', 'O', '1996-01-15', '1996-03-27', '1996-02-03', 'COLLECT COD', 'MAIL', ' unusual reques'), + (7, 163072047, 3072048, 4, 28.00, 28304.92, 0.03, 0.04, 'N', 'O', '1996-03-21', '1996-04-08', '1996-04-20', 'NONE', 'FOB', '. slyly special requests haggl'), + (7, 151893810, 9393856, 5, 38.00, 68256.36, 0.08, 0.01, 'N', 'O', '1996-02-11', '1996-02-24', '1996-02-18', 'DELIVER IN PERSON', 'TRUCK', 'ns haggle carefully ironic deposits. bl'), + (7, 79250148, 1750156, 6, 35.00, 38296.30, 0.06, 0.03, 'N', 'O', '1996-01-16', '1996-02-23', '1996-01-22', 'TAKE BACK RETURN', 'FOB', 'jole. excuses wake carefully alongside of '), + (7, 157237027, 2237058, 7, 5.00, 4780.80, 0.04, 0.02, 'N', 'O', '1996-02-10', '1996-03-26', '1996-02-13', 'NONE', 'FOB', 'ithely regula'), + (32, 82703512, 7703529, 1, 28.00, 42318.64, 0.05, 0.08, 'N', 'O', '1995-10-23', '1995-08-27', '1995-10-26', 'TAKE BACK RETURN', 'TRUCK', 'sleep quickly. req'), + (32, 197920162, 420182, 2, 32.00, 37512.64, 0.02, 0.00, 'N', 'O', '1995-08-14', '1995-10-07', '1995-08-27', 'COLLECT COD', 'AIR', 'lithely regular deposits. fluffily '), + (32, 44160335, 6660340, 3, 2.00, 2786.26, 0.09, 0.02, 'N', 'O', '1995-08-07', '1995-10-07', '1995-08-23', 'DELIVER IN PERSON', 'AIR', ' express accounts wake according to the'), + (32, 2742061, 7742062, 4, 4.00, 4411.72, 0.09, 0.03, 'N', 'O', '1995-08-04', '1995-10-01', '1995-09-03', 'NONE', 'REG AIR', 'e slyly final pac'), + (32, 85810176, 8310185, 5, 44.00, 47602.72, 0.05, 0.06, 'N', 'O', '1995-08-28', '1995-08-20', '1995-09-14', 'DELIVER IN PERSON', 'AIR', 'symptotes nag according to the ironic depo'), + (32, 11614679, 4114681, 6, 6.00, 9558.54, 0.04, 0.03, 'N', 'O', '1995-07-21', '1995-09-23', '1995-07-25', 'COLLECT COD', 'RAIL', ' gifts cajole carefully.'), + (33, 61335189, 8835208, 1, 31.00, 37854.72, 0.09, 0.04, 'A', 'F', '1993-10-29', '1993-12-19', '1993-11-08', 'COLLECT COD', 'TRUCK', 'ng to the furiously ironic package'), + (33, 60518681, 5518694, 2, 32.00, 54293.12, 0.02, 0.05, 'A', 'F', '1993-12-09', '1994-01-04', '1993-12-28', 'COLLECT COD', 'MAIL', 'gular theodolites'), + (33, 137468550, 9968564, 3, 5.00, 7558.40, 0.05, 0.03, 'A', 'F', '1993-12-09', '1993-12-25', '1993-12-23', 'TAKE BACK RETURN', 'AIR', '. stealthily bold exc'), + (33, 33917488, 3917489, 4, 41.00, 61655.39, 0.09, 0.00, 'R', 'F', '1993-11-09', '1994-01-24', '1993-11-11', 'TAKE BACK RETURN', 'MAIL', 'unusual packages doubt caref'), + (34, 88361363, 861372, 1, 13.00, 18459.35, 0.00, 0.07, 'N', 'O', '1998-10-23', '1998-09-14', '1998-11-06', 'NONE', 'REG AIR', 'nic accounts. deposits are alon'), + (34, 89413313, 1913322, 2, 22.00, 26880.48, 0.08, 0.06, 'N', 'O', '1998-10-09', '1998-10-16', '1998-10-12', 'NONE', 'FOB', 'thely slyly p'), + (34, 169543048, 4543081, 3, 6.00, 6495.42, 0.02, 0.06, 'N', 'O', '1998-10-30', '1998-09-20', '1998-11-05', 'NONE', 'FOB', 'ar foxes sleep '), + (35, 449928, 2949929, 1, 24.00, 45069.60, 0.02, 0.00, 'N', 'O', '1996-02-21', '1996-01-03', '1996-03-18', 'TAKE BACK RETURN', 'FOB', ', regular tithe'), + (35, 161939722, 4439739, 2, 34.00, 59623.42, 0.06, 0.08, 'N', 'O', '1996-01-22', '1996-01-06', '1996-01-27', 'DELIVER IN PERSON', 'RAIL', 's are carefully against the f'), + (35, 120895174, 8395211, 3, 7.00, 8141.91, 0.06, 0.04, 'N', 'O', '1996-01-19', '1995-12-22', '1996-01-29', 'NONE', 'MAIL', ' the carefully regular '), + (35, 85174030, 7674039, 4, 25.00, 27494.50, 0.06, 0.05, 'N', 'O', '1995-11-26', '1995-12-25', '1995-12-21', 'DELIVER IN PERSON', 'SHIP', ' quickly unti'), + (35, 119916152, 4916175, 5, 34.00, 39513.44, 0.08, 0.06, 'N', 'O', '1995-11-08', '1996-01-15', '1995-11-26', 'COLLECT COD', 'MAIL', '. silent, unusual deposits boost'), + (35, 30761725, 3261729, 6, 28.00, 49985.32, 0.03, 0.02, 'N', 'O', '1996-02-01', '1995-12-24', '1996-02-28', 'COLLECT COD', 'RAIL', 'ly alongside of '), + (36, 119766448, 9766449, 1, 42.00, 63355.32, 0.09, 0.00, 'N', 'O', '1996-02-03', '1996-01-21', '1996-02-23', 'COLLECT COD', 'SHIP', ' careful courts. special '), + (37, 22629071, 5129074, 1, 40.00, 39957.60, 0.09, 0.03, 'A', 'F', '1992-07-21', '1992-08-01', '1992-08-15', 'NONE', 'REG AIR', 'luffily regular requests. slyly final acco'), + (37, 126781276, 1781301, 2, 39.00, 52686.66, 0.05, 0.02, 'A', 'F', '1992-07-02', '1992-08-18', '1992-07-28', 'TAKE BACK RETURN', 'RAIL', 'the final requests. ca'), + (37, 12902948, 5402950, 3, 43.00, 83862.90, 0.05, 0.08, 'A', 'F', '1992-07-10', '1992-07-06', '1992-08-02', 'DELIVER IN PERSON', 'TRUCK', 'iously ste'), + (38, 175838801, 838836, 1, 44.00, 76164.44, 0.04, 0.02, 'N', 'O', '1996-09-29', '1996-11-17', '1996-09-30', 'COLLECT COD', 'MAIL', 's. blithely unusual theodolites am'), + (39, 2319664, 9819665, 1, 44.00, 74076.20, 0.09, 0.06, 'N', 'O', '1996-11-14', '1996-12-15', '1996-12-12', 'COLLECT COD', 'RAIL', 'eodolites. careful'), + (39, 186581058, 4081113, 2, 26.00, 29372.98, 0.08, 0.04, 'N', 'O', '1996-11-04', '1996-10-20', '1996-11-20', 'NONE', 'FOB', 'ckages across the slyly silent'), + (39, 67830106, 5330125, 3, 46.00, 47504.66, 0.06, 0.08, 'N', 'O', '1996-09-26', '1996-12-19', '1996-10-26', 'DELIVER IN PERSON', 'AIR', 'he carefully e'), + (39, 20589905, 3089908, 4, 32.00, 63804.16, 0.07, 0.05, 'N', 'O', '1996-10-02', '1996-12-19', '1996-10-14', 'COLLECT COD', 'MAIL', 'heodolites sleep silently pending foxes. ac'), + (39, 54518616, 9518627, 5, 43.00, 70171.27, 0.01, 0.01, 'N', 'O', '1996-10-17', '1996-11-14', '1996-10-26', 'COLLECT COD', 'MAIL', 'yly regular i'), + (39, 94367239, 6867249, 6, 40.00, 52060.80, 0.06, 0.05, 'N', 'O', '1996-12-08', '1996-10-22', '1997-01-01', 'COLLECT COD', 'AIR', 'quickly ironic fox'), + (64, 85950874, 5950875, 1, 21.00, 40332.18, 0.05, 0.02, 'R', 'F', '1994-09-30', '1994-09-18', '1994-10-26', 'DELIVER IN PERSON', 'REG AIR', 'ch slyly final, thin platelets.'), + (65, 59693808, 4693819, 1, 26.00, 46769.32, 0.03, 0.03, 'A', 'F', '1995-04-20', '1995-04-25', '1995-05-13', 'NONE', 'TRUCK', 'pending deposits nag even packages. ca'), + (65, 73814565, 8814580, 2, 22.00, 32469.14, 0.00, 0.05, 'N', 'O', '1995-07-17', '1995-06-04', '1995-07-19', 'COLLECT COD', 'FOB', ' ideas. special, r'), + (65, 1387319, 3887320, 3, 21.00, 29531.25, 0.09, 0.07, 'N', 'O', '1995-07-06', '1995-05-14', '1995-07-31', 'DELIVER IN PERSON', 'RAIL', 'bove the even packages. accounts nag carefu'), + (66, 115117124, 7617136, 1, 31.00, 35196.47, 0.00, 0.08, 'R', 'F', '1994-02-19', '1994-03-11', '1994-02-20', 'TAKE BACK RETURN', 'RAIL', 'ut the unusual accounts sleep at the bo'), + (66, 173488357, 3488358, 2, 41.00, 54803.88, 0.04, 0.07, 'A', 'F', '1994-02-21', '1994-03-01', '1994-03-18', 'COLLECT COD', 'AIR', ' regular de'), + (67, 21635045, 9135052, 1, 4.00, 3915.84, 0.09, 0.04, 'N', 'O', '1997-04-17', '1997-01-31', '1997-04-20', 'NONE', 'SHIP', ' cajole thinly expres'), + (67, 20192396, 5192401, 2, 12.00, 17848.68, 0.09, 0.05, 'N', 'O', '1997-01-27', '1997-02-21', '1997-02-22', 'NONE', 'REG AIR', ' even packages cajole'), + (67, 173599543, 6099561, 3, 5.00, 8169.35, 0.03, 0.07, 'N', 'O', '1997-02-20', '1997-02-12', '1997-02-21', 'DELIVER IN PERSON', 'TRUCK', 'y unusual packages thrash pinto '), + (67, 87513573, 7513574, 4, 44.00, 69616.80, 0.08, 0.06, 'N', 'O', '1997-03-18', '1997-01-29', '1997-04-13', 'DELIVER IN PERSON', 'RAIL', 'se quickly above the even, express reques'), + (67, 40612740, 8112753, 5, 23.00, 37966.33, 0.05, 0.07, 'N', 'O', '1997-04-19', '1997-02-14', '1997-05-06', 'DELIVER IN PERSON', 'REG AIR', 'ly regular deposit'), + (67, 178305451, 805469, 6, 29.00, 41978.66, 0.02, 0.05, 'N', 'O', '1997-01-25', '1997-01-27', '1997-01-27', 'DELIVER IN PERSON', 'FOB', 'ultipliers '), + (68, 7067007, 9567008, 1, 3.00, 2920.95, 0.05, 0.02, 'N', 'O', '1998-07-04', '1998-06-05', '1998-07-21', 'NONE', 'RAIL', 'fully special instructions cajole. furious'), + (68, 175179091, 2679143, 2, 46.00, 53421.64, 0.02, 0.05, 'N', 'O', '1998-06-26', '1998-06-07', '1998-07-05', 'NONE', 'MAIL', ' requests are unusual, regular pinto '), + (68, 34979160, 7479164, 3, 46.00, 56921.32, 0.04, 0.05, 'N', 'O', '1998-08-13', '1998-07-08', '1998-08-29', 'NONE', 'RAIL', 'egular dependencies affix ironically along '), + (68, 94727362, 2227390, 4, 20.00, 27692.60, 0.07, 0.01, 'N', 'O', '1998-06-27', '1998-05-23', '1998-07-02', 'NONE', 'REG AIR', ' excuses integrate fluffily '), + (68, 82757337, 5257346, 5, 27.00, 37535.40, 0.03, 0.06, 'N', 'O', '1998-06-19', '1998-06-25', '1998-06-29', 'DELIVER IN PERSON', 'SHIP', 'ccounts. deposits use. furiously'), + (68, 102560793, 5060804, 6, 30.00, 55460.10, 0.05, 0.06, 'N', 'O', '1998-08-11', '1998-07-11', '1998-08-14', 'NONE', 'RAIL', 'oxes are slyly blithely fin'), + (68, 139246458, 1746472, 7, 41.00, 57297.09, 0.09, 0.08, 'N', 'O', '1998-06-24', '1998-06-27', '1998-07-06', 'NONE', 'SHIP', 'eposits nag special ideas. furiousl'), + (69, 115208198, 7708210, 1, 48.00, 52820.64, 0.01, 0.07, 'A', 'F', '1994-08-17', '1994-08-11', '1994-09-08', 'NONE', 'TRUCK', 'regular epitaphs. carefully even ideas hag'), + (69, 104179049, 9179070, 2, 32.00, 35930.88, 0.08, 0.06, 'A', 'F', '1994-08-24', '1994-08-17', '1994-08-31', 'NONE', 'REG AIR', 's sleep carefully bold, '), + (69, 137266467, 4766507, 3, 17.00, 24252.20, 0.09, 0.00, 'A', 'F', '1994-07-02', '1994-07-07', '1994-07-03', 'TAKE BACK RETURN', 'AIR', 'final, pending instr'), + (69, 37501760, 2501767, 4, 3.00, 5279.67, 0.09, 0.04, 'R', 'F', '1994-06-06', '1994-07-27', '1994-06-15', 'NONE', 'MAIL', ' blithely final d'), + (69, 92069882, 7069901, 5, 42.00, 77585.76, 0.07, 0.04, 'R', 'F', '1994-07-31', '1994-07-26', '1994-08-28', 'DELIVER IN PERSON', 'REG AIR', 'tect regular, speci'), + (69, 18503830, 1003832, 6, 23.00, 42156.93, 0.05, 0.00, 'A', 'F', '1994-10-03', '1994-08-06', '1994-10-24', 'NONE', 'SHIP', 'nding accounts ca'), + (70, 64127814, 9127827, 1, 8.00, 14708.88, 0.03, 0.08, 'R', 'F', '1994-01-12', '1994-02-27', '1994-01-14', 'TAKE BACK RETURN', 'FOB', 'ggle. carefully pending dependenc'), + (70, 196155163, 1155202, 2, 13.00, 15708.68, 0.06, 0.06, 'A', 'F', '1994-03-03', '1994-02-13', '1994-03-26', 'COLLECT COD', 'AIR', 'lyly special packag'), + (70, 179808755, 7308807, 3, 1.00, 1854.77, 0.03, 0.05, 'R', 'F', '1994-01-26', '1994-03-05', '1994-01-28', 'TAKE BACK RETURN', 'RAIL', 'quickly. fluffily unusual theodolites c'), + (70, 45733155, 733164, 4, 11.00, 13044.57, 0.01, 0.05, 'A', 'F', '1994-03-17', '1994-03-17', '1994-03-27', 'NONE', 'MAIL', 'alongside of the deposits. fur'), + (70, 37130699, 2130706, 5, 37.00, 63930.08, 0.09, 0.04, 'R', 'F', '1994-02-13', '1994-03-16', '1994-02-21', 'COLLECT COD', 'MAIL', 'n accounts are. q'), + (70, 55654148, 3154164, 6, 19.00, 20887.84, 0.06, 0.03, 'A', 'F', '1994-01-26', '1994-02-17', '1994-02-06', 'TAKE BACK RETURN', 'SHIP', ' packages wake pending accounts.'), + (71, 61930501, 1930502, 1, 25.00, 38210.25, 0.09, 0.07, 'N', 'O', '1998-04-10', '1998-04-22', '1998-04-11', 'COLLECT COD', 'FOB', 'ckly. slyly'), + (71, 65915062, 3415081, 2, 3.00, 3221.31, 0.09, 0.07, 'N', 'O', '1998-05-23', '1998-04-03', '1998-06-02', 'COLLECT COD', 'SHIP', 'y. pinto beans haggle after the'), + (71, 34431883, 1931893, 3, 45.00, 81592.20, 0.00, 0.07, 'N', 'O', '1998-02-23', '1998-03-20', '1998-03-24', 'DELIVER IN PERSON', 'SHIP', ' ironic packages believe blithely a'), + (71, 96644449, 9144459, 4, 33.00, 45824.13, 0.00, 0.01, 'N', 'O', '1998-04-12', '1998-03-20', '1998-04-15', 'NONE', 'FOB', ' serve quickly fluffily bold deposi'), + (71, 103254337, 5754348, 5, 39.00, 50160.63, 0.08, 0.06, 'N', 'O', '1998-01-29', '1998-04-07', '1998-02-18', 'DELIVER IN PERSON', 'RAIL', 'l accounts sleep across the pack'), + (71, 195634217, 634256, 6, 34.00, 38808.62, 0.04, 0.01, 'N', 'O', '1998-03-05', '1998-04-22', '1998-03-30', 'DELIVER IN PERSON', 'TRUCK', 's cajole. '), + (96, 123075825, 575862, 1, 23.00, 41277.41, 0.10, 0.06, 'A', 'F', '1994-07-19', '1994-06-29', '1994-07-25', 'DELIVER IN PERSON', 'TRUCK', 'ep-- carefully reg'), + (96, 135389770, 5389771, 2, 30.00, 55590.30, 0.01, 0.06, 'R', 'F', '1994-06-03', '1994-05-29', '1994-06-22', 'DELIVER IN PERSON', 'TRUCK', 'e quickly even ideas. furiou'), + (97, 119476978, 1976990, 1, 13.00, 25337.00, 0.00, 0.02, 'R', 'F', '1993-04-01', '1993-04-04', '1993-04-08', 'NONE', 'TRUCK', 'ayers cajole against the furiously'), + (97, 49567306, 2067311, 2, 37.00, 50720.71, 0.02, 0.06, 'A', 'F', '1993-04-13', '1993-03-30', '1993-04-14', 'DELIVER IN PERSON', 'SHIP', 'ic requests boost carefully quic'), + (97, 77698944, 5198966, 3, 19.00, 36842.14, 0.06, 0.08, 'R', 'F', '1993-05-14', '1993-03-05', '1993-05-25', 'TAKE BACK RETURN', 'RAIL', 'gifts. furiously ironic packages cajole. '), + (98, 40215967, 215968, 1, 28.00, 52666.60, 0.06, 0.07, 'A', 'F', '1994-12-24', '1994-10-25', '1995-01-16', 'COLLECT COD', 'REG AIR', ' pending, regular accounts s'), + (98, 109742650, 7242681, 2, 1.00, 1687.17, 0.00, 0.00, 'A', 'F', '1994-12-01', '1994-12-12', '1994-12-15', 'DELIVER IN PERSON', 'TRUCK', '. unusual instructions against'), + (98, 44705610, 4705611, 3, 14.00, 22587.32, 0.05, 0.02, 'A', 'F', '1994-12-30', '1994-11-22', '1995-01-27', 'COLLECT COD', 'AIR', ' cajole furiously. blithely ironic ideas '), + (98, 167179412, 7179413, 4, 10.00, 14830.60, 0.03, 0.03, 'A', 'F', '1994-10-23', '1994-11-08', '1994-11-09', 'COLLECT COD', 'RAIL', ' carefully. quickly ironic ideas'), + (99, 87113927, 4613952, 1, 10.00, 19365.70, 0.02, 0.01, 'A', 'F', '1994-05-18', '1994-06-03', '1994-05-23', 'COLLECT COD', 'RAIL', 'kages. requ'), + (99, 123765936, 3765937, 2, 5.00, 9978.75, 0.02, 0.07, 'R', 'F', '1994-05-06', '1994-05-28', '1994-05-20', 'TAKE BACK RETURN', 'RAIL', 'ests cajole fluffily waters. blithe'), + (99, 134081534, 1581574, 3, 42.00, 63370.86, 0.02, 0.02, 'A', 'F', '1994-04-19', '1994-05-18', '1994-04-20', 'NONE', 'RAIL', 'kages are fluffily furiously ir'), + (99, 108337010, 837021, 4, 36.00, 37497.60, 0.09, 0.02, 'A', 'F', '1994-07-04', '1994-04-17', '1994-07-30', 'DELIVER IN PERSON', 'AIR', 'slyly. slyly e'), + (100, 62028678, 2028679, 1, 28.00, 44899.96, 0.04, 0.05, 'N', 'O', '1998-05-08', '1998-05-13', '1998-06-07', 'COLLECT COD', 'TRUCK', 'sts haggle. slowl'), + (100, 115978233, 8478245, 2, 22.00, 28719.68, 0.00, 0.07, 'N', 'O', '1998-06-24', '1998-04-12', '1998-06-29', 'DELIVER IN PERSON', 'SHIP', 'nto beans alongside of the fi'), + (100, 46149701, 8649706, 3, 46.00, 80426.40, 0.03, 0.04, 'N', 'O', '1998-05-02', '1998-04-10', '1998-05-22', 'TAKE BACK RETURN', 'SHIP', 'ular accounts. even'), + (100, 38023053, 3023060, 4, 14.00, 13638.10, 0.06, 0.03, 'N', 'O', '1998-05-22', '1998-05-01', '1998-06-03', 'COLLECT COD', 'MAIL', 'y. furiously ironic ideas gr'), + (100, 53438259, 938275, 5, 37.00, 44199.46, 0.05, 0.00, 'N', 'O', '1998-03-06', '1998-04-16', '1998-03-31', 'TAKE BACK RETURN', 'TRUCK', 'nd the quickly s'), + (101, 118281867, 5781901, 1, 49.00, 90304.55, 0.10, 0.00, 'N', 'O', '1996-06-21', '1996-05-27', '1996-06-29', 'DELIVER IN PERSON', 'REG AIR', 'ts-- final packages sleep furiousl'), + (101, 163333041, 833090, 2, 36.00, 38371.68, 0.00, 0.01, 'N', 'O', '1996-05-19', '1996-05-01', '1996-06-04', 'DELIVER IN PERSON', 'AIR', 'tes. blithely pending dolphins x-ray f'), + (101, 138417252, 5917292, 3, 12.00, 13947.96, 0.06, 0.02, 'N', 'O', '1996-03-29', '1996-04-20', '1996-04-12', 'COLLECT COD', 'MAIL', '. quickly regular'), + (102, 88913503, 3913520, 1, 37.00, 55946.22, 0.06, 0.00, 'N', 'O', '1997-07-24', '1997-08-02', '1997-08-07', 'TAKE BACK RETURN', 'SHIP', 'ully across the ideas. final deposit'), + (102, 169237956, 6738005, 2, 34.00, 64106.66, 0.03, 0.08, 'N', 'O', '1997-08-09', '1997-07-28', '1997-08-26', 'TAKE BACK RETURN', 'SHIP', 'eposits cajole across'), + (102, 182320531, 4820550, 3, 25.00, 38560.50, 0.01, 0.01, 'N', 'O', '1997-07-31', '1997-07-24', '1997-08-17', 'NONE', 'RAIL', 'bits. ironic accoun'), + (102, 61157984, 8658003, 4, 15.00, 30583.95, 0.07, 0.07, 'N', 'O', '1997-06-02', '1997-07-13', '1997-06-04', 'DELIVER IN PERSON', 'SHIP', 'final packages. carefully even excu'), + (103, 194657609, 2157667, 1, 6.00, 9341.22, 0.03, 0.05, 'N', 'O', '1996-10-11', '1996-07-25', '1996-10-28', 'NONE', 'FOB', 'cajole. carefully ex'), + (103, 10425920, 2925922, 2, 37.00, 68279.80, 0.02, 0.07, 'N', 'O', '1996-09-17', '1996-07-27', '1996-09-20', 'TAKE BACK RETURN', 'MAIL', 'ies. quickly ironic requests use blithely'), + (103, 28430358, 8430359, 3, 23.00, 29599.39, 0.01, 0.04, 'N', 'O', '1996-09-11', '1996-09-18', '1996-09-26', 'NONE', 'FOB', 'ironic accou'), + (103, 29021558, 4021563, 4, 32.00, 47299.20, 0.01, 0.07, 'N', 'O', '1996-07-30', '1996-08-06', '1996-08-04', 'NONE', 'RAIL', 'kages doze. special, regular deposit'), + (128, 106827451, 9327462, 1, 38.00, 52178.18, 0.06, 0.01, 'A', 'F', '1992-09-01', '1992-08-27', '1992-10-01', 'TAKE BACK RETURN', 'FOB', ' cajole careful'), + (129, 2866970, 5366971, 1, 46.00, 89094.18, 0.08, 0.02, 'R', 'F', '1993-02-15', '1993-01-24', '1993-03-05', 'COLLECT COD', 'TRUCK', 'uietly bold theodolites. fluffil'), + (129, 185163292, 5163293, 2, 36.00, 48457.44, 0.01, 0.02, 'A', 'F', '1992-11-25', '1992-12-25', '1992-12-09', 'TAKE BACK RETURN', 'REG AIR', 'packages are care'), + (129, 39443990, 1943994, 3, 33.00, 63756.66, 0.04, 0.06, 'A', 'F', '1993-01-08', '1993-02-14', '1993-01-29', 'COLLECT COD', 'SHIP', 'sts nag bravely. fluffily'), + (129, 135136037, 136064, 4, 34.00, 36253.52, 0.00, 0.01, 'R', 'F', '1993-01-29', '1993-02-14', '1993-02-10', 'COLLECT COD', 'MAIL', 'quests. express ideas'), + (129, 31372467, 8872477, 5, 24.00, 36909.60, 0.06, 0.00, 'A', 'F', '1992-12-07', '1993-01-02', '1992-12-11', 'TAKE BACK RETURN', 'FOB', 'uests. foxes cajole slyly after the ca'), + (129, 77049359, 4549381, 6, 22.00, 28699.00, 0.06, 0.01, 'R', 'F', '1993-02-15', '1993-01-31', '1993-02-24', 'COLLECT COD', 'SHIP', 'e. fluffily regular '), + (129, 168568384, 3568417, 7, 1.00, 1443.96, 0.05, 0.04, 'R', 'F', '1993-01-26', '1993-01-08', '1993-02-24', 'DELIVER IN PERSON', 'FOB', 'e carefully blithely bold dolp'), + (130, 128815478, 8815479, 1, 14.00, 19418.42, 0.08, 0.05, 'A', 'F', '1992-08-15', '1992-07-25', '1992-09-13', 'COLLECT COD', 'RAIL', ' requests. final instruction'), + (130, 1738077, 4238078, 2, 48.00, 53519.52, 0.03, 0.02, 'R', 'F', '1992-07-01', '1992-07-12', '1992-07-24', 'NONE', 'AIR', 'lithely alongside of the regu'), + (130, 11859085, 1859086, 3, 18.00, 18782.82, 0.04, 0.08, 'A', 'F', '1992-07-04', '1992-06-14', '1992-07-29', 'DELIVER IN PERSON', 'MAIL', ' slyly ironic decoys abou'), + (130, 115634506, 3134540, 4, 13.00, 18651.36, 0.09, 0.02, 'R', 'F', '1992-06-26', '1992-07-29', '1992-07-05', 'NONE', 'FOB', ' pending dolphins sleep furious'), + (130, 69129320, 4129333, 5, 31.00, 41721.97, 0.06, 0.05, 'R', 'F', '1992-09-01', '1992-07-18', '1992-09-02', 'TAKE BACK RETURN', 'RAIL', 'thily about the ruth'), + (131, 167504270, 4287, 1, 45.00, 56965.50, 0.10, 0.02, 'R', 'F', '1994-09-14', '1994-09-02', '1994-10-04', 'NONE', 'FOB', 'ironic, bold accounts. careful'), + (131, 44254717, 9254726, 2, 50.00, 83475.00, 0.02, 0.04, 'A', 'F', '1994-09-17', '1994-08-10', '1994-09-21', 'NONE', 'SHIP', 'ending requests. final, ironic pearls slee'), + (131, 189020323, 1520342, 3, 4.00, 4935.48, 0.04, 0.03, 'A', 'F', '1994-09-20', '1994-08-30', '1994-09-23', 'COLLECT COD', 'REG AIR', ' are carefully slyly i'), + (132, 140448567, 2948582, 1, 18.00, 27153.72, 0.00, 0.08, 'R', 'F', '1993-07-10', '1993-08-05', '1993-07-13', 'NONE', 'TRUCK', 'ges. platelets wake furio'), + (132, 119052444, 9052445, 2, 43.00, 59791.07, 0.01, 0.08, 'R', 'F', '1993-09-01', '1993-08-16', '1993-09-22', 'NONE', 'TRUCK', 'y pending theodolites'), + (132, 114418283, 4418284, 3, 32.00, 38257.92, 0.04, 0.04, 'A', 'F', '1993-07-12', '1993-08-05', '1993-08-05', 'COLLECT COD', 'TRUCK', 'd instructions hagg'), + (132, 28081909, 5581916, 4, 23.00, 43458.50, 0.10, 0.00, 'A', 'F', '1993-06-16', '1993-08-27', '1993-06-23', 'DELIVER IN PERSON', 'AIR', 'refully blithely bold acco'), + (133, 103431682, 5931693, 1, 27.00, 43429.77, 0.00, 0.02, 'N', 'O', '1997-12-21', '1998-02-23', '1997-12-27', 'TAKE BACK RETURN', 'MAIL', 'yly even gifts after the sl'), + (133, 176278774, 3778826, 2, 12.00, 20927.52, 0.02, 0.06, 'N', 'O', '1997-12-02', '1998-01-15', '1997-12-29', 'DELIVER IN PERSON', 'REG AIR', 'ts cajole fluffily quickly i'), + (133, 117349311, 4849345, 3, 29.00, 39279.05, 0.09, 0.08, 'N', 'O', '1998-02-28', '1998-01-30', '1998-03-09', 'DELIVER IN PERSON', 'RAIL', ' the carefully regular theodoli'), + (133, 89854644, 7354669, 4, 11.00, 17535.65, 0.06, 0.01, 'N', 'O', '1998-03-21', '1998-01-15', '1998-04-04', 'DELIVER IN PERSON', 'REG AIR', 'e quickly across the dolphins'), + (134, 640486, 640487, 1, 21.00, 29955.45, 0.00, 0.03, 'A', 'F', '1992-07-17', '1992-07-08', '1992-07-26', 'COLLECT COD', 'SHIP', 's. quickly regular'), + (134, 164644985, 9645018, 2, 35.00, 67261.25, 0.06, 0.07, 'A', 'F', '1992-08-23', '1992-06-01', '1992-08-24', 'NONE', 'MAIL', 'ajole furiously. instructio'), + (134, 188251562, 3251599, 3, 26.00, 39107.90, 0.09, 0.06, 'A', 'F', '1992-06-20', '1992-07-12', '1992-07-16', 'NONE', 'RAIL', ' among the pending depos'), + (134, 144001617, 4001618, 4, 47.00, 80436.74, 0.05, 0.00, 'A', 'F', '1992-08-16', '1992-07-06', '1992-08-28', 'NONE', 'REG AIR', 's! carefully unusual requests boost careful'), + (134, 35171840, 5171841, 5, 12.00, 22921.08, 0.05, 0.02, 'A', 'F', '1992-07-03', '1992-06-01', '1992-07-11', 'COLLECT COD', 'TRUCK', 'nts are quic'); diff --git a/tests/resources/datasets/nation.sql b/tests/resources/datasets/nation.sql new file mode 100644 index 0000000000..9985dfa434 --- /dev/null +++ b/tests/resources/datasets/nation.sql @@ -0,0 +1,36 @@ +------------------------------------------------------------ +CREATE TABLE nation ( + n_nationkey BIGINT, + n_name VARCHAR(500), + n_regionkey BIGINT, + n_comment VARCHAR(500) + ); + +------------------------------------------------------------ + +INSERT INTO nation (n_nationkey, n_name, n_regionkey, n_comment) VALUES + (0, 'ALGERIA', 0, ' haggle. carefully final deposits detect slyly agai'), + (1, 'ARGENTINA', 1, 'al foxes promise slyly according to the regular accounts. bold requests alon'), + (2, 'BRAZIL', 1, 'y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special '), + (3, 'CANADA', 1, 'eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold'), + (4, 'EGYPT', 4, 'y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d'), + (5, 'ETHIOPIA', 0, 'ven packages wake quickly. regu'), + (6, 'FRANCE', 3, 'refully final requests. regular, ironi'), + (7, 'GERMANY', 3, 'l platelets. regular accounts x-ray: unusual, regular acco'), + (8, 'INDIA', 2, 'ss excuses cajole slyly across the packages. deposits print aroun'), + (9, 'INDONESIA', 2, ' slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull'), + (10, 'IRAN', 4, 'efully alongside of the slyly final dependencies. '), + (11, 'IRAQ', 4, 'nic deposits boost atop the quickly final requests? quickly regula'), + (12, 'JAPAN', 2, 'ously. final, express gifts cajole a'), + (13, 'JORDAN', 4, 'ic deposits are blithely about the carefully regular pa'), + (14, 'KENYA', 0, ' pending excuses haggle furiously deposits. pending, express pinto beans wake fluffily past t'), + (15, 'MOROCCO', 0, 'rns. blithely bold courts among the closely regular packages use furiously bold platelets?'), + (16, 'MOZAMBIQUE', 0, 's. ironic, unusual asymptotes wake blithely r'), + (17, 'PERU', 1, 'platelets. blithely pending dependencies use fluffily across the even pinto beans. carefully silent accoun'), + (18, 'CHINA', 2, 'c dependencies. furiously express notornis sleep slyly regular accounts. ideas sleep. depos'), + (19, 'ROMANIA', 3, 'ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account'), + (20, 'SAUDI ARABIA', 4, 'ts. silent requests haggle. closely express packages sleep across the blithely'), + (21, 'VIETNAM', 2, 'hely enticingly express accounts. even, final '), + (22, 'RUSSIA', 3, ' requests against the platelets use never according to the quickly regular pint'), + (23, 'UNITED KINGDOM', 3, 'eans boost carefully special requests. accounts are. carefull'), + (24, 'UNITED STATES', 1, 'y final packages. slow foxes cajole quickly. quickly silent platelets breach ironic accounts. unusual pinto be'); diff --git a/tests/resources/datasets/order.sql b/tests/resources/datasets/order.sql new file mode 100644 index 0000000000..9427667fd0 --- /dev/null +++ b/tests/resources/datasets/order.sql @@ -0,0 +1,116 @@ +---------------------------------------------- +CREATE TABLE orders ( + o_orderkey BIGINT, + o_custkey BIGINT, + o_orderstatus VARCHAR(500), + o_totalprice DECIMAL(18,2), + o_orderdate DATE, + o_orderpriority VARCHAR(500), + o_clerk STRING, + o_shippriority INT, + o_comment VARCHAR(500) + ); + +---------------------------------------------- + +INSERT INTO orders (o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment) VALUES + (1, 36899983, 'O', 192726.27, '1996-01-02', '5-LOW', 'Clerk#000950550', 0, 'nstructions sleep furiously among '), + (2, 78001627, 'O', 71278.15, '1996-12-01', '1-URGENT', 'Clerk#000879155', 0, ' foxes. pending accounts at the pending, silent asymptot'), + (3, 123313904, 'F', 245783.08, '1993-10-14', '5-LOW', 'Clerk#000954253', 0, 'sly final accounts boost. carefully regular ideas cajole carefully. depos'), + (4, 136776016, 'O', 50733.58, '1995-10-11', '5-LOW', 'Clerk#000123397', 0, 'sits. slyly regular warthogs cajole. regular, regular theodolites acro'), + (5, 44484776, 'F', 125096.86, '1994-07-30', '5-LOW', 'Clerk#000924797', 0, 'quickly. bold deposits sleep slyly. packages use slyly'), + (6, 55622011, 'F', 48506.10, '1992-02-21', '4-NOT SPECIFIED', 'Clerk#000057980', 0, 'ggle. special, final requests are against the furiously specia'), + (7, 39134293, 'O', 249584.44, '1996-01-10', '2-HIGH', 'Clerk#000469603', 0, 'ly special requests '), + (32, 130056940, 'O', 144289.97, '1995-07-16', '2-HIGH', 'Clerk#000615602', 0, 'ise blithely bold, regular requests. quickly unusual dep'), + (33, 66957875, 'F', 155195.60, '1993-10-27', '3-MEDIUM', 'Clerk#000408594', 0, 'uriously. furiously final request'), + (34, 61000036, 'O', 52712.78, '1998-07-21', '3-MEDIUM', 'Clerk#000222780', 0, 'ly final packages. fluffily final deposits wake blithely ideas. spe'), + (35, 127587628, 'O', 227783.45, '1995-10-23', '4-NOT SPECIFIED', 'Clerk#000258366', 0, 'zzle. carefully enticing deposits nag furio'), + (36, 115251748, 'O', 57653.34, '1995-11-03', '1-URGENT', 'Clerk#000357081', 0, ' quick packages are blithely. slyly silent accounts wake qu'), + (37, 86114407, 'F', 174548.94, '1992-06-03', '3-MEDIUM', 'Clerk#000455561', 0, 'kly regular pinto beans. carefully unusual waters cajole never'), + (38, 124827856, 'O', 74580.21, '1996-08-21', '4-NOT SPECIFIED', 'Clerk#000603657', 0, 'haggle blithely. furiously express ideas haggle blithely furiously regular re'), + (39, 81762487, 'O', 331637.68, '1996-09-20', '3-MEDIUM', 'Clerk#000658575', 0, 'ole express, ironic requests: ir'), + (64, 32112589, 'F', 39081.88, '1994-07-16', '3-MEDIUM', 'Clerk#000660299', 0, 'wake fluffily. sometimes ironic pinto beans about the dolphin'), + (65, 16250920, 'P', 109574.38, '1995-03-18', '1-URGENT', 'Clerk#000631984', 0, 'ular requests are blithely pending orbits-- even requests against the deposit'), + (66, 129199537, 'F', 94306.72, '1994-01-20', '5-LOW', 'Clerk#000742283', 0, 'y pending requests integrate'), + (67, 56612041, 'O', 178918.38, '1996-12-19', '4-NOT SPECIFIED', 'Clerk#000546406', 0, 'symptotes haggle slyly around the furiously iron'), + (68, 28546706, 'O', 291943.34, '1998-04-18', '3-MEDIUM', 'Clerk#000439361', 0, ' pinto beans sleep carefully. blithely ironic deposits haggle furiously acro'), + (69, 84486485, 'F', 233148.86, '1994-06-04', '4-NOT SPECIFIED', 'Clerk#000329882', 0, ' depths atop the slyly thin deposits detect among the furiously silent accou'), + (70, 64339465, 'F', 127237.02, '1993-12-18', '5-LOW', 'Clerk#000321668', 0, ' carefully ironic request'), + (71, 3371083, 'O', 260473.37, '1998-01-24', '4-NOT SPECIFIED', 'Clerk#000270150', 0, ' express deposits along the blithely regul'), + (96, 107777776, 'F', 97715.08, '1994-04-17', '2-HIGH', 'Clerk#000394512', 0, 'oost furiously. pinto'), + (97, 21059945, 'F', 115934.53, '1993-01-29', '3-MEDIUM', 'Clerk#000546823', 0, 'hang blithely along the regular accounts. furiously even ideas after the'), + (98, 104479618, 'F', 91363.58, '1994-09-25', '1-URGENT', 'Clerk#000447788', 0, 'c asymptotes. quickly regular packages should have to nag re'), + (99, 88909205, 'F', 127782.63, '1994-03-13', '4-NOT SPECIFIED', 'Clerk#000972249', 0, 'e carefully ironic packages. pending'), + (100, 147002917, 'O', 212317.22, '1998-02-28', '4-NOT SPECIFIED', 'Clerk#000576035', 0, 'heodolites detect slyly alongside of the ent'), + (101, 27997306, 'O', 133402.78, '1996-03-17', '3-MEDIUM', 'Clerk#000418136', 0, 'ding accounts above the slyly final asymptote'), + (102, 715720, 'O', 188738.28, '1997-05-09', '2-HIGH', 'Clerk#000595175', 0, ' slyly according to the asymptotes. carefully final packages integrate furious'), + (103, 29099422, 'O', 161691.76, '1996-06-20', '4-NOT SPECIFIED', 'Clerk#000089629', 0, 'ges. carefully unusual instructions haggle quickly regular f'), + (128, 73955563, 'F', 49537.95, '1992-06-15', '1-URGENT', 'Clerk#000384118', 0, 'ns integrate fluffily. ironic asymptotes after the regular excuses nag around '), + (129, 71133277, 'F', 297401.58, '1992-11-19', '5-LOW', 'Clerk#000858692', 0, 'ing tithes. carefully pending deposits boost about the silently express '), + (130, 36963334, 'F', 149676.16, '1992-05-08', '2-HIGH', 'Clerk#000035202', 0, 'le slyly unusual, regular packages? express deposits det'), + (131, 92748641, 'F', 142252.24, '1994-06-08', '3-MEDIUM', 'Clerk#000624829', 0, 'after the fluffily special foxes integrate s'), + (132, 26393408, 'F', 170563.96, '1993-06-11', '3-MEDIUM', 'Clerk#000487044', 0, 'sits are daringly accounts. carefully regular foxes sleep slyly about the'), + (133, 43999894, 'O', 121289.63, '1997-11-29', '1-URGENT', 'Clerk#000737236', 0, 'usly final asymptotes '), + (134, 6197447, 'F', 247819.51, '1992-05-01', '4-NOT SPECIFIED', 'Clerk#000710943', 0, 'lar theodolites boos'), + (135, 60479711, 'O', 263545.46, '1995-10-21', '4-NOT SPECIFIED', 'Clerk#000803377', 0, 'l platelets use according t'), + (160, 82493404, 'O', 157852.72, '1996-12-19', '4-NOT SPECIFIED', 'Clerk#000341306', 0, 'thely special sauternes wake slyly of t'), + (161, 16618916, 'F', 21545.55, '1994-08-31', '2-HIGH', 'Clerk#000321258', 0, 'carefully! special instructions sin'), + (162, 14115415, 'O', 3538.21, '1995-05-08', '3-MEDIUM', 'Clerk#000377842', 0, 'nts hinder fluffily ironic instructions. express, express excuses '), + (163, 87758125, 'O', 187536.78, '1997-09-05', '3-MEDIUM', 'Clerk#000378308', 0, 'y final packages. final foxes since the quickly even'), + (164, 778084, 'F', 291093.87, '1992-10-21', '5-LOW', 'Clerk#000208621', 0, 'cajole ironic courts. slyly final ideas are slyly. blithely final Tiresias sub'), + (165, 27236141, 'F', 197970.00, '1993-01-30', '4-NOT SPECIFIED', 'Clerk#000291810', 0, 'across the blithely regular accounts. bold'), + (166, 107810401, 'O', 145568.40, '1995-09-12', '2-HIGH', 'Clerk#000439087', 0, 'lets. ironic, bold asymptotes kindle'), + (167, 119401594, 'F', 87123.86, '1993-01-04', '4-NOT SPECIFIED', 'Clerk#000730677', 0, 's nag furiously bold excuses. fluffily iron'), + (192, 82568518, 'O', 186074.82, '1997-11-25', '5-LOW', 'Clerk#000482777', 0, 'y unusual platelets among the final instructions integrate rut'), + (193, 79060019, 'F', 61949.53, '1993-08-08', '1-URGENT', 'Clerk#000024535', 0, 'the furiously final pin'), + (194, 61723415, 'F', 167377.65, '1992-04-05', '3-MEDIUM', 'Clerk#000351236', 0, 'egular requests haggle slyly regular, regular pinto beans. asymptote'), + (195, 135420649, 'F', 186359.22, '1993-12-28', '3-MEDIUM', 'Clerk#000215922', 0, 'old forges are furiously sheaves. slyly fi'), + (196, 64823140, 'F', 52884.74, '1993-03-17', '2-HIGH', 'Clerk#000987147', 0, 'beans boost at the foxes. silent foxes'), + (197, 32510872, 'P', 145084.27, '1995-04-07', '2-HIGH', 'Clerk#000968307', 0, 'solve quickly about the even braids. carefully express deposits affix care'), + (198, 110216159, 'O', 212946.03, '1998-01-02', '4-NOT SPECIFIED', 'Clerk#000330844', 0, 'its. carefully ironic requests sleep. furiously express fox'), + (199, 52969150, 'O', 137876.35, '1996-03-07', '2-HIGH', 'Clerk#000488764', 0, 'g theodolites. special packag'), + (224, 2474150, 'F', 249317.98, '1994-06-18', '4-NOT SPECIFIED', 'Clerk#000641410', 0, 'r the quickly thin courts. carefully'), + (225, 33029540, 'P', 238678.08, '1995-05-25', '1-URGENT', 'Clerk#000176282', 0, 's. blithely ironic accounts wake quickly fluffily special acc'), + (226, 127465336, 'F', 324618.04, '1993-03-10', '2-HIGH', 'Clerk#000755990', 0, 's are carefully at the blithely ironic acc'), + (227, 9882436, 'O', 69285.58, '1995-11-10', '5-LOW', 'Clerk#000918223', 0, ' express instructions. slyly regul'), + (228, 44077571, 'F', 3811.67, '1993-02-25', '1-URGENT', 'Clerk#000561590', 0, 'es was slyly among the regular foxes. blithely regular dependenci'), + (229, 111727408, 'F', 226752.27, '1993-12-29', '1-URGENT', 'Clerk#000627687', 0, 'he fluffily even instructions. furiously i'), + (230, 102534112, 'F', 196932.63, '1993-10-27', '1-URGENT', 'Clerk#000519929', 0, 'odolites. carefully quick requ'), + (231, 90817804, 'F', 176997.79, '1994-09-29', '2-HIGH', 'Clerk#000445936', 0, ' packages haggle slyly after the carefully ironic instruct'), + (256, 124830292, 'F', 158742.21, '1993-10-19', '4-NOT SPECIFIED', 'Clerk#000833325', 0, 'he fluffily final ideas might are final accounts. carefully f'), + (257, 122692442, 'O', 6462.29, '1998-03-28', '3-MEDIUM', 'Clerk#000679132', 0, 'ts against the sly warhorses cajole slyly accounts'), + (258, 41860129, 'F', 252298.36, '1993-12-29', '1-URGENT', 'Clerk#000166217', 0, 'dencies. blithely quick packages cajole. ruthlessly final accounts'), + (259, 43184669, 'F', 121177.65, '1993-09-29', '4-NOT SPECIFIED', 'Clerk#000600938', 0, 'ages doubt blithely against the final foxes. carefully express deposits dazzle'), + (260, 104726965, 'O', 252588.62, '1996-12-10', '3-MEDIUM', 'Clerk#000959560', 0, 'lently regular pinto beans sleep after the slyly e'), + (261, 46071896, 'F', 303454.20, '1993-06-29', '3-MEDIUM', 'Clerk#000309250', 0, 'ully fluffily brave instructions. furiousl'), + (262, 30351406, 'O', 198674.95, '1995-11-25', '4-NOT SPECIFIED', 'Clerk#000550312', 0, 'l packages. blithely final pinto beans use carefu'), + (263, 116068132, 'F', 136416.46, '1994-05-17', '2-HIGH', 'Clerk#000087768', 0, ' pending instructions. blithely un'), + (288, 7092508, 'O', 273779.21, '1997-02-21', '1-URGENT', 'Clerk#000108295', 0, 'uriously final requests. even, final ideas det'), + (289, 103749508, 'O', 215349.07, '1997-02-10', '3-MEDIUM', 'Clerk#000102799', 0, 'sily. slyly special excuse'), + (290, 117950323, 'F', 99516.47, '1994-01-01', '4-NOT SPECIFIED', 'Clerk#000734987', 0, 'efully dogged deposits. furiou'), + (291, 141049697, 'F', 91007.26, '1994-03-13', '1-URGENT', 'Clerk#000922485', 0, 'dolites. carefully regular pinto beans cajol'), + (292, 22251439, 'F', 52236.30, '1992-01-13', '2-HIGH', 'Clerk#000192458', 0, 'g pinto beans will have to sleep f'), + (293, 29928037, 'F', 67168.42, '1992-10-02', '2-HIGH', 'Clerk#000628280', 0, 're bold, ironic deposits. platelets c'), + (294, 50497948, 'F', 43160.20, '1993-07-16', '3-MEDIUM', 'Clerk#000498135', 0, 'kly according to the frays. final dolphins affix quickly '), + (295, 18983180, 'F', 140545.02, '1994-09-29', '2-HIGH', 'Clerk#000154315', 0, ' unusual pinto beans play. regular ideas haggle'), + (320, 301816, 'O', 50366.95, '1997-11-21', '2-HIGH', 'Clerk#000572041', 0, 'ar foxes nag blithely'), + (321, 122590799, 'F', 102660.33, '1993-03-21', '3-MEDIUM', 'Clerk#000288878', 0, 'equests run. blithely final dependencies after the deposits wake caref'), + (322, 133544900, 'F', 168881.52, '1992-03-19', '1-URGENT', 'Clerk#000157254', 0, 'fully across the slyly bold packages. packages against the quickly regular i'), + (323, 39132812, 'F', 138213.59, '1994-03-26', '1-URGENT', 'Clerk#000958670', 0, 'arefully pending foxes sleep blithely. slyly express accoun'), + (324, 105154997, 'F', 46204.82, '1992-03-20', '1-URGENT', 'Clerk#000351868', 0, ' about the ironic, regular deposits run blithely against the excuses'), + (325, 40023562, 'F', 103112.84, '1993-10-17', '5-LOW', 'Clerk#000843427', 0, 'ly sometimes pending pa'), + (326, 75985876, 'O', 302704.68, '1995-06-04', '2-HIGH', 'Clerk#000465044', 0, ' requests. furiously ironic asymptotes mold carefully alongside of the blit'), + (327, 144597874, 'P', 34736.85, '1995-04-17', '5-LOW', 'Clerk#000991369', 0, 'ng the slyly final courts. slyly even escapades eat '), + (352, 106455751, 'F', 18839.74, '1994-03-08', '2-HIGH', 'Clerk#000931484', 0, 'ke slyly bold pinto beans. blithely regular accounts against the spe'), + (353, 1775348, 'F', 265740.68, '1993-12-31', '5-LOW', 'Clerk#000448118', 0, ' quiet ideas sleep. even instructions cajole slyly. silently spe'), + (354, 138267985, 'O', 271440.62, '1996-03-14', '2-HIGH', 'Clerk#000510117', 0, 'ly regular ideas wake across the slyly silent ideas. final deposits eat b'), + (355, 70006357, 'F', 105671.63, '1994-06-14', '5-LOW', 'Clerk#000531085', 0, 's. sometimes regular requests cajole. regular, pending accounts a'), + (356, 146808497, 'F', 252238.28, '1994-06-30', '4-NOT SPECIFIED', 'Clerk#000943673', 0, 'as wake along the bold accounts. even, '), + (357, 60394615, 'O', 137799.94, '1996-10-09', '2-HIGH', 'Clerk#000300672', 0, 'e blithely about the express, final accounts. quickl'), + (358, 2289380, 'F', 351172.09, '1993-09-20', '2-HIGH', 'Clerk#000391917', 0, 'l, silent instructions are slyly. silently even de'), + (359, 77599394, 'F', 202583.31, '1994-12-19', '3-MEDIUM', 'Clerk#000933731', 0, 'n dolphins. special courts above the carefully ironic requests use'), + (384, 113008447, 'F', 182911.31, '1992-03-03', '5-LOW', 'Clerk#000205724', 0, ', even accounts use furiously packages. slyly ironic pla'), + (385, 32945209, 'O', 97475.23, '1996-03-22', '5-LOW', 'Clerk#000599075', 0, 'hless accounts unwind bold pain'), + (386, 60109081, 'F', 151257.58, '1995-01-25', '2-HIGH', 'Clerk#000647931', 0, ' haggle quickly. stealthily bold asymptotes haggle among the furiously even re'), + (387, 3295930, 'O', 202423.80, '1997-01-26', '4-NOT SPECIFIED', 'Clerk#000767975', 0, ' are carefully among the quickly even deposits. furiously silent req'), + (388, 44667518, 'F', 193416.33, '1992-12-16', '4-NOT SPECIFIED', 'Clerk#000355749', 0, 'ar foxes above the furiously ironic deposits nag slyly final reque'); diff --git a/tests/resources/datasets/part.sql b/tests/resources/datasets/part.sql new file mode 100644 index 0000000000..b360150ebf --- /dev/null +++ b/tests/resources/datasets/part.sql @@ -0,0 +1,116 @@ +-------------------------------------------------------------------------------- +CREATE TABLE part ( + p_partkey BIGINT, + p_name VARCHAR(500), + p_mfgr VARCHAR(500), + p_brand VARCHAR(500), + p_type VARCHAR(500), + p_size INT, + p_container VARCHAR(500), + p_retailprice DECIMAL(18,2), + p_comment VARCHAR(500) + ); + +-------------------------------------------------------------------------------- + +INSERT INTO part (p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment) VALUES +(449928, 'violet pale medium cyan maroon', 'Manufacturer#4', 'Brand#43', 'ECONOMY PLATED BRASS', 11, 'SM BAG', 1877.90, 'ely regular a'), + (640486, 'bisque sandy sky floral frosted', 'Manufacturer#1', 'Brand#15', 'LARGE POLISHED STEEL', 20, 'MED CASE', 1426.45, 'ts. final pinto bean'), + (1387319, 'plum azure coral spring maroon', 'Manufacturer#3', 'Brand#31', 'PROMO BURNISHED TIN', 4, 'MED BOX', 1406.25, 'nding, un'), + (1738077, 'puff dim powder misty mint', 'Manufacturer#2', 'Brand#24', 'STANDARD PLATED STEEL', 23, 'JUMBO CASE', 1114.99, 'ely ironic deposi'), + (2131495, 'aquamarine ghost peru forest rose', 'Manufacturer#1', 'Brand#15', 'MEDIUM BRUSHED COPPER', 11, 'MED CAN', 1526.39, 'ic ideas cajole q'), + (2319664, 'cornflower hot black tan steel', 'Manufacturer#1', 'Brand#14', 'ECONOMY POLISHED COPPER', 7, 'SM PACK', 1683.55, ' final'), + (2742061, 'firebrick chartreuse snow bisque peru', 'Manufacturer#2', 'Brand#22', 'SMALL BRUSHED TIN', 7, 'JUMBO JAR', 1102.93, 'across the quickly fin'), + (2866970, 'cream chartreuse turquoise cyan white', 'Manufacturer#4', 'Brand#41', 'SMALL PLATED TIN', 44, 'MED BAG', 1936.83, 'gedly r'), + (4296962, 'dark sky wheat thistle gainsboro', 'Manufacturer#1', 'Brand#11', 'PROMO BRUSHED STEEL', 43, 'LG BAG', 1958.75, ' requests nag bl'), + (7067007, 'ghost lemon blue turquoise yellow', 'Manufacturer#3', 'Brand#31', 'MEDIUM BRUSHED TIN', 50, 'WRAP CAN', 973.65, 'ly fluffy'), + (10425920, 'pink cyan royal cream lime', 'Manufacturer#1', 'Brand#14', 'MEDIUM PLATED NICKEL', 28, 'SM PACK', 1845.40, ' accounts use f'), + (87113927, 'chartreuse green lace slate black', 'Manufacturer#3', 'Brand#33', 'SMALL ANODIZED STEEL', 8, 'SM CAN', 1936.57, 'refully ironic fo'), + (82703512, 'brown hot olive plum burnished', 'Manufacturer#5', 'Brand#53', 'SMALL BRUSHED TIN', 3, 'WRAP PACK', 1511.38, 'sts! regular '), + (82757337, 'firebrick frosted tan blue navy', 'Manufacturer#1', 'Brand#14', 'STANDARD ANODIZED STEEL', 13, 'WRAP JAR', 1390.20, 'fluffily even instr'), + (34979160, 'smoke lawn linen red honeydew', 'Manufacturer#5', 'Brand#55', 'SMALL ANODIZED TIN', 16, 'MED CAN', 1237.42, 'ess, unusual deposit'), + (28081909, 'drab royal orange thistle lavender', 'Manufacturer#4', 'Brand#41', 'STANDARD ANODIZED NICKEL', 1, 'WRAP DRUM', 1889.50, ' fluffily requests'), + (94367239, 'peru purple sky sandy olive', 'Manufacturer#4', 'Brand#41', 'ECONOMY POLISHED BRASS', 30, 'LG BOX', 1301.52, ' bold requests caj'), + (11614679, 'blue turquoise ivory dark olive', 'Manufacturer#3', 'Brand#32', 'MEDIUM BRUSHED BRASS', 11, 'JUMBO CAN', 1593.09, ', iron'), + (24026634, 'rose mint beige sky tan', 'Manufacturer#5', 'Brand#53', 'STANDARD ANODIZED BRASS', 4, 'SM DRUM', 1559.43, 'ts. final asymptot'), + (88361363, 'burnished deep green drab medium', 'Manufacturer#5', 'Brand#51', 'SMALL BRUSHED STEEL', 6, 'JUMBO JAR', 1419.95, ' quickly'), + (88913503, 'sandy chartreuse burnished metallic beige', 'Manufacturer#5', 'Brand#51', 'SMALL BURNISHED TIN', 24, 'SM JAR', 1512.06, 'lieve bravely. final, '), + (60518681, 'khaki cornflower salmon cream slate', 'Manufacturer#4', 'Brand#41', 'STANDARD POLISHED STEEL', 41, 'LG CASE', 1696.66, 'ts. pending, '), + (39443990, 'misty lime white indian mint', 'Manufacturer#1', 'Brand#14', 'STANDARD POLISHED TIN', 37, 'SM BAG', 1932.02, 'deposits. alway'), + (61157984, 'orange lavender cornflower medium tan', 'Manufacturer#4', 'Brand#42', 'ECONOMY BRUSHED TIN', 44, 'WRAP CASE', 2038.93, 'ans sleep agains'), + (61335189, 'plum pink lemon indian lace', 'Manufacturer#3', 'Brand#33', 'MEDIUM PLATED NICKEL', 12, 'LG PKG', 1221.12, 'lly even platelet'), + (61930501, 'chartreuse pink coral blue rosy', 'Manufacturer#3', 'Brand#32', 'STANDARD PLATED STEEL', 37, 'SM JAR', 1528.41, 'phins'), + (62028678, 'blush moccasin smoke dodger pink', 'Manufacturer#1', 'Brand#12', 'ECONOMY PLATED COPPER', 25, 'SM DRUM', 1603.57, 's: quickly ironic pack'), + (62142591, 'blanched salmon steel sienna chartreuse', 'Manufacturer#3', 'Brand#31', 'MEDIUM POLISHED NICKEL', 31, 'WRAP BAG', 1630.49, 'n foxes'), + (53438259, 'ivory deep dark lime goldenrod', 'Manufacturer#2', 'Brand#21', 'MEDIUM BURNISHED TIN', 30, 'LG BAG', 1194.58, 'hely regular req'), + (79250148, 'almond ivory dim midnight metallic', 'Manufacturer#1', 'Brand#11', 'MEDIUM POLISHED COPPER', 3, 'SM BOX', 1094.18, 'te at the sly'), + (28430358, 'chiffon orange olive medium dodger', 'Manufacturer#2', 'Brand#23', 'PROMO PLATED TIN', 17, 'SM BOX', 1286.93, 'en packages could h'), + (29021558, 'blush medium lace peru puff', 'Manufacturer#4', 'Brand#44', 'LARGE POLISHED BRASS', 45, 'JUMBO CAN', 1478.10, '. furious'), + (45733155, 'pink papaya chiffon red tomato', 'Manufacturer#2', 'Brand#22', 'STANDARD BURNISHED STEEL', 10, 'JUMBO CASE', 1185.87, 'nto beans after the'), + (30761725, 'dodger spring orange blue chocolate', 'Manufacturer#3', 'Brand#32', 'MEDIUM POLISHED STEEL', 41, 'SM PKG', 1785.19, 'ily fluffily silent'), + (18503830, 'drab dark wheat orange pale', 'Manufacturer#1', 'Brand#13', 'PROMO BRUSHED NICKEL', 32, 'SM PKG', 1832.91, 'ons. '), + (96644449, 'slate antique floral white olive', 'Manufacturer#2', 'Brand#25', 'STANDARD ANODIZED NICKEL', 40, 'MED CASE', 1388.61, 's after t'), + (89413313, 'azure rose thistle blue dodger', 'Manufacturer#1', 'Brand#11', 'STANDARD POLISHED COPPER', 9, 'JUMBO JAR', 1221.84, 'oxes. thin acco'), + (89854644, 'tan plum cream medium magenta', 'Manufacturer#4', 'Brand#45', 'ECONOMY POLISHED STEEL', 37, 'SM JAR', 1594.15, ' pinto beans! blithel'), + (65915062, 'cyan chocolate frosted navy purple', 'Manufacturer#2', 'Brand#25', 'ECONOMY POLISHED NICKEL', 5, 'WRAP PACK', 1073.77, 'arefully about the '), + (44160335, 'snow smoke indian thistle gainsboro', 'Manufacturer#4', 'Brand#44', 'LARGE ANODIZED NICKEL', 37, 'JUMBO JAR', 1393.13, 'arhorses. furio'), + (44705610, 'cornflower wheat blue aquamarine lavender', 'Manufacturer#2', 'Brand#21', 'MEDIUM BRUSHED COPPER', 18, 'JUMBO CASE', 1613.38, 'h idly care'), + (85810176, 'orange cornflower chiffon sienna peach', 'Manufacturer#4', 'Brand#41', 'PROMO BURNISHED COPPER', 45, 'SM DRUM', 1081.88, ' about the'), + (88034684, 'rosy blanched peach khaki pale', 'Manufacturer#4', 'Brand#43', 'STANDARD BRUSHED BRASS', 17, 'LG DRUM', 1614.28, 'y. bold'), + (67830106, 'indian sienna cornflower chocolate dodger', 'Manufacturer#1', 'Brand#15', 'ECONOMY PLATED TIN', 10, 'WRAP JAR', 1032.71, 'sly across the'), + (77698944, 'rosy deep tomato tan white', 'Manufacturer#5', 'Brand#53', 'STANDARD BRUSHED TIN', 47, 'SM PKG', 1939.06, 'blithely carefu'), + (22629071, 'red puff sienna tan aquamarine', 'Manufacturer#1', 'Brand#11', 'PROMO ANODIZED BRASS', 14, 'JUMBO PKG', 998.94, 'refull'), + (55654148, 'blanched honeydew violet orange maroon', 'Manufacturer#4', 'Brand#42', 'ECONOMY POLISHED COPPER', 28, 'MED PACK', 1099.36, ' blith'), + (54518616, 'chocolate brown tan plum cream', 'Manufacturer#1', 'Brand#13', 'LARGE BURNISHED BRASS', 6, 'WRAP CASE', 1631.89, 'unts; '), + (73814565, 'yellow almond red drab aquamarine', 'Manufacturer#2', 'Brand#22', 'LARGE BRUSHED TIN', 5, 'JUMBO CAN', 1475.87, 'scapades. express'), + (33917488, 'metallic royal white navajo deep', 'Manufacturer#4', 'Brand#42', 'ECONOMY BURNISHED COPPER', 12, 'SM BAG', 1503.79, 'jole a'), + (34431883, 'burnished sky cream saddle royal', 'Manufacturer#3', 'Brand#32', 'STANDARD BURNISHED STEEL', 32, 'LG PKG', 1813.16, ' sleep furiously bold '), + (37130699, 'red violet sandy cornsilk peru', 'Manufacturer#4', 'Brand#45', 'PROMO BURNISHED STEEL', 35, 'MED PKG', 1727.84, 'usly acro'), + (37530180, 'grey turquoise deep ivory papaya', 'Manufacturer#2', 'Brand#23', 'MEDIUM BURNISHED COPPER', 10, 'LG BAG', 1208.31, 'deposits h'), + (38023053, 'snow slate lace orange light', 'Manufacturer#2', 'Brand#24', 'PROMO ANODIZED NICKEL', 42, 'SM PACK', 974.15, 'ously re'), + (20192396, 'cornflower black lavender deep dark', 'Manufacturer#4', 'Brand#41', 'LARGE POLISHED BRASS', 26, 'JUMBO PACK', 1487.39, 't slyly. re'), + (20589905, 'snow bisque magenta lemon orchid', 'Manufacturer#4', 'Brand#45', 'STANDARD ANODIZED BRASS', 3, 'LG DRUM', 1993.88, 'ep furiously b'), + (49567306, 'dark snow chiffon sky midnight', 'Manufacturer#2', 'Brand#21', 'LARGE BURNISHED TIN', 18, 'SM BAG', 1370.83, 'inal, ironic req'), + (15634450, 'peru cornsilk deep gainsboro maroon', 'Manufacturer#1', 'Brand#11', 'MEDIUM ANODIZED COPPER', 35, 'SM CASE', 1383.67, 'close foxes. express t'), + (12902948, 'wheat spring salmon green violet', 'Manufacturer#5', 'Brand#52', 'STANDARD BRUSHED BRASS', 16, 'WRAP BAG', 1950.30, 'even accoun'), + (168568384, 'yellow chiffon powder papaya honeydew', 'Manufacturer#1', 'Brand#13', 'MEDIUM POLISHED COPPER', 30, 'JUMBO PKG', 1443.96, 'ole fluffily a'), + (169237956, 'cornflower maroon frosted grey navy', 'Manufacturer#2', 'Brand#24', 'ECONOMY BRUSHED TIN', 15, 'MED JAR', 1885.49, 'ions wake regular, '), + (123075825, 'indian midnight azure ghost slate', 'Manufacturer#3', 'Brand#31', 'LARGE PLATED BRASS', 43, 'LG PACK', 1794.67, 'e the theodolites.'), + (123765936, 'papaya orchid tomato sienna blue', 'Manufacturer#5', 'Brand#52', 'ECONOMY PLATED STEEL', 37, 'LG BAG', 1995.75, 'ls wake carefully. bol'), + (123926789, 'rosy saddle burnished cyan salmon', 'Manufacturer#3', 'Brand#34', 'SMALL BRUSHED TIN', 33, 'SM PKG', 1809.59, 'ests ha'), + (195634217, 'blue forest gainsboro bisque royal', 'Manufacturer#5', 'Brand#55', 'MEDIUM POLISHED COPPER', 13, 'SM PKG', 1141.43, 'usly final packages us'), + (176278774, 'floral tan mint sky cyan', 'Manufacturer#3', 'Brand#32', 'LARGE POLISHED NICKEL', 15, 'SM CASE', 1743.96, 'o beans'), + (115117124, 'ivory gainsboro dark plum drab', 'Manufacturer#2', 'Brand#23', 'SMALL BURNISHED STEEL', 31, 'LG JAR', 1135.37, 'uickly bold as'), + (115208198, 'lawn cornsilk ghost snow peru', 'Manufacturer#4', 'Brand#45', 'LARGE PLATED COPPER', 50, 'SM DRUM', 1100.43, 'l theodolites among '), + (118281867, 'bisque almond cream magenta lime', 'Manufacturer#5', 'Brand#52', 'PROMO BURNISHED TIN', 42, 'LG DRUM', 1842.95, 'ithel'), + (163072047, 'blue green rosy ivory maroon', 'Manufacturer#4', 'Brand#44', 'MEDIUM PLATED COPPER', 47, 'SM BOX', 1010.89, 'r accoun'), + (163333041, 'plum puff slate turquoise misty', 'Manufacturer#5', 'Brand#55', 'PROMO BRUSHED COPPER', 16, 'WRAP BAG', 1065.88, 'efully abo'), + (139246458, 'spring orange moccasin cornsilk dark', 'Manufacturer#2', 'Brand#24', 'MEDIUM BURNISHED NICKEL', 46, 'JUMBO PKG', 1397.49, 'sual depen'), + (109742650, 'blush goldenrod brown drab cyan', 'Manufacturer#5', 'Brand#55', 'ECONOMY POLISHED STEEL', 42, 'MED DRUM', 1687.17, 'ges doubt. fluffi'), + (115634506, 'navy wheat turquoise chocolate yellow', 'Manufacturer#1', 'Brand#13', 'MEDIUM PLATED NICKEL', 37, 'JUMBO DRUM', 1434.72, ' regular excuses. f'), + (137468550, 'indian papaya saddle dim olive', 'Manufacturer#3', 'Brand#32', 'LARGE PLATED BRASS', 22, 'LG PKG', 1511.68, 'refully'), + (167179412, 'pale misty blue grey midnight', 'Manufacturer#1', 'Brand#14', 'ECONOMY ANODIZED NICKEL', 43, 'SM CAN', 1483.06, 'ng to the unusual'), + (155189345, 'orange midnight turquoise maroon antique', 'Manufacturer#4', 'Brand#43', 'ECONOMY POLISHED STEEL', 20, 'LG PACK', 1426.59, ' final accounts are f'), + (128815478, 'snow orange hot puff metallic', 'Manufacturer#1', 'Brand#14', 'MEDIUM BURNISHED BRASS', 38, 'LG CASE', 1387.03, 'above the p'), + (188251562, 'brown pink burlywood slate seashell', 'Manufacturer#2', 'Brand#23', 'MEDIUM BRUSHED NICKEL', 13, 'SM BOX', 1504.15, 'nooze'), + (197920162, 'ivory aquamarine burnished gainsboro bisque', 'Manufacturer#1', 'Brand#11', 'PROMO ANODIZED NICKEL', 43, 'WRAP BOX', 1172.27, 'nstruc'), + (108337010, 'blanched navy thistle rosy tan', 'Manufacturer#2', 'Brand#22', 'MEDIUM BRUSHED COPPER', 22, 'SM JAR', 1041.60, ' requests haggl'), + (108569283, 'almond yellow plum olive tan', 'Manufacturer#5', 'Brand#54', 'STANDARD POLISHED TIN', 49, 'LG BOX', 1346.86, 'ns sleep carefully'), + (106169722, 'tan honeydew floral navy dim', 'Manufacturer#3', 'Brand#35', 'ECONOMY PLATED TIN', 49, 'SM BOX', 1786.42, 'eep blithely. regula'), + (194657609, 'linen ivory yellow indian blanched', 'Manufacturer#2', 'Brand#23', 'LARGE POLISHED BRASS', 16, 'MED BOX', 1556.87, 'n asy'), + (134081534, 'linen purple peru aquamarine indian', 'Manufacturer#4', 'Brand#44', 'SMALL ANODIZED TIN', 44, 'LG PKG', 1508.83, 'ully unusual instructi'), + (178305451, 'goldenrod navy beige smoke almond', 'Manufacturer#3', 'Brand#34', 'STANDARD BURNISHED COPPER', 18, 'JUMBO DRUM', 1447.54, 'bold instru'), + (144001617, 'olive lavender misty forest hot', 'Manufacturer#2', 'Brand#22', 'SMALL ANODIZED STEEL', 4, 'JUMBO BAG', 1711.42, 'e slyly quick'), + (175179091, 'light forest violet blue plum', 'Manufacturer#5', 'Brand#55', 'MEDIUM PLATED NICKEL', 42, 'WRAP DRUM', 1161.34, 'r foxes! final ins'), + (173488357, 'ghost violet steel snow honeydew', 'Manufacturer#3', 'Brand#32', 'PROMO PLATED NICKEL', 24, 'LG CASE', 1336.68, 'ent packa'), + (173599543, 'linen tan chartreuse firebrick orchid', 'Manufacturer#2', 'Brand#24', 'LARGE PLATED COPPER', 43, 'LG DRUM', 1633.87, ' excuses. specia'), + (103254337, 'saddle aquamarine khaki white lace', 'Manufacturer#4', 'Brand#41', 'PROMO BRUSHED BRASS', 14, 'JUMBO DRUM', 1286.17, ' nag slyly accounts.'), + (103431682, 'blue slate navajo sienna maroon', 'Manufacturer#5', 'Brand#55', 'STANDARD ANODIZED BRASS', 27, 'JUMBO BOX', 1608.51, 'quickly pe'), + (164644985, 'peru medium maroon ivory lace', 'Manufacturer#3', 'Brand#32', 'SMALL ANODIZED TIN', 37, 'MED DRUM', 1921.75, 'counts wake fur'), + (119476978, 'firebrick slate bisque violet burlywood', 'Manufacturer#3', 'Brand#32', 'LARGE ANODIZED NICKEL', 22, 'WRAP PKG', 1949.00, 'thely ir'), + (186581058, 'white snow grey plum moccasin', 'Manufacturer#1', 'Brand#13', 'LARGE BRUSHED COPPER', 39, 'SM PKG', 1129.73, 'the unusual, regular '), + (169543048, 'ghost aquamarine steel dark pale', 'Manufacturer#2', 'Brand#22', 'MEDIUM ANODIZED BRASS', 43, 'WRAP CAN', 1082.57, 'efully-- pendi'), + (151893810, 'ivory coral lace royal rose', 'Manufacturer#4', 'Brand#45', 'MEDIUM ANODIZED NICKEL', 29, 'MED CAN', 1796.22, 'oxes integrate'), + (135389770, 'peru moccasin azure peach ivory', 'Manufacturer#1', 'Brand#15', 'PROMO BRUSHED BRASS', 6, 'WRAP CASE', 1853.01, 'c accounts. carefully'), + (140448567, 'lace khaki maroon indian blanched', 'Manufacturer#1', 'Brand#11', 'LARGE ANODIZED COPPER', 48, 'LG JAR', 1508.54, 'telets. re'), + (157237027, 'wheat maroon rosy pink spring', 'Manufacturer#4', 'Brand#43', 'LARGE ANODIZED STEEL', 8, 'WRAP BOX', 956.16, 'xcuses '); diff --git a/tests/resources/datasets/partsupp.sql b/tests/resources/datasets/partsupp.sql new file mode 100644 index 0000000000..d0383153b4 --- /dev/null +++ b/tests/resources/datasets/partsupp.sql @@ -0,0 +1,112 @@ +-------------------------------------------------------------------------------- +CREATE TABLE partsupp ( + ps_partkey BIGINT, + ps_suppkey BIGINT, + ps_availqty INT, + ps_supplycost DECIMAL(18,2), + ps_comment VARCHAR(500) + ); + +-------------------------------------------------------------------------------- + +INSERT INTO partsupp (ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment) VALUES + (449928, 2949929, 2780, 651.80, 'ts cajole slyly stealthy, fluffy accounts. pending deposits poac'), + (640486, 640487, 7505, 575.02, 'inst the stealthy, ironic accounts. slyly silent accounts wake. special, ironic pearls detect qui'), + (2866970, 5366971, 4245, 662.77, 'uests kindle dependencies. carefully regular notornis cajole blithely. final foxes hag'), + (1387319, 3887320, 8879, 856.46, 'boost above the express deposits. furiously blithe orbits sleep carefully. fluffily special deposits serve fluf'), + (2319664, 9819665, 4662, 803.81, 'sts. silent accounts sleep. final theodolites haggle above the furiou'), + (28081909, 5581916, 2377, 777.65, 'ke furiously blithely express deposits. quickly ironic requests nag final ideas. final dep'), + (34979160, 7479164, 2478, 318.93, 'refully. requests haggle slyly. slyly even deposits boost furiously. deposits among '), + (88361363, 861372, 1009, 870.94, ' slyly unusual instructions. even requests sleep express requests. bold accounts after the sauternes integrate a'), + (12902948, 5402950, 5661, 19.78, 's haggle carefully above the slyly regular requests. quickly even'), + (18503830, 1003832, 9817, 995.33, 'packages are-- carefully even pinto beans solve always slyly sly accounts. regular deposits detect slyly blithely special foxes. blithely idle foxes about the carefully express asymptote'), + (28430358, 8430359, 4870, 20.35, 'c, final requests boost quickly blithely pending theodolites: furiously final deposits are furiously pending deposits. quickly bold dolphins across the accounts wake furiously r'), + (157237027, 2237058, 9026, 74.12, ' beans. ideas haggle furiously slyly regular requests. quickly regular accounts haggle slyly carefully express pinto beans. carefully ironic theodolites haggle. packages boost furiously slyly '), + (34431883, 1931893, 3307, 147.59, 'osits wake with the accounts? carefully brave packages boost carefu'), + (194657609, 2157667, 8695, 529.40, 'lyly express deposits about the regular instructions haggle carefully against the ironic gifts. blit'), + (128815478, 8815479, 154, 55.49, ' the carefully daring packages. ironic accounts use fluff'), + (176278774, 3778826, 5177, 203.52, 'int fluffily ironic instructions. blithely express packages wake furiously unusual deposits! packages wak'), + (118281867, 5781901, 6451, 283.44, 'lly along the bold accounts; special, final packages cajole blithely according to the special deposits. br'), + (108337010, 837021, 384, 357.98, ' the carefully even deposits. furiously regular sentiments above the regular, unusual dependencies are carefully against the carefully even packages. special packages sleep busily al'), + (195634217, 634256, 5848, 808.62, 'dencies wake carefully against the blithely pending ideas. daringly unusual foxes sleep '), + (108569283, 8569284, 3407, 403.55, 'onic, regular deposits snooze-- carefully express requests cajole among the furiously regular dependencies. quiet, busy packages sleep blithely express accounts-- bold, unusual pinto beans hinde'), + (109742650, 7242681, 1705, 605.75, 'lyly unusual braids. fluffily slow asymptotes are carefully according to the blithely final theodolites. blithely special foxes sleep furiously quickly pending platelets. even pinto beans integ'), + (167179412, 7179413, 3388, 579.22, 'ckages. ironic requests wake. requests haggle carefully unusual accounts. blithely final ideas snooze blithely. carefully ruthless foxes was. sometimes even theodolites '), + (168568384, 3568417, 1289, 668.75, 'lithely alongside of the requests. special, brave foxes across the regular, ironic dependencies boost slyly against the express pinto beans. blithely express dolphins run slyly '), + (173599543, 6099561, 9401, 818.84, 'silent accounts. even accounts eat slyly. furiously ironic theodolites haggle. carefully unusual requests sleep fluffily! quickly bo'), + (169543048, 4543081, 2091, 855.72, ' across the furiously special ideas. blithely even instructions across the care'), + (1738077, 4238078, 1851, 22.23, 'sly. blithely express packages haggle quickly slyly unusual packages. instructions haggle about the slyly final foxes. bold, ironic gifts doubt quickly about the slyly p'), + (22629071, 5129074, 9256, 340.97, 'gular theodolites. furiously regular packages use slyly furiously final requests. dogged, final p'), + (82703512, 7703529, 4675, 95.01, 'fully bold packages need to wake carefully. quickly bold pinto beans engage bol'), + (82757337, 5257346, 2997, 206.84, 'lites wake slyly along the pending pinto beans. furious theodolites boost slyly above the ironic accounts. ideas boost. furiously pending accounts above '), + (15634450, 634453, 8198, 382.30, 'ideas among the slyly regular pinto beans wake across the bold pinto beans. pending ideas play. dolphins haggle finally above the quickly even pinto beans. special, silent pinto '), + (30761725, 3261729, 21, 426.39, 'ully. even foxes are carefully. pending, regular deposits nag. regular courts nag furiously. pinto beans sleep furiously around the regular packages. pending requests p'), + (2742061, 7742062, 150, 689.99, 'y furiously even deposits. furiously even packages haggle carefully. foxes s'), + (2131495, 4631496, 3823, 295.83, 'gular deposits play furiously. furiously ironic excuses believe carefully furiously express requests. regular, final requests cajole unu'), + (4296962, 1796963, 2568, 864.38, 'ly brave requests are against the final, pending requests'), + (55654148, 3154164, 7686, 713.99, ' packages wake about the slyly special requests. ironic requests beside the idly special d'), + (38023053, 3023060, 6072, 351.79, 'atelets use slyly. dependencies cajole permanently. fluffily special packages cajole furiously. deposits wake furiously. furiously regular theodolites according to the carefully even packag'), + (65915062, 3415081, 1233, 979.78, 'le furiously. furiously even excuses detect. pending, regular accounts along the slyly ironic deposits haggle daring foxes. requests against the bold, silent ideas affix fluffily final accounts.'), + (20589905, 3089908, 731, 17.46, 'lithely carefully unusual somas? furiously regular theodolites haggle pending, ironic packages. excuses nag across the fluffily regular pai'), + (37530180, 30184, 7192, 745.18, 's are carefully furiously regular braids. final packages wake. ironic sheaves haggle under the pinto '), + (11614679, 4114681, 3901, 529.90, 'sts. blithely even foxes sleep after the special foxes. final instructio'), + (7067007, 9567008, 4636, 581.75, 'egular accounts haggle blithely boldly regular packages. packages atop the blithe pinto beans thrash furiously alongside of the plat'), + (77698944, 5198966, 1244, 631.78, 'onic orbits boost slyly according to the carefully ironic excuses. carefully unus'), + (44705610, 4705611, 8387, 176.61, 'according to the carefully idle packages. slyly even ideas en'), + (49567306, 2067311, 9493, 214.81, 'lar deposits wake never after the regular pinto beans. always final depos'), + (88034684, 5534709, 7359, 160.97, 'ular ideas. fluffily regular packages cajole carefully across the quickly unusual theodolites'), + (29021558, 4021563, 2443, 688.46, 'ould have to use fluffily against the ironic foxes. slyly final pinto beans wake slyly regular accounts. regular requests integrate along the regular packages. unusual accounts along'), + (37130699, 2130706, 1104, 449.12, ' deposits affix across the quickly regular platelets. final instructions use. dar'), + (88913503, 3913520, 5780, 632.16, 'he even ideas. bold, busy notornis cajole slyly slowly fluffy dependencies-- fluffil'), + (39443990, 1943994, 9593, 11.66, ' accounts alongside of the express deposits sleep about the carefully final packages. even, special packages sleep ruthlessly furiously regular excuses. blithely '), + (94367239, 6867249, 9850, 969.34, '; quickly regular requests sleep slyly slyly final excuses. quickly final packa'), + (87113927, 4613952, 3758, 517.84, 'instructions are quickly ideas. fluffily final requests on the carefully final dependencies sleep quickly blithely even depths. slyly expr'), + (45733155, 733164, 4744, 483.76, 'are blithely at the pending packages. slyly even '), + (62142591, 9642610, 1770, 247.41, 'nts at the furiously even deposits cajole final instructions. blithely regular foxes are furiously slyly ironic packages. packages solve blithely. packages are slyly. spe'), + (89413313, 1913322, 894, 192.56, 'sly regular theodolites haggle. carefully regular deposits cajole qui'), + (96644449, 9144459, 179, 555.50, 'iously special packages after the ironic asymptotes wake quickly even excuses. quickly final instructions after the blithely regular escapades cajole blithely'), + (67830106, 5330125, 2411, 278.01, 'y. fluffily bold requests boost bold, bold dependencies. furiously bold pinto beans use. accounts wake slyly among the carefully ironic packages. carefully silent excuses haggle slyly '), + (73814565, 8814580, 7788, 435.56, 'e among the slyly even ideas. carefully bold sauternes sleep. accounts up the carefully ironic instructions maintain alongside of the enticingly even dolphins. ironic accounts '), + (61335189, 8835208, 7802, 874.02, 'structions sleep quickly. quickly pending foxes wake furiously final instructions. blithely ironic asymptotes haggle daringly regular theodoli'), + (24026634, 1526641, 3330, 853.01, 'ep quickly according to the slyly dogged excuses. blithely even deposits detect. fluffily ironic foxes after the express requests sleep since the regular '), + (60518681, 5518694, 6790, 545.24, 'n asymptotes. final, regular theodolites boost slyly silent instructions.'), + (44160335, 6660340, 4814, 679.36, 'boost past the theodolites. quickly final packages poach dur'), + (175179091, 2679143, 3628, 419.48, 'ccounts kindle carefully after the furiously special excuses. slyly ironic platelets thrash alongside'), + (115117124, 7617136, 8617, 990.39, 'ts dazzle blithely above the carefully dogged deposits. permanently ironic asy'), + (115208198, 7708210, 2766, 41.63, ' the regular, even requests. even theodolites haggle ruthlessly. furiously express accounts boost quickly caref'), + (85810176, 8310185, 9646, 986.82, 'ress theodolites haggle furiously final, enticing asymptotes. carefully even deposits use blithely blithely ironic foxes-- busily final dolphins across the carefully '), + (20192396, 5192401, 6090, 809.70, 'equests nag carefully-- instructions boost furiously. bold patterns according to the slyly bold pinto beans kindle careful'), + (115634506, 3134540, 3679, 119.52, 'sts. blithely ironic requests sleep idly along the blithely regular instructions. blithely pending dependencies sleep furiously after '), + (135389770, 5389771, 840, 944.81, 'ns are slyly express ideas: blithely even packages detect carefully about the pending warhorses. blithely unusual packages are regular, bold requests. furiously final packages c'), + (10425920, 2925922, 1094, 196.53, 'counts of the unusual requests cajole furiously after '), + (103254337, 5754348, 8265, 606.66, 'leep slyly. slyly pending instructions boost carefully. final requests cajole furiously. slyly unusual packages sleep packages. accounts sleep carefully fin'), + (103431682, 5931693, 619, 524.31, 'ly unusual platelets. carefully regular deposits promise slyly regular ideas. dugouts detect packages. slyly pending ideas unwind quickly. blithely even theodolites ar'), + (178305451, 805469, 5872, 956.83, 'manently furiously unusual instructions. regular, express packages haggle. theodolites haggle slyly beside the bold, express requests. silent platelets'), + (197920162, 420182, 8703, 528.71, 'ke blithely ironic asymptotes. slyly bold foxes detect carefully deposits. bold accounts wake blithely about the carefully bold asymptotes. accounts use slyly. blithely '), + (137468550, 9968564, 2822, 828.67, ' accounts integrate. regularly final theodolites cajole blithely requests. carefully ironic foxes are fluffily. furiously express asymptotes hang about the fluffily ironic accounts. slyly regular re'), + (163072047, 3072048, 2793, 685.77, 'lyly pending sentiments about the furiously dogged requests'), + (140448567, 2948582, 860, 820.96, ' regular requests. ironic, pending packages are blithely regular foxes. furiously ironic ideas across the fluffy requests grow quickly final foxes. special packages '), + (173488357, 3488358, 5288, 84.12, 'le sometimes. blithely permanent deposits x-ray slyly above the bravely even foxes. furiously even deposits haggle from the fluffily careful accounts. idle '), + (163333041, 833090, 4577, 324.55, 'as alongside of the furiously regular accounts boost slyly regular deposits. quiet excuses sleep quickly final accounts. furiously unusual deposits haggle slyly carefully ironic a'), + (123075825, 575862, 798, 466.45, 'ong the carefully regular packages. even dolphins wake furiously always ironic instructions. blithely even theodolites across the quickly ironic accounts hang slyly final theodolites. furiou'), + (155189345, 7689361, 590, 340.90, 'e furiously furiously even requests. blithely unusual dependencies among the special excuses nod above the'), + (53438259, 938275, 4742, 206.49, 't the accounts. evenly unusual theodolites detect besides the furiously final accounts. slyly regular requests sleep quickly. express deposits according to the even pains haggle slyly regular req'), + (119476978, 1976990, 1653, 564.95, 'ing theodolites are blithely according to the quickly fina'), + (186581058, 4081113, 7556, 495.87, 'ar pinto beans boost blithely; escapades wake requests. carefully unusual platelets alongside of the bold, busy requests use carefully quickly express ideas. furiously bold excuses are'), + (123765936, 3765937, 8557, 619.29, 'ic foxes sublate carefully pending foxes. bravely regular deposits cajole fur'), + (123926789, 3926790, 136, 440.46, 'lar accounts hang furiously according to the quickly special requests. theodolites outside the pinto beans x-ray fluffily silen'), + (139246458, 1746472, 7819, 308.47, 'l pinto beans could have to nag carefully blithely final '), + (54518616, 9518627, 5446, 265.36, 'theodolites sleep furiously. blithely ironic courts above the special dependencies sleep silent accounts; blithely unusual'), + (89854644, 7354669, 5233, 276.77, 'sts. even, regular instructions according to the furiously ironic courts sleep quickly ironic ideas; slyly silent packages might'), + (164644985, 9645018, 5182, 754.18, 'gular deposits cajole blithely against the fluffily regular epitaphs. accounts alongside of the foxes slee'), + (61930501, 1930502, 9706, 119.60, 'odolites are carefully alongside of the pending excuse'), + (62028678, 2028679, 1768, 430.22, 'ic asymptotes wake quickly above the blithely bold accounts. dolphins play above the regular packages. regular, regular accounts boost. slyly ironic deposits affix foxes-- fluff'), + (106169722, 1169743, 9365, 265.13, 'kages. carefully silent theodolites believe blithely about the blithely regular excuses. carefully final packages are furiou'), + (134081534, 1581574, 3356, 485.56, 'gly final accounts solve quickly. fluffily bold foxes are fluffily against the ruthl'), + (33917488, 3917489, 5569, 664.32, 'pending packages eat furiously along the slyly special deposits. doggedly '), + (151893810, 9393856, 1477, 268.37, 'nding packages haggle against the ideas. furiously enticing ideas affix final deposits. slyly special requests promise carefully. '), + (61157984, 8658003, 5557, 209.93, 'y bold foxes may nag quickly ironic, special platelets-- regular, bold deposits thrash across the special deposits. pending, special ideas nag blithely. quickly reg'), + (144001617, 4001618, 3167, 861.64, 'ar foxes. regular, ironic requests sleep permanently among the never bold requests. pending dependencies haggle across the slyly ironic grouches. close requests ar'), + (79250148, 1750156, 7498, 529.62, 'beans haggle fluffily according to the slyly regular asymptotes. slyly express accounts integrate against the quickly silent packages. special theodolites alongside of the regular'), + (169237956, 6738005, 5184, 360.72, 'elieve slyly final packages. slyly even pinto beans cajole stealthily. even deposits thrash slyly dolphins. blithely special d'), + (188251562, 3251599, 274, 857.43, 'equests. final requests cajole furiously final, regular deposits. requests nag along the slyly regular accounts. daring packages sleep quickly. regularl'); diff --git a/tests/resources/datasets/region.sql b/tests/resources/datasets/region.sql new file mode 100644 index 0000000000..6e4149059a --- /dev/null +++ b/tests/resources/datasets/region.sql @@ -0,0 +1,14 @@ +-------------------------------------------------------------------------------- +CREATE TABLE region ( + r_regionkey BIGINT, + r_name VARCHAR(500), + r_comment VARCHAR(500) +); +-------------------------------------------------------------------------------- + +INSERT INTO region (r_regionkey, r_name, r_comment) VALUES + (0, 'AFRICA', 'lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to '), + (1, 'AMERICA', 'hs use ironic, even requests. s'), + (2, 'ASIA', 'ges. thinly even pinto beans ca'), + (3, 'EUROPE', 'ly final courts cajole furiously final excuse'), + (4, 'MIDDLE EAST', 'uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl'); diff --git a/tests/resources/datasets/supplier.sql b/tests/resources/datasets/supplier.sql new file mode 100644 index 0000000000..402f9f136f --- /dev/null +++ b/tests/resources/datasets/supplier.sql @@ -0,0 +1,114 @@ +-------------------------------------------------------------------------------- +CREATE TABLE supplier ( + s_suppkey BIGINT, + s_name VARCHAR(500), + s_address VARCHAR(500), + s_nationkey VARCHAR(500), + s_phone VARCHAR(500), + s_acctbal DECIMAL(18,2), + s_comment VARCHAR(500) + ); + +-------------------------------------------------------------------------------- + +INSERT INTO supplier (s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment) VALUES + (30184, 'Supplier#000030184', 'Cj6lONc7GEnIFK', 6, '16-626-330-8040', 9262.48, ' Customer blithely regular pinto beans. slyly express ideas uRecommendsy'), + (634256, 'Supplier#000634256', 'p,iow4V1zcU6Q93CywpawZsLLXyu94', 23, '33-885-643-6448', 2212.11, 'ites integrate furiously across the slyly final ac'), + (634453, 'Supplier#000634453', '1ReLIcraal8hBSvF6o', 11, '21-689-946-5527', 9745.78, 'above the finally special packages. carefully pending deposits unwind'), + (640487, 'Supplier#000640487', 'gNQnqgSrIF5v9L1mqZab', 5, '15-909-462-8185', 7543.51, 'os alongside of the theodolites detect carefully even theodolit'), + (733164, 'Supplier#000733164', 'W0GmDtlZUpexLw7EKxAG', 13, '23-811-826-9078', 9167.25, 'oss the fluffily dogged asymptotes. final, pending requests boost al'), + (805469, 'Supplier#000805469', 'umZXt8nWpSg QOULl0Tr9Bexg5 0', 9, '19-142-118-8226', 7930.84, 'iously regular requests. furiously regular dependen'), + (833090, 'Supplier#000833090', 'PFql654teyZ,oN1aO43oQMixfRMpaL7TISWOIEL ', 1, '11-514-568-7497', 3418.87, 'its. regular, regular frets wake furiously. pending'), + (837021, 'Supplier#000837021', 'VkObuG2PjI 0C7JY52lgoPCo3heE0SW,w9a', 8, '18-377-707-4807', 609.94, 'es above the furiously unusual requests are c'), + (861372, 'Supplier#000861372', 'EpxxJ9VjaA8n6NRGvw uldqh51auUuq5lu6yQVN6', 9, '19-290-207-7734', 6267.18, 'xpress, even ideas. regular dolphins na'), + (575862, 'Supplier#000575862', 'mIE6AfQ1OQVt0BGv0I4ahuo7', 1, '11-552-378-5982', 1230.97, 'l, final foxes. dolphins use slyly f'), + (420182, 'Supplier#000420182', 'baI38Xr885NLuvqydr71qnVV', 19, '29-816-556-6872', 7205.50, 'ly. final foxes use according to the final accounts. unusual, even theodol'), + (938275, 'Supplier#000938275', 'vKvdB,0Jaycz3VHaYcV1qa', 3, '13-620-293-7481', 3379.93, 'carefully. blithely fluffy requests haggle. final instructions along the blithely'), + (1003832, 'Supplier#001003832', 'aGL9y eidMj', 6, '16-327-893-2142', 4984.40, 'jole after the packages. quick packages hinder. escapades boost slyl'), + (1796963, 'Supplier#001796963', 'j,zWwyj0DUET8DTS7iQPCAdFkSbempW59', 12, '22-639-278-6874', 6313.61, 'ccounts? blithely ironic waters wake against th'), + (9518627, 'Supplier#009518627', 'b8dB2K2r1,m,D eI1f8TZX5yVAN', 14, '24-125-338-4367', 3611.79, 'ts. deposits promise slyly. express, ironic ideas wake blithely bold dependencies. caref'), + (5129074, 'Supplier#005129074', '8eGmqE99LT8v2OF5,Th1oU', 12, '22-587-461-5549', 732.36, 'n deposits at the regular, pending accounts integrate blithely around the'), + (5754348, 'Supplier#005754348', 'HJQ,4uDORoGa5Mw8yP3UEkI', 3, '13-422-795-6458', 543.35, 'eans are behind the ironic, regular dinos. instructions cajo'), + (7479164, 'Supplier#007479164', 'vn6pKxDZmDC11qFuXc3mekqXyWHNyBK5NLVBG', 12, '22-570-676-4453', 9455.20, 'arefully ironic depths detect; regular, regular pl'), + (5192401, 'Supplier#005192401', ' AZhe8xXaHCOvqT6V', 13, '23-436-823-7057', 8192.12, 'ove the furiously ironic ideas. fluffily unusu'), + (5198966, 'Supplier#005198966', 'xiN6bjDo0zOmW3VBrtot9hTo26krwj', 22, '32-345-648-5995', 7539.50, 'cuses haggle fluffily on the foxes'), + (5257346, 'Supplier#005257346', 'lCfBlP veJWGU5wwhXKIM1lY8ERZsx17XY1', 6, '16-717-531-7295', 8811.17, '. carefully ironic accounts integrate quickly after'), + (6867249, 'Supplier#006867249', 'BLIAIpAOjb9suf', 20, '30-842-521-7951', 2591.45, 'y unusual ideas. pending courts boost alongside of the carefully silent deposits. furiously c'), + (1526641, 'Supplier#001526641', 'p3SdEwGqPRnaja,', 16, '26-270-848-6770', 2177.18, 'counts wake furiously upon th'), + (1581574, 'Supplier#001581574', 'VQAj0GfIRoVkIz6di9kby 7inc9', 1, '11-197-372-7269', 7849.68, ' shall affix against the furiously ironic asymptotes. regular escapades detect slyly'), + (2679143, 'Supplier#002679143', 'YO8of8M,MbsYUwRCLPkWT47', 11, '21-939-288-9961', 3957.06, 'gle quickly. carefully ironic foxes are furiously un'), + (8814580, 'Supplier#008814580', 'iPLQQ8eMBUM9P7Pc7JgL20xYRoScAkzBwlQtrF z', 24, '34-419-160-6605', -850.36, 'uests. carefully unusual instructions boost after the blithely specia'), + (8815479, 'Supplier#008815479', 'RjExqJUmCUXhfKewZIln6N', 18, '28-513-594-1369', -971.27, 'ost carefully along the regularly ironic warthogs. bold '), + (8835208, 'Supplier#008835208', 'S8d i03eRx9SP', 3, '13-971-361-3533', 3586.37, 'osits. regular, idle instructions cajole slyly bravely express packages. furiously even p'), + (2028679, 'Supplier#002028679', 'aULBzzOJMo ovVGsED7docm3ghg', 2, '12-382-293-4441', 2491.73, 'apades. deposits cajole daringly after the furiously even deposits. f'), + (2067311, 'Supplier#002067311', '7aLyj4pIkboYesGaXks3', 18, '28-206-936-1745', 212.13, 'of the blithely ironic deposits. final, final instructions hang furiously fluffily regular'), + (2130706, 'Supplier#002130706', 'kJSH021vt7WX3Xfn7NMAWK', 1, '11-111-354-4128', 1085.77, 'wake. daring requests cajole enticingly quick requests. requests are furiously across the express r'), + (7354669, 'Supplier#007354669', 'hKcUpT7d0XXVeRp,gk8MUMS1,tFcv4iDsRs', 22, '32-782-561-6838', 4214.96, 's about the regular deposits wake requests. carefully special packages above the quickly reg'), + (3913520, 'Supplier#003913520', 'EdT4Y0466ce', 5, '15-966-246-6640', 6107.67, 'oost stealthily above the quickly final ideas. '), + (3917489, 'Supplier#003917489', 'x6LQ2l3WFMqyBUZj', 14, '24-485-752-4470', 6692.07, 'lent theodolites. regular deposits about the regular, regular foxes d'), + (3926790, 'Supplier#003926790', 'Y,,YJX5MMrY1', 13, '23-960-312-1375', 3292.52, 'the blithely even dolphins. carefully ironic packages across the slyly special requests nag furious'), + (4001618, 'Supplier#004001618', 'EGfsFtHaqvP', 15, '25-258-126-6839', 7508.01, 'ously fluffily express grouches. blithely ironic asymptotes use ca'), + (7179413, 'Supplier#007179413', 'ivRp3FOePUCats3ybbKARn0x72NODHd0l', 19, '29-225-162-7910', 5888.47, 'furiously furiously ironic warhorses. unusual, ironic requests cajole quickly about t'), + (7242681, 'Supplier#007242681', 'QW4dDvFVa23PDO0cZUyTJJSAm5t', 15, '25-240-816-4136', 1241.32, 'carefully silent multipliers use about the fluffily special requests. caref'), + (8658003, 'Supplier#008658003', 'Z6pCtXUMNjTw fX87O', 21, '31-802-495-9045', -401.95, 'ously alongside of the final deposits? silent instructions haggle at the accounts. plate'), + (3023060, 'Supplier#003023060', 'gduPIt,DDJ3QJ NjbH', 16, '26-350-576-8263', 9281.34, 'y alongside of the furiously bold deposits. furiously ironic pinto beans according to the carefully '), + (3072048, 'Supplier#003072048', '9TYx252u0BrRLe5UOxZCa9,wlE6VA9', 19, '29-431-623-5626', 6385.14, 'ons. furiously regular packages according to the final foxes sleep furiously silent requests. '), + (3089908, 'Supplier#003089908', 'hxXnvsP2m9DMkZW', 17, '27-790-169-3158', 5929.76, 'ully above the carefully final requests. blit'), + (3134540, 'Supplier#003134540', 'oQv2qtwRQ54A0Y0OT0zlHZ,YtIb H4t', 17, '27-885-534-4104', -143.69, 'onic foxes are slyly after the regular deposi'), + (1746472, 'Supplier#001746472', 'a1HnwV wDHkgU226QD96G0b1dCOrNJtzd ZKh', 16, '26-306-814-8036', -238.44, 'egular instructions. unusual packages lose against the furiously even packages. ironic'), + (1750156, 'Supplier#001750156', 'fnp1WXp GGAJkAZmLh', 5, '15-302-324-2465', 3322.21, 'old ideas detect carefully according to the slyly regular dependenci'), + (5534709, 'Supplier#005534709', 'dG7uxj xMHRe0B1A3R6zZFyt', 12, '22-581-479-1675', 7710.11, 'ses. requests sleep furiously. carefully quiet asymptotes use quickly. careful'), + (5581916, 'Supplier#005581916', 'hFMdlBZgHk0hyuPn,', 13, '23-603-567-5086', 2177.86, 'uests. fluffily furious theodolites nag slyly ironic packages'), + (9144459, 'Supplier#009144459', 'bquL718C1bXhD9kj6WLH8', 1, '11-316-607-6437', 1699.52, ' special packages; furiously regular dependencies boost. carefully bold brai'), + (5330125, 'Supplier#005330125', 'yTKohN5z6Vo13EdkusomP0rtInJKw9WBOkR7XMX', 22, '32-115-946-1683', 1074.73, 'lly along the blithely regular excuses. regular pack'), + (5366971, 'Supplier#005366971', 'ibZ6URl4Nl0Z0ifLamX338vxsL,', 20, '30-637-352-2377', 597.65, 'blithely express packages boost never bold deposits. fluffily final deposits impress slyly bold '), + (5389771, 'Supplier#005389771', 'R9pa4R5 j5wKTFkkwZqaZNj', 22, '32-763-345-9809', 9433.32, 'eodolites. blithely ironic packages wake slyly against the fu'), + (1913322, 'Supplier#001913322', 'CUa8ycLrBfJT72RiwY CT Cjrzv1ZX', 12, '22-518-846-2149', 5113.37, 'yly pending instructions. blithely bold requests '), + (1930502, 'Supplier#001930502', 'rC0yyIzIfhdrS', 15, '25-961-841-7975', 6747.94, 'eep carefully among the blithely bold '), + (1931893, 'Supplier#001931893', 'Xbvx7p2P4Gk8vYX', 2, '12-717-704-6074', 2630.06, 'structions are across the final dinos. bravely bold pearls'), + (1943994, 'Supplier#001943994', 'ZKNLFVjgvOHarPkWcgajZYK ', 2, '12-733-694-8098', 7535.23, 'ilent deposits. slyly even e'), + (1976990, 'Supplier#001976990', '0OqALoA6wNXA0xc', 21, '31-587-639-6171', 147.77, 'accounts. quickly even packages will hav'), + (7689361, 'Supplier#007689361', 'exWrfV,eGK OIK2ZMWA0IeYQEGqtyoYY3r1vq', 18, '28-383-878-1195', -375.50, 's sleep slyly fluffily final deposits. blithely regular packages wake quickly d'), + (7703529, 'Supplier#007703529', 'T3vUWJuNbQzDuQFCgC,1 NTuDpozj7WooLvP', 7, '17-200-633-5636', 2044.89, 'n excuses detect slyly fi'), + (7708210, 'Supplier#007708210', '2tEVI8j6s4Hv6VJtQD,rw5FX 0v', 20, '30-142-976-8868', 2300.51, 'ecial ideas nag carefully above the bold sentiments. quickly final instructions boost. '), + (7742062, 'Supplier#007742062', 'LU99He0q7oq7HY0yWac8WCj', 16, '26-162-248-2839', 6037.68, 'pending, bold theodolites. sheaves haggle'), + (4543081, 'Supplier#004543081', 'ga55lpk3eEw2v9ED', 14, '24-589-396-7297', 7477.11, 'otes. unusual requests are from th'), + (4613952, 'Supplier#004613952', 'F7jSHm,ZYI7', 22, '32-296-196-9373', 480.27, 'larly around the slyly bold foxes. blithely bold pa'), + (4631496, 'Supplier#004631496', 'jMAM1WPNQ5EVARMF', 13, '23-265-483-4796', 442.65, 'y regular patterns. quickly even theodolites lose'), + (8310185, 'Supplier#008310185', ',qpWjGTUgjr8ghXPyjzmnInxhng5pitF9OTC', 19, '29-869-953-7543', 3053.33, 'he slyly final ideas sleep across the slyl'), + (5931693, 'Supplier#005931693', 'v0XbUWZyWIAfSvLLtr ,4 vDSi', 9, '19-920-377-7668', 1241.04, 'ccounts. bravely fluffy braids sleep. slyly express deposits affix blithely af'), + (6660340, 'Supplier#006660340', 'u1Oc1tH4G7PXMt0mUmDc1nGkVZWD6ckHqDz', 14, '24-478-444-5079', 5747.31, 'regular, final foxes. careful'), + (6738005, 'Supplier#006738005', 'hA9wSDZmPraLmZirNX8QTR', 1, '11-291-888-5954', 9158.70, 'ts. slyly special packages according to t'), + (4021563, 'Supplier#004021563', 'sgw87qPVeafBkTjDi3EGCs', 13, '23-223-951-9961', -826.07, 'c deposits nod carefully. slyly pending deposits accordi'), + (4081113, 'Supplier#004081113', '5CCTXFb7FTyO54WBrvgKPQ6JjRf', 19, '29-226-908-1611', 8002.20, 'al foxes maintain. final excuses use carefully across the bu'), + (4114681, 'Supplier#004114681', 'owZXGdiA2qL', 8, '18-645-618-6331', 4516.05, 'ly unusual deposits are fluffily bold, pending excuses. ironic packages haggle furiously regu'), + (3154164, 'Supplier#003154164', 'VRbaHmdYKBjlTNWYCwC', 23, '33-110-448-3837', 3164.09, 'ronic packages. quickly final requests wake furiously. c'), + (3251599, 'Supplier#003251599', 'OpDQD Ld5sgPawt,euTFMjJbkZe7JiUuFIbuv', 19, '29-180-457-5770', 6521.57, '. special, final frays integrate quickly unusual pinto beans. carefully regular requests integrate. '), + (3261729, 'Supplier#003261729', ' J0p1cH,dffvGg fzmcovB', 21, '31-346-575-1608', 3010.20, 'se carefully about the pending braids. instructions sleep furiously bold requests'), + (9819665, 'Supplier#009819665', 'cw5IO6m705D8U8xMY4kBSmQzy895hLj', 0, '10-863-742-7262', 6097.66, 'ing requests. quickly enticing accounts need to nag slyly quickl'), + (1169743, 'Supplier#001169743', '0kMLux4TAp1 BCEphU,uyT4ptJx', 0, '10-979-714-9103', 5282.99, ' furiously ironic packages sleep carefully carefully pending '), + (3765937, 'Supplier#003765937', 'rGbsQvBQZWFedJN4p1k4HGoUAZRQ01h', 22, '32-974-149-1769', 6081.69, 'riously final theodolites '), + (3778826, 'Supplier#003778826', 'ZpY4XKwJpVabm5VZyb2ibB8BVD7wrPmZK', 12, '22-315-325-3702', 3155.37, 'kages cajole furiously ironic foxes. deposits about the final excuses h'), + (3887320, 'Supplier#003887320', 'YA ctuGCtgMM6vicqOHB7evL,ucLm', 15, '25-640-226-6606', 1367.62, 'oost carefully. blithely regular asymptotes are. carefully silent somas a'), + (6099561, 'Supplier#006099561', 'gnOv6vMqE95hQQrTImSloOAJWTqm98f', 21, '31-652-306-4562', 5240.90, 'fily silent pinto beans. bold, ironic requests abo'), + (8430359, 'Supplier#008430359', 'jV 2f9K7GIBWEXmEwtCiAezRT', 18, '28-246-899-4357', 9140.28, 'uests print packages. carefully unusual theodolites are regular platelet'), + (4238078, 'Supplier#004238078', 'pMVOoRb,eFYxXN5MyRQ5zVx61j', 13, '23-313-447-8012', 6781.42, 'efully furiously even requests. caref'), + (4705611, 'Supplier#004705611', 'Rw0w0l,HBlQcELBeW jL2RT5A', 2, '12-196-336-6788', 2419.13, 's. furiously regular ideas cajole never s'), + (9567008, 'Supplier#009567008', 'yNf1Cb8spTPI', 14, '24-933-404-6385', 8811.96, 'ts are carefully alongside of the carefully bold requests. blithely bold packages are f'), + (9642610, 'Supplier#009642610', 'r5qZkUp2PgDzpzh6', 6, '16-932-842-8046', 9000.51, 'ainst the stealthy ideas. slyly ironic requests are slyly. furiously final packages nag furiou'), + (9645018, 'Supplier#009645018', '3YOElF4xanxqaOaZySO1d8Z,rM0pkX1UlP', 6, '16-957-117-6777', 4865.79, 'ages. carefully ironic pinto beans eat blithely regular ide'), + (7617136, 'Supplier#007617136', 'gawA2tJvw6JbuG,osn3MZ3s1Gtwj', 10, '20-796-394-1809', 6053.34, 'leep blithely bravely unusual deposits. even deposits according to the express accounts ought t'), + (8569284, 'Supplier#008569284', 'YVvvxiCKSmBjOT', 18, '28-743-299-1091', 298.21, 'express pinto beans. regular, final requests cajole carefully. unusu'), + (3568417, 'Supplier#003568417', 'cLLFgVVVDd2J19QEQsyl', 11, '21-601-981-6077', 2080.55, 'ss the slyly ironic ideas sleep blithely pending instructions. special deposits boost blithely spec'), + (5402950, 'Supplier#005402950', 'ROrDmKh8vkoPCRVMX', 7, '17-348-710-9287', 8438.73, 'even theodolites haggle. daring foxes wake furiously special acc'), + (5518694, 'Supplier#005518694', 'S7VBptFyETj', 2, '12-586-778-1563', 7544.26, 'ainst the regular instructions. special, ironic accounts are slyly regular warthogs. bold '), + (2157667, 'Supplier#002157667', 'oFYlRQL2P1rGw1a', 13, '23-324-269-2980', 5197.19, 'hely. stealthy theodolites'), + (2237058, 'Supplier#002237058', 'mF6FCu5pWxBbFTO0', 6, '16-679-470-2385', 2483.07, ' packages. blithely unusual requests doubt slyly wi'), + (5781901, 'Supplier#005781901', ' h72xq5k4NhinR', 9, '19-478-445-6420', 7591.65, 'refully regular instructions use toward the fluffily s'), + (3415081, 'Supplier#003415081', '1pdGc1Ew9ClfLqNEP6NfMkBtKEFkW5zV', 15, '25-473-439-1803', 1713.85, 'c deposits. slyly ironic platelets are carefully carefully special plate'), + (3488358, 'Supplier#003488358', 'cZcnlzBrf0eGzpk6gEC3voY1w0', 1, '11-105-263-7524', 8256.38, 'atelets. quickly ironic dependencies detect. special packages c'), + (2925922, 'Supplier#002925922', 'VehGItTuVY60U39YR5Jsw', 24, '34-430-539-7866', 8617.32, 'es cajole carefully across'), + (2948582, 'Supplier#002948582', '0wD0zy FbWyftoPUhpC2OhCvSEmN3fi1k', 19, '29-308-329-9077', -776.49, 'oldly ironic foxes. carefully even deposits affix quickly. furiously ironic ac'), + (2949929, 'Supplier#002949929', 'hmQfNRFC2s91z,Oz6Z', 9, '19-967-661-3107', -751.86, 'fix slyly. quickly ruthless theodolites impress furiously pending platelets. ironic dolphins '), + (9393856, 'Supplier#009393856', 'qoL9lXWn0Q1OXdQcNRoy LPqIQnXkvZ', 24, '34-878-515-8337', -727.40, 'lites. fluffily even accounts after the pending pinto beans solve fluffily above the reques'), + (9968564, 'Supplier#009968564', 'hkqraabXr gOinA', 23, '33-338-836-8120', 3434.20, ' stealthily unusual accounts kindle carefully quickly final'); diff --git a/tests/resources/functional/bigquery/datetime/format_date_1.sql b/tests/resources/functional/bigquery/datetime/format_date_1.sql new file mode 100644 index 0000000000..2c35e067cb --- /dev/null +++ b/tests/resources/functional/bigquery/datetime/format_date_1.sql @@ -0,0 +1,77 @@ +-- bigquery sql: +SELECT + FORMAT_DATE("%A", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS full_weekday_name, + FORMAT_DATE("%a", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS abbreviated_weekday_name, + FORMAT_DATE("%B", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS full_month_name, + FORMAT_DATE("%b", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS abbreviated_month_name, + FORMAT_DATE("%C", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS century, + FORMAT_DATE("%c", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS date_time_representation, + FORMAT_DATE("%D", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS date_mm_dd_yy, + FORMAT_DATE("%d", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS day_of_month_two_digits, + FORMAT_DATE("%e", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS day_of_month_single_digit, + FORMAT_DATE("%F", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS iso_8601_date, + FORMAT_DATE("%H", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS hour_24, + FORMAT_DATE("%h", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS abbreviated_month_name_duplicate, + FORMAT_DATE("%I", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS hour_12, + FORMAT_DATE("%j", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS day_of_year, + FORMAT_DATE("%k", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS hour_24_no_leading_zero, + FORMAT_DATE("%l", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS hour_12_no_leading_zero, + FORMAT_DATE("%M", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS minutes, + FORMAT_DATE("%m", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS month_2_digits, + FORMAT_DATE("%P", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS am_pm_lowercase, + FORMAT_DATE("%p", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS am_pm_uppercase, + FORMAT_DATE("%Q", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS quarter, + FORMAT_DATE("%R", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS time_hh_mm, + FORMAT_DATE("%S", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS seconds, + FORMAT_DATE("%s", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS ephoc_seconds, + FORMAT_DATE("%T", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS time_hh_mm_ss, + FORMAT_DATE("%u", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS iso_weekday_monday_start, + FORMAT_DATE("%V", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS iso_week_number, + FORMAT_DATE("%w", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS weekday_sunday_start, + FORMAT_DATE("%X", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS time_representation, + FORMAT_DATE("%x", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS date_representation, + FORMAT_DATE("%Y", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS year_with_century, + FORMAT_DATE("%y", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS year_without_century, + FORMAT_DATE("%Z", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS time_zone_name, + FORMAT_DATE("%Ez", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS rfc3339_time_zone, + FORMAT_DATE("%E*S", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS seconds_with_full_fractional, + FORMAT_DATE("%E4Y", TIMESTAMP("2008-12-28 16:44:12.277404 UTC")) AS four_character_years; + +-- databricks sql: +SELECT + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'EEEE') AS full_weekday_name, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'EEE') AS abbreviated_weekday_name, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MMMM') AS full_month_name, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MMM') AS abbreviated_month_name, + ROUND(EXTRACT(YEAR FROM TIMESTAMP('2008-12-28 16:44:12.277404 UTC')) / 100) AS century, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'EEE MMM dd HH:mm:ss yyyy') AS date_time_representation, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MM/dd/yy') AS date_mm_dd_yy, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'dd') AS day_of_month_two_digits, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'd') AS day_of_month_single_digit, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'yyyy-MM-dd') AS iso_8601_date, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'HH') AS hour_24, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MMM') AS abbreviated_month_name_duplicate, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'hh') AS hour_12, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'DDD') AS day_of_year, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'H') AS hour_24_no_leading_zero, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'h') AS hour_12_no_leading_zero, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'mm') AS minutes, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MM') AS month_2_digits, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'a') AS am_pm_lowercase, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'a') AS am_pm_uppercase, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'q') AS quarter, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'HH:mm') AS time_hh_mm, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'ss') AS seconds, + UNIX_TIMESTAMP(TIMESTAMP('2008-12-28 16:44:12.277404 UTC')) AS ephoc_seconds, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'HH:mm:ss') AS time_hh_mm_ss, + EXTRACT(DAYOFWEEK_ISO FROM TIMESTAMP('2008-12-28 16:44:12.277404 UTC')) AS iso_weekday_monday_start, + EXTRACT(W FROM TIMESTAMP('2008-12-28 16:44:12.277404 UTC')) AS iso_week_number, + EXTRACT(DAYOFWEEK FROM TIMESTAMP('2008-12-28 16:44:12.277404 UTC')) - 1 AS weekday_sunday_start, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'HH:mm:ss') AS time_representation, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'MM/dd/yy') AS date_representation, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'yyyy') AS year_with_century, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'yy') AS year_without_century, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'z') AS time_zone_name, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'xxx') AS rfc3339_time_zone, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'ss.SSSSSS') AS seconds_with_full_fractional, + DATE_FORMAT(TIMESTAMP('2008-12-28 16:44:12.277404 UTC'), 'yyyy') AS four_character_years; diff --git a/tests/resources/functional/oracle/test_long_datatype/test_long_datatype_1.sql b/tests/resources/functional/oracle/test_long_datatype/test_long_datatype_1.sql new file mode 100644 index 0000000000..9ff4e4d587 --- /dev/null +++ b/tests/resources/functional/oracle/test_long_datatype/test_long_datatype_1.sql @@ -0,0 +1,6 @@ + +-- oracle sql: +SELECT cast(col1 as long) AS col1 FROM dual; + +-- databricks sql: +SELECT CAST(COL1 AS STRING) AS COL1 FROM DUAL; diff --git a/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_1.sql b/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_1.sql new file mode 100644 index 0000000000..439e2a252a --- /dev/null +++ b/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_1.sql @@ -0,0 +1,15 @@ +-- presto sql: +SELECT + any_keys_match( + map(array ['a', 'b', 'c'], array [1, 2, 3]), + x -> x = 'a' + ) as col; + +-- databricks sql: +SELECT + EXISTS( + MAP_KEYS( + MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(1, 2, 3)) + ), + x -> x = 'a' + ) AS col; diff --git a/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_2.sql b/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_2.sql new file mode 100644 index 0000000000..a7afbe212a --- /dev/null +++ b/tests/resources/functional/presto/test_any_keys_match/test_any_keys_match_2.sql @@ -0,0 +1,25 @@ +-- presto sql: +SELECT + *, + any_keys_match( + metadata, + k -> ( + k LIKE 'config_%' + OR k = 'active' + ) + ) AS has_config_or_active +FROM + your_table; + +-- databricks sql: +SELECT + *, + EXISTS( + MAP_KEYS(metadata), + k -> ( + k LIKE 'config_%' + OR k = 'active' + ) + ) AS has_config_or_active +FROM + your_table; diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_1.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_1.sql new file mode 100644 index 0000000000..a123ba3288 --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_1.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, 0.5) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, 0.5) FROM people; diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_2.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_2.sql new file mode 100644 index 0000000000..e0e9b94e01 --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_2.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, 0.5, 0.01) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, 0.5, 10000) FROM people; diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_3.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_3.sql new file mode 100644 index 0000000000..5b6981dd1a --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_3.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, ARRAY[0.25, 0.5, 0.75]) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, ARRAY(0.25, 0.5, 0.75)) FROM people; diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_4.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_4.sql new file mode 100644 index 0000000000..8b156243d1 --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_4.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, ARRAY[0.25, 0.5, 0.75], 0.5) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, ARRAY(0.25, 0.5, 0.75), 200) FROM people; diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_6.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_6.sql new file mode 100644 index 0000000000..eae5642aab --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_6.sql @@ -0,0 +1,7 @@ + +-- presto sql: +SELECT approx_percentile(height, weight, 0.5, 0.09) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, weight, 0.5, 1111 ) FROM people; + diff --git a/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_8.sql b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_8.sql new file mode 100644 index 0000000000..81f9335ce1 --- /dev/null +++ b/tests/resources/functional/presto/test_approx_percentile/test_approx_percentile_8.sql @@ -0,0 +1,7 @@ + +-- presto sql: +SELECT approx_percentile(height, weight, ARRAY[0.25, 0.5, 0.75], 0.9) FROM people; + +-- databricks sql: +SELECT approx_percentile(height, weight, ARRAY(0.25, 0.5, 0.75), 111 ) FROM people; + diff --git a/tests/resources/functional/presto/test_array_average/test_array_average_1.sql b/tests/resources/functional/presto/test_array_average/test_array_average_1.sql new file mode 100644 index 0000000000..38db44d6ec --- /dev/null +++ b/tests/resources/functional/presto/test_array_average/test_array_average_1.sql @@ -0,0 +1,40 @@ +-- presto sql: +select + id, + sum(array_average(arr)) as sum_arr +FROM + ( + SELECT + 1 as id, + ARRAY [1,2,3] AS arr + UNION + SELECT + 2 as id, + ARRAY [10.20,20.108,30.4,40.0] as arr + ) AS t +group by + id; + +-- databricks sql: +SELECT + id, + SUM( + AGGREGATE( + FILTER(arr, x -> x IS NOT NULL), + NAMED_STRUCT('sum', CAST(0 AS DOUBLE), 'cnt', 0), + (acc, x) -> NAMED_STRUCT('sum', acc.sum + x, 'cnt', acc.cnt + 1), + acc -> TRY_DIVIDE(acc.sum, acc.cnt) + ) + ) AS sum_arr +FROM + ( + SELECT + 1 AS id, + ARRAY(1, 2, 3) AS arr + UNION + SELECT + 2 AS id, + ARRAY(10.20, 20.108, 30.4, 40.0) AS arr + ) AS t +GROUP BY + id; diff --git a/tests/resources/functional/presto/test_cast_as_json/test_cast_as_json_1.sql b/tests/resources/functional/presto/test_cast_as_json/test_cast_as_json_1.sql new file mode 100644 index 0000000000..5b80cab642 --- /dev/null +++ b/tests/resources/functional/presto/test_cast_as_json/test_cast_as_json_1.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT CAST(extra AS JSON) FROM dual; + +-- databricks sql: +SELECT CAST(extra AS STRING) FROM dual; diff --git a/tests/resources/functional/presto/test_format_datetime/test_format_datetime_1.sql b/tests/resources/functional/presto/test_format_datetime/test_format_datetime_1.sql new file mode 100644 index 0000000000..f4a175724e --- /dev/null +++ b/tests/resources/functional/presto/test_format_datetime/test_format_datetime_1.sql @@ -0,0 +1,13 @@ +-- presto sql: +select + format_datetime(current_timestamp,'EEEE') as col1 +, format_datetime(current_date,'EEEE') as col2 +, format_datetime(from_unixtime(1732723200), 'hh:mm:ss a') as col3 +, format_datetime(from_unixtime(1732723200), 'yyyy-MM-dd HH:mm:ss EEEE') as col4; + +-- databricks sql: +SELECT + DATE_FORMAT(CURRENT_TIMESTAMP(), 'EEEE') AS col1, + DATE_FORMAT(CURRENT_DATE(), 'EEEE') AS col2, + DATE_FORMAT(CAST(FROM_UNIXTIME(1732723200) AS TIMESTAMP), 'hh:mm:ss a') AS col3, + DATE_FORMAT(CAST(FROM_UNIXTIME(1732723200) AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss EEEE') AS col4 diff --git a/tests/resources/functional/presto/test_if/test_if_1.sql b/tests/resources/functional/presto/test_if/test_if_1.sql new file mode 100644 index 0000000000..2b811329ad --- /dev/null +++ b/tests/resources/functional/presto/test_if/test_if_1.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT if(cond, 1, 0) FROM dual; + +-- databricks sql: +SELECT IF(cond, 1, 0) FROM dual; diff --git a/tests/resources/functional/presto/test_if/test_if_2.sql b/tests/resources/functional/presto/test_if/test_if_2.sql new file mode 100644 index 0000000000..11d5e49fde --- /dev/null +++ b/tests/resources/functional/presto/test_if/test_if_2.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT if(cond, 1) FROM dual; + +-- databricks sql: +SELECT IF(cond, 1, NULL) FROM dual; diff --git a/tests/resources/functional/presto/test_json_extract/test_json_extract_1.sql b/tests/resources/functional/presto/test_json_extract/test_json_extract_1.sql new file mode 100644 index 0000000000..a3cc18fb73 --- /dev/null +++ b/tests/resources/functional/presto/test_json_extract/test_json_extract_1.sql @@ -0,0 +1,16 @@ +-- presto sql: +select + "json_extract"(params, '$.query') query, + "json_extract"(params, '$.dependencies') IS NOT NULL AS TEST, + "json_extract"(params, '$.dependencies') IS NULL AS TEST1, + "json_extract"(params, '$.dependencies') AS TEST2 +FROM + drone_job_manager dm; + +-- databricks sql: +SELECT + params:query AS query, + params:dependencies IS NOT NULL AS TEST, + params:dependencies IS NULL AS TEST1, + params:dependencies AS TEST2 +FROM drone_job_manager AS dm diff --git a/tests/resources/functional/presto/test_json_size/test_json_size_1.sql b/tests/resources/functional/presto/test_json_size/test_json_size_1.sql new file mode 100644 index 0000000000..b2aecb0019 --- /dev/null +++ b/tests/resources/functional/presto/test_json_size/test_json_size_1.sql @@ -0,0 +1,76 @@ +-- presto sql: +SELECT + json_size(col, path) as j_size, + id +from + ( + select + '{"x": {"a": 1, "b": 2}}' as col, + '$.x' as path, + 1 as id + union + select + '{"x": [1, 2, 3]}' as col, + '$.x' as path, + 2 as id + union + select + '{"x": {"a": 1, "b": 2}}' as col, + '$.x.a' as path, + 3 as id + union + select + '42' as col, + '$' as path, + 4 as id + union + select + 'invalid json' as col, + '$' as path, + 5 as id + ) tmp +order by + id; + +-- databricks sql: +SELECT + CASE + WHEN GET_JSON_OBJECT(col, path) LIKE '{%' THEN SIZE( + FROM_JSON(GET_JSON_OBJECT(col, path), 'map') + ) + WHEN GET_JSON_OBJECT(col, path) LIKE '[%' THEN SIZE( + FROM_JSON(GET_JSON_OBJECT(col, path), 'array') + ) + WHEN GET_JSON_OBJECT(col, path) IS NOT NULL THEN 0 + ELSE NULL + END AS j_size, + id +from + ( + select + '{"x": {"a": 1, "b": 2}}' as col, + '$.x' as path, + 1 as id + union + select + '{"x": [1, 2, 3]}' as col, + '$.x' as path, + 2 as id + union + select + '{"x": {"a": 1, "b": 2}}' as col, + '$.x.a' as path, + 3 as id + union + select + '42' as col, + '$' as path, + 4 as id + union + select + 'invalid json' as col, + '$' as path, + 5 as id + ) as tmp +order by + id nulls last; diff --git a/tests/resources/functional/presto/test_nested_json_with_cast/test_nested_json_with_cast_1.sql b/tests/resources/functional/presto/test_nested_json_with_cast/test_nested_json_with_cast_1.sql new file mode 100644 index 0000000000..13c61e8a74 --- /dev/null +++ b/tests/resources/functional/presto/test_nested_json_with_cast/test_nested_json_with_cast_1.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT AVG(CAST(json_extract_scalar(CAST(extra AS JSON), '$.satisfaction_rating') AS INT)) FROM dual; + +-- databricks sql: +SELECT AVG(CAST(CAST(extra AS STRING):satisfaction_rating AS INT)) FROM dual; diff --git a/tests/resources/functional/presto/test_strpos/test_strpos_1.sql b/tests/resources/functional/presto/test_strpos/test_strpos_1.sql new file mode 100644 index 0000000000..c6203b34d0 --- /dev/null +++ b/tests/resources/functional/presto/test_strpos/test_strpos_1.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT strpos('Hello world', 'l', 2); + +-- databricks sql: +SELECT LOCATE('l', 'Hello world', 2); diff --git a/tests/resources/functional/presto/test_strpos/test_strpos_2.sql b/tests/resources/functional/presto/test_strpos/test_strpos_2.sql new file mode 100644 index 0000000000..d950825b19 --- /dev/null +++ b/tests/resources/functional/presto/test_strpos/test_strpos_2.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT CASE WHEN strpos(greeting_message, 'hello') > 0 THEN 'Contains hello' ELSE 'Does not contain hello' END FROM greetings_table; + +-- databricks sql: +SELECT CASE WHEN LOCATE('hello', greeting_message) > 0 THEN 'Contains hello' ELSE 'Does not contain hello' END FROM greetings_table; diff --git a/tests/resources/functional/presto/test_unnest/test_unnest_1.sql b/tests/resources/functional/presto/test_unnest/test_unnest_1.sql new file mode 100644 index 0000000000..d9ce67baa2 --- /dev/null +++ b/tests/resources/functional/presto/test_unnest/test_unnest_1.sql @@ -0,0 +1,17 @@ +-- presto sql: +SELECT + * +FROM + default.sync_gold + CROSS JOIN UNNEST(engine_waits) t (col1, col2) +WHERE + ("cardinality"(engine_waits) > 0); + +-- databricks sql: +SELECT + * +FROM + default.sync_gold LATERAL VIEW EXPLODE(engine_waits) As col1, + col2 +WHERE + (SIZE(engine_waits) > 0); diff --git a/tests/resources/functional/presto/test_unnest/test_unnest_2.sql b/tests/resources/functional/presto/test_unnest/test_unnest_2.sql new file mode 100644 index 0000000000..d8a8707a2d --- /dev/null +++ b/tests/resources/functional/presto/test_unnest/test_unnest_2.sql @@ -0,0 +1,31 @@ +-- presto sql: +SELECT + day, + build_number, + error, + error_count +FROM + sch.tab + CROSS JOIN UNNEST(CAST(extra AS map(varchar, integer))) e (error, error_count) +WHERE + ( + (event_type = 'fp_daemon_crit_errors_v2') + AND (error_count > 0) + ); + +-- databricks sql: +SELECT + day, + build_number, + error, + error_count +FROM + sch.tab LATERAL VIEW EXPLODE(CAST(extra AS MAP)) As error, + error_count +WHERE + ( + ( + event_type = 'fp_daemon_crit_errors_v2' + ) + AND (error_count > 0) + ); diff --git a/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_5.sql b/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_5.sql new file mode 100644 index 0000000000..8d0116f423 --- /dev/null +++ b/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_5.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, 0.5, 'non_integer') FROM people; + +-- databricks sql: +SELECT approx_percentile(height, 0.5, 0.01) FROM people; diff --git a/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_7.sql b/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_7.sql new file mode 100644 index 0000000000..65016293a1 --- /dev/null +++ b/tests/resources/functional/presto_expected_exceptions/test_approx_percentile_7.sql @@ -0,0 +1,6 @@ + +-- presto sql: +SELECT approx_percentile(height, weight, 0.5, 'non_integer') FROM people; + +-- databricks sql: +SELECT approx_percentile(height, weight, 0.5, 0.01) FROM people; diff --git a/tests/resources/functional/snowflake/aggregates/least_1.sql b/tests/resources/functional/snowflake/aggregates/least_1.sql new file mode 100644 index 0000000000..9247a0dfdd --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/least_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT least(col1) AS least_col1 FROM tabl; + +-- databricks sql: +SELECT LEAST(col1) AS least_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_1.sql b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_1.sql new file mode 100644 index 0000000000..fcac4173b6 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_1.sql @@ -0,0 +1,10 @@ +-- snowflake sql: +SELECT LISTAGG(col1, ' ') FROM test_table WHERE col2 > 10000; + +-- databricks sql: +SELECT + ARRAY_JOIN(ARRAY_AGG(col1), ' ') +FROM test_table +WHERE + col2 > 10000 +; diff --git a/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_2.sql b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_2.sql new file mode 100644 index 0000000000..5211f6c608 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_2.sql @@ -0,0 +1,8 @@ +-- snowflake sql: +SELECT LISTAGG(col1) FROM test_table; + +-- databricks sql: +SELECT + ARRAY_JOIN(ARRAY_AGG(col1), '') +FROM test_table +; diff --git a/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_3.sql b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_3.sql new file mode 100644 index 0000000000..a06d19a6e2 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_3.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT LISTAGG(DISTINCT col3, '|') + FROM test_table WHERE col2 > 10000; + +-- databricks sql: +SELECT + ARRAY_JOIN(ARRAY_AGG(DISTINCT col3), '|') +FROM test_table +WHERE + col2 > 10000 +; diff --git a/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_4.sql b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_4.sql new file mode 100644 index 0000000000..11e80f1e60 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/listagg/test_listagg_4.sql @@ -0,0 +1,28 @@ +-- snowflake sql: +SELECT col3, listagg(col4, ', ') WITHIN GROUP (ORDER BY col2 DESC) +FROM +test_table +WHERE col2 > 10000 GROUP BY col3; + +-- databricks sql: +SELECT + col3, + ARRAY_JOIN( + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col2)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN 1 + WHEN left.sort_by_0 > right.sort_by_0 THEN -1 + ELSE 0 + END + ), + s -> s.value + ), + ', ' + ) +FROM test_table +WHERE + col2 > 10000 +GROUP BY + col3; diff --git a/tests/resources/functional/snowflake/aggregates/test_booland_agg_1.sql b/tests/resources/functional/snowflake/aggregates/test_booland_agg_1.sql new file mode 100644 index 0000000000..c63cc79423 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_booland_agg_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select booland_agg(k) from bool_example; + +-- databricks sql: +SELECT BOOL_AND(k) FROM bool_example; diff --git a/tests/resources/functional/snowflake/aggregates/test_booland_agg_2.sql b/tests/resources/functional/snowflake/aggregates/test_booland_agg_2.sql new file mode 100644 index 0000000000..feeae02ea5 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_booland_agg_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select s2, booland_agg(k) from bool_example group by s2; + +-- databricks sql: +SELECT s2, BOOL_AND(k) FROM bool_example GROUP BY s2; diff --git a/tests/resources/functional/snowflake/aggregates/test_count_1.sql b/tests/resources/functional/snowflake/aggregates/test_count_1.sql new file mode 100644 index 0000000000..bce3802eb9 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_count_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT count(col1) AS count_col1 FROM tabl; + +-- databricks sql: +SELECT COUNT(col1) AS count_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/aggregates/test_count_if_1.sql b/tests/resources/functional/snowflake/aggregates/test_count_if_1.sql new file mode 100644 index 0000000000..1644f13324 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_count_if_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT COUNT_IF(j_col > i_col) FROM basic_example; + +-- databricks sql: +SELECT COUNT_IF(j_col > i_col) FROM basic_example; diff --git a/tests/resources/functional/snowflake/aggregates/test_dense_rank_1.sql b/tests/resources/functional/snowflake/aggregates/test_dense_rank_1.sql new file mode 100644 index 0000000000..7558c9545c --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_dense_rank_1.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + dense_rank() OVER ( + PARTITION BY col1 + ORDER BY + col2 + ) AS dense_rank_col1 +FROM + tabl; + +-- databricks sql: +SELECT + DENSE_RANK() OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST + ) AS dense_rank_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/aggregates/test_dense_rank_2.sql b/tests/resources/functional/snowflake/aggregates/test_dense_rank_2.sql new file mode 100644 index 0000000000..8614ba97f0 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_dense_rank_2.sql @@ -0,0 +1,21 @@ +-- snowflake sql: +SELECT + dense_rank() OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS dense_rank_col1 +FROM + tabl; + +-- databricks sql: +SELECT + DENSE_RANK() OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS dense_rank_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/aggregates/test_greatest_1.sql b/tests/resources/functional/snowflake/aggregates/test_greatest_1.sql new file mode 100644 index 0000000000..691a41eec0 --- /dev/null +++ b/tests/resources/functional/snowflake/aggregates/test_greatest_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT greatest(col_1, col_2, col_3) AS greatest_col1 FROM tabl; + +-- databricks sql: +SELECT GREATEST(col_1, col_2, col_3) AS greatest_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_append_1.sql b/tests/resources/functional/snowflake/arrays/test_array_append_1.sql new file mode 100644 index 0000000000..251e0d4317 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_append_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_append(array, elem) AS array_append_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_APPEND(array, elem) AS array_append_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_cat_1.sql b/tests/resources/functional/snowflake/arrays/test_array_cat_1.sql new file mode 100644 index 0000000000..7b768b7b7a --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_cat_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select array_cat(col1, col2) FROM tbl; + +-- databricks sql: +SELECT CONCAT(col1, col2) FROM tbl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_cat_2.sql b/tests/resources/functional/snowflake/arrays/test_array_cat_2.sql new file mode 100644 index 0000000000..a9d268be86 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_cat_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select array_cat([1, 3], [2, 4]); + +-- databricks sql: +SELECT CONCAT(ARRAY(1, 3), ARRAY(2, 4)); diff --git a/tests/resources/functional/snowflake/arrays/test_array_compact_1.sql b/tests/resources/functional/snowflake/arrays/test_array_compact_1.sql new file mode 100644 index 0000000000..268c85f022 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_compact_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_compact(col1) AS array_compact_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_COMPACT(col1) AS array_compact_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_construct_1.sql b/tests/resources/functional/snowflake/arrays/test_array_construct_1.sql new file mode 100644 index 0000000000..60d29e13e7 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_construct_1.sql @@ -0,0 +1,32 @@ +-- snowflake sql: +WITH users AS ( + SELECT + 1 AS user_id, + ARRAY_CONSTRUCT('item1', 'item2', 'item3') AS items + UNION ALL + SELECT + 2 AS user_id, + ARRAY_CONSTRUCT('itemA', 'itemB') AS items +) +SELECT + user_id, + value AS item +FROM + users, + LATERAL FLATTEN(input => items) as value; + +-- databricks sql: +WITH users AS ( + SELECT + 1 AS user_id, + ARRAY('item1', 'item2', 'item3') AS items + UNION ALL + SELECT + 2 AS user_id, + ARRAY('itemA', 'itemB') AS items +) +SELECT + user_id, + value AS item +FROM users + LATERAL VIEW EXPLODE(items) AS value; diff --git a/tests/resources/functional/snowflake/arrays/test_array_construct_2.sql b/tests/resources/functional/snowflake/arrays/test_array_construct_2.sql new file mode 100644 index 0000000000..e712cfeb4a --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_construct_2.sql @@ -0,0 +1,40 @@ +-- snowflake sql: +WITH orders AS ( + SELECT + 101 AS order_id, + ARRAY_CONSTRUCT( + OBJECT_CONSTRUCT('product_id', 1, 'name', 'ProductA'), + OBJECT_CONSTRUCT('product_id', 2, 'name', 'ProductB') + ) AS order_details + UNION ALL + SELECT + 102 AS order_id, + ARRAY_CONSTRUCT( + OBJECT_CONSTRUCT('product_id', 3, 'name', 'ProductC') + ) AS order_details +) +SELECT + order_id, + value AS product +FROM + orders, + LATERAL FLATTEN(input => order_details) as value; + +-- databricks sql: +WITH orders AS ( + SELECT + 101 AS order_id, + ARRAY( + STRUCT(1 AS product_id, 'ProductA' AS name), + STRUCT(2 AS product_id, 'ProductB' AS name) + ) AS order_details + UNION ALL + SELECT + 102 AS order_id, + ARRAY(STRUCT(3 AS product_id, 'ProductC' AS name)) AS order_details +) +SELECT + order_id, + value AS product +FROM + orders LATERAL VIEW EXPLODE(order_details) AS value diff --git a/tests/resources/functional/snowflake/arrays/test_array_construct_compact_1.sql b/tests/resources/functional/snowflake/arrays/test_array_construct_compact_1.sql new file mode 100644 index 0000000000..2f6359b3e6 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_construct_compact_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAY_CONSTRUCT(null, 'hello', 3::double, 4, 5); + +-- databricks sql: +SELECT ARRAY(NULL, 'hello', CAST(3 AS DOUBLE), 4, 5); diff --git a/tests/resources/functional/snowflake/arrays/test_array_construct_compact_2.sql b/tests/resources/functional/snowflake/arrays/test_array_construct_compact_2.sql new file mode 100644 index 0000000000..c465c2421d --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_construct_compact_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAY_CONSTRUCT_COMPACT(null, 'hello', 3::double, 4, 5); + +-- databricks sql: +SELECT ARRAY_EXCEPT(ARRAY(NULL, 'hello', CAST(3 AS DOUBLE), 4, 5), ARRAY(NULL)); diff --git a/tests/resources/functional/snowflake/arrays/test_array_contains_1.sql b/tests/resources/functional/snowflake/arrays/test_array_contains_1.sql new file mode 100644 index 0000000000..1a0ca8201a --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_contains_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_contains(33, arr_col) AS array_contains_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_CONTAINS(arr_col, 33) AS array_contains_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_distinct_1.sql b/tests/resources/functional/snowflake/arrays/test_array_distinct_1.sql new file mode 100644 index 0000000000..be46d7a68c --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_distinct_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_distinct(col1) AS array_distinct_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_DISTINCT(col1) AS array_distinct_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_except_1.sql b/tests/resources/functional/snowflake/arrays/test_array_except_1.sql new file mode 100644 index 0000000000..a261e1eb51 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_except_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_except(a, b) AS array_except_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_EXCEPT(a, b) AS array_except_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_intersection_1.sql b/tests/resources/functional/snowflake/arrays/test_array_intersection_1.sql new file mode 100644 index 0000000000..40a141a07f --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_intersection_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAY_INTERSECTION(col1, col2); + +-- databricks sql: +SELECT ARRAY_INTERSECT(col1, col2); diff --git a/tests/resources/functional/snowflake/arrays/test_array_intersection_2.sql b/tests/resources/functional/snowflake/arrays/test_array_intersection_2.sql new file mode 100644 index 0000000000..1a9337259d --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_intersection_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAY_INTERSECTION(ARRAY_CONSTRUCT(1, 2, 3), ARRAY_CONSTRUCT(1, 2)); + +-- databricks sql: +SELECT ARRAY_INTERSECT(ARRAY(1, 2, 3), ARRAY(1, 2)); diff --git a/tests/resources/functional/snowflake/arrays/test_array_position_1.sql b/tests/resources/functional/snowflake/arrays/test_array_position_1.sql new file mode 100644 index 0000000000..9dd59bb912 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_position_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_position(col1, array1) AS array_position_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_POSITION(col1, array1) AS array_position_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_prepend_1.sql b/tests/resources/functional/snowflake/arrays/test_array_prepend_1.sql new file mode 100644 index 0000000000..11422346d4 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_prepend_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_prepend(array, elem) AS array_prepend_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_PREPEND(array, elem) AS array_prepend_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_remove_1.sql b/tests/resources/functional/snowflake/arrays/test_array_remove_1.sql new file mode 100644 index 0000000000..0b8d79004b --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_remove_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_remove(array, element) AS array_remove_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_REMOVE(array, element) AS array_remove_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_size_1.sql b/tests/resources/functional/snowflake/arrays/test_array_size_1.sql new file mode 100644 index 0000000000..2f93088cf6 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_size_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_size(col1) AS array_size_col1 FROM tabl; + +-- databricks sql: +SELECT SIZE(col1) AS array_size_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_array_slice_1.sql b/tests/resources/functional/snowflake/arrays/test_array_slice_1.sql new file mode 100644 index 0000000000..ffe27227d7 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_slice_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_slice(array_construct(0,1,2,3,4,5,6), 0, 2); + +-- databricks sql: +SELECT SLICE(ARRAY(0, 1, 2, 3, 4, 5, 6), 1, 2); diff --git a/tests/resources/functional/snowflake/arrays/test_array_slice_2.sql b/tests/resources/functional/snowflake/arrays/test_array_slice_2.sql new file mode 100644 index 0000000000..ed34df96de --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_slice_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_slice(array_construct(90,91,92,93,94,95,96), -5, 3); + +-- databricks sql: +SELECT SLICE(ARRAY(90, 91, 92, 93, 94, 95, 96), -5, 3); diff --git a/tests/resources/functional/snowflake/arrays/test_array_to_string_1.sql b/tests/resources/functional/snowflake/arrays/test_array_to_string_1.sql new file mode 100644 index 0000000000..98b6576bcd --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_array_to_string_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAY_TO_STRING(ary_column1, '') AS no_separation FROM tbl; + +-- databricks sql: +SELECT ARRAY_JOIN(ary_column1, '') AS no_separation FROM tbl; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_1.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_1.sql new file mode 100644 index 0000000000..194c2dae10 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +select array_agg(col1) FROM test_table; + +-- databricks sql: +SELECT ARRAY_AGG(col1) FROM test_table; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_2.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_2.sql new file mode 100644 index 0000000000..64f3b4c523 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_2.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT ARRAY_AGG(DISTINCT col2) WITHIN GROUP (ORDER BY col2 ASC) +FROM test_table +WHERE col3 > 10000; + +-- databricks sql: +SELECT + SORT_ARRAY(ARRAY_AGG(DISTINCT col2)) +FROM test_table +WHERE + col3 > 10000; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_3.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_3.sql new file mode 100644 index 0000000000..7e0969f194 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_3.sql @@ -0,0 +1,12 @@ +-- snowflake sql: + +SELECT ARRAY_AGG(col2) WITHIN GROUP (ORDER BY col2 ASC) +FROM test_table +WHERE col3 > 10000; + +-- databricks sql: +SELECT + SORT_ARRAY(ARRAY_AGG(col2)) +FROM test_table +WHERE + col3 > 10000; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_4.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_4.sql new file mode 100644 index 0000000000..794d5218aa --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_4.sql @@ -0,0 +1,30 @@ +-- snowflake sql: +SELECT + col2, + ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3 DESC) +FROM test_table +WHERE col3 > 450000 +GROUP BY col2 +ORDER BY col2 DESC; + +-- databricks sql: +SELECT + col2, + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN 1 + WHEN left.sort_by_0 > right.sort_by_0 THEN -1 + ELSE 0 + END + ), + s -> s.value + ) +FROM test_table +WHERE + col3 > 450000 +GROUP BY + col2 +ORDER BY + col2 DESC NULLS FIRST; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_5.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_5.sql new file mode 100644 index 0000000000..ac19c9730c --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_5.sql @@ -0,0 +1,30 @@ +-- snowflake sql: +SELECT + col2, + ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3) +FROM test_table +WHERE col3 > 450000 +GROUP BY col2 +ORDER BY col2 DESC; + +-- databricks sql: + SELECT + col2, + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN -1 + WHEN left.sort_by_0 > right.sort_by_0 THEN 1 + ELSE 0 + END + ), + s -> s.value + ) + FROM test_table + WHERE + col3 > 450000 + GROUP BY + col2 + ORDER BY + col2 DESC NULLS FIRST; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_6.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_6.sql new file mode 100644 index 0000000000..0c433015a3 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_6.sql @@ -0,0 +1,7 @@ +-- snowflake sql: +SELECT ARRAY_AGG(col2) WITHIN GROUP (ORDER BY col2 DESC) FROM test_table; + +-- databricks sql: +SELECT + SORT_ARRAY(ARRAY_AGG(col2), false) +FROM test_table; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_7.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_7.sql new file mode 100644 index 0000000000..138a23dbad --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_7.sql @@ -0,0 +1,38 @@ +-- snowflake sql: +WITH cte AS ( + SELECT + id, + tag, + SUM(tag:count) AS item_count + FROM another_table +) +SELECT +id +, ARRAY_AGG(tag) WITHIN GROUP(ORDER BY item_count DESC) AS agg_tags +FROM cte +GROUP BY 1; + +-- databricks sql: +WITH cte AS ( + SELECT + id, + tag, + SUM(tag:count) AS item_count + FROM another_table + ) + SELECT + id, + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', tag, 'sort_by_0', item_count)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN 1 + WHEN left.sort_by_0 > right.sort_by_0 THEN -1 + ELSE 0 + END + ), + s -> s.value + ) AS agg_tags + FROM cte + GROUP BY + 1; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_8.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_8.sql new file mode 100644 index 0000000000..7ad991ff7f --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_8.sql @@ -0,0 +1,32 @@ +-- snowflake sql: +SELECT + col2, + ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3, col5) +FROM test_table +WHERE col3 > 450000 +GROUP BY col2 +ORDER BY col2 DESC; + +-- databricks sql: + SELECT + col2, + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3, 'sort_by_1', col5)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN -1 + WHEN left.sort_by_0 > right.sort_by_0 THEN 1 + WHEN left.sort_by_1 < right.sort_by_1 THEN -1 + WHEN left.sort_by_1 > right.sort_by_1 THEN 1 + ELSE 0 + END + ), + s -> s.value + ) + FROM test_table + WHERE + col3 > 450000 + GROUP BY + col2 + ORDER BY + col2 DESC NULLS FIRST; diff --git a/tests/resources/functional/snowflake/arrays/test_arrayagg_9.sql b/tests/resources/functional/snowflake/arrays/test_arrayagg_9.sql new file mode 100644 index 0000000000..d0440b9077 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrayagg_9.sql @@ -0,0 +1,32 @@ +-- snowflake sql: +SELECT + col2, + ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3, col5 DESC) +FROM test_table +WHERE col3 > 450000 +GROUP BY col2 +ORDER BY col2 DESC; + +-- databricks sql: +SELECT + col2, + TRANSFORM( + ARRAY_SORT( + ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3, 'sort_by_1', col5)), + (left, right) -> CASE + WHEN left.sort_by_0 < right.sort_by_0 THEN -1 + WHEN left.sort_by_0 > right.sort_by_0 THEN 1 + WHEN left.sort_by_1 < right.sort_by_1 THEN 1 + WHEN left.sort_by_1 > right.sort_by_1 THEN -1 + ELSE 0 + END + ), + s -> s.value + ) +FROM test_table +WHERE + col3 > 450000 +GROUP BY + col2 +ORDER BY + col2 DESC NULLS FIRST; diff --git a/tests/resources/functional/snowflake/arrays/test_arrays_overlap_1.sql b/tests/resources/functional/snowflake/arrays/test_arrays_overlap_1.sql new file mode 100644 index 0000000000..b1e2b09dac --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_arrays_overlap_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ARRAYS_OVERLAP(ARRAY_CONSTRUCT(1, 2, NULL), ARRAY_CONSTRUCT(3, NULL, 5)); + +-- databricks sql: +SELECT ARRAYS_OVERLAP(ARRAY(1, 2, NULL), ARRAY(3, NULL, 5)); diff --git a/tests/resources/functional/snowflake/arrays/test_concat_ws_1.sql b/tests/resources/functional/snowflake/arrays/test_concat_ws_1.sql new file mode 100644 index 0000000000..02a09b1b2f --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_concat_ws_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT CONCAT_WS(',', 'one', 'two', 'three'); + +-- databricks sql: +SELECT CONCAT_WS(',', 'one', 'two', 'three'); diff --git a/tests/resources/functional/snowflake/arrays/test_extract_1.sql b/tests/resources/functional/snowflake/arrays/test_extract_1.sql new file mode 100644 index 0000000000..70faafe8ef --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_extract_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT extract(week FROM col1) AS extract_col1 FROM tabl; + +-- databricks sql: +SELECT EXTRACT(week FROM col1) AS extract_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_flatten_1.sql b/tests/resources/functional/snowflake/arrays/test_flatten_1.sql new file mode 100644 index 0000000000..3f263f9492 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_flatten_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT flatten(col1) AS flatten_col1 FROM tabl; + +-- databricks sql: +SELECT EXPLODE(col1) AS flatten_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/arrays/test_get_1.sql b/tests/resources/functional/snowflake/arrays/test_get_1.sql new file mode 100644 index 0000000000..b242ea5777 --- /dev/null +++ b/tests/resources/functional/snowflake/arrays/test_get_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT get(col1, idx1) AS get_col1 FROM tabl; + +-- databricks sql: +SELECT GET(col1, idx1) AS get_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/cast/test_cast_date.sql b/tests/resources/functional/snowflake/cast/test_cast_date.sql new file mode 100644 index 0000000000..74dfc054dd --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_date.sql @@ -0,0 +1,7 @@ +-- snowflake sql: +SELECT + '2024-01-01'::DATE AS date_val + +-- databricks sql: +SELECT + CAST('2024-01-01' AS DATE) AS date_val; diff --git a/tests/resources/functional/snowflake/cast/test_cast_decimal.sql b/tests/resources/functional/snowflake/cast/test_cast_decimal.sql new file mode 100644 index 0000000000..74560062e7 --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_decimal.sql @@ -0,0 +1,13 @@ +-- snowflake sql: +SELECT + 12345::DECIMAL(10, 2) AS decimal_val, + 12345::NUMBER(10, 2) AS number_val, + 12345::NUMERIC(10, 2) AS numeric_val, + 12345::BIGINT AS bigint_val + +-- databricks sql: +SELECT + CAST(12345 AS DECIMAL(10, 2)) AS decimal_val, + CAST(12345 AS DECIMAL(10, 2)) AS number_val, + CAST(12345 AS DECIMAL(10, 2)) AS numeric_val, + CAST(12345 AS DECIMAL(38, 0)) AS bigint_val; diff --git a/tests/resources/functional/snowflake/cast/test_cast_double.sql b/tests/resources/functional/snowflake/cast/test_cast_double.sql new file mode 100644 index 0000000000..06e777612f --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_double.sql @@ -0,0 +1,17 @@ +-- snowflake sql: +SELECT + 12345.678::DOUBLE AS double_val, + 12345.678::DOUBLE PRECISION AS double_precision_val, + 12345.678::FLOAT AS float_val, + 12345.678::FLOAT4 AS float4_val, + 12345.678::FLOAT8 AS float8_val, + 12345.678::REAL AS real_val + +-- databricks sql: +SELECT + CAST(12345.678 AS DOUBLE) AS double_val, + CAST(12345.678 AS DOUBLE) AS double_precision_val, + CAST(12345.678 AS DOUBLE) AS float_val, + CAST(12345.678 AS DOUBLE) AS float4_val, + CAST(12345.678 AS DOUBLE) AS float8_val, + CAST(12345.678 AS DOUBLE) AS real_val; diff --git a/tests/resources/functional/snowflake/cast/test_cast_int.sql b/tests/resources/functional/snowflake/cast/test_cast_int.sql new file mode 100644 index 0000000000..4ff05cec0b --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_int.sql @@ -0,0 +1,17 @@ +-- snowflake sql: +SELECT + 123::BYTEINT AS byteint_val, + 123::SMALLINT AS smallint_val, + 123::INT AS int_val, + 123::INTEGER AS integer_val, + 123::BIGINT AS bigint_val, + 123::TINYINT AS tinyint_val + +-- databricks sql: +SELECT + CAST(123 AS DECIMAL(38, 0)) AS byteint_val, + CAST(123 AS DECIMAL(38, 0)) AS smallint_val, + CAST(123 AS DECIMAL(38, 0)) AS int_val, + CAST(123 AS DECIMAL(38, 0)) AS integer_val, + CAST(123 AS DECIMAL(38, 0)) AS bigint_val, + CAST(123 AS TINYINT) AS tinyint_val; diff --git a/tests/resources/functional/snowflake/cast/test_cast_strings.sql b/tests/resources/functional/snowflake/cast/test_cast_strings.sql new file mode 100644 index 0000000000..05fe1e0973 --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_strings.sql @@ -0,0 +1,15 @@ +-- snowflake sql: +SELECT + '12345'::VARCHAR(10) AS varchar_val, + '12345'::STRING AS string_val, + '12345'::TEXT AS text_val, + 'A'::CHAR(1) AS char_val, + 'A'::CHARACTER(1) AS character_val + +-- databricks sql: +SELECT + CAST('12345' AS STRING) AS varchar_val, + CAST('12345' AS STRING) AS string_val, + CAST('12345' AS STRING) AS text_val, + CAST('A' AS STRING) AS char_val, + CAST('A' AS STRING) AS character_val; diff --git a/tests/resources/functional/snowflake/cast/test_cast_timestamp.sql b/tests/resources/functional/snowflake/cast/test_cast_timestamp.sql new file mode 100644 index 0000000000..f38bff6588 --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_cast_timestamp.sql @@ -0,0 +1,15 @@ +-- snowflake sql: +SELECT + '12:34:56'::TIME AS time_val, + '2024-01-01 12:34:56'::TIMESTAMP AS timestamp_val, + '2024-01-01 12:34:56 +00:00'::TIMESTAMP_LTZ AS timestamp_ltz_val, + '2024-01-01 12:34:56'::TIMESTAMP_NTZ AS timestamp_ntz_val, + '2024-01-01 12:34:56 +00:00'::TIMESTAMP_TZ AS timestamp_tz_val + +-- databricks sql: +SELECT + CAST('12:34:56' AS TIMESTAMP) AS time_val, + CAST('2024-01-01 12:34:56' AS TIMESTAMP) AS timestamp_val, + CAST('2024-01-01 12:34:56 +00:00' AS TIMESTAMP) AS timestamp_ltz_val, + CAST('2024-01-01 12:34:56' AS TIMESTAMP_NTZ) AS timestamp_ntz_val, + CAST('2024-01-01 12:34:56 +00:00' AS TIMESTAMP) AS timestamp_tz_val; diff --git a/tests/resources/functional/snowflake/cast/test_colon_cast.sql b/tests/resources/functional/snowflake/cast/test_colon_cast.sql new file mode 100644 index 0000000000..2353248e73 --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_colon_cast.sql @@ -0,0 +1,7 @@ +-- snowflake sql: +SELECT + ARRAY_REMOVE([2, 3, 4::DOUBLE, 4, NULL], 4) + +-- databricks sql: +SELECT + ARRAY_REMOVE(ARRAY(2, 3, CAST(4 AS DOUBLE), 4, NULL), 4); diff --git a/tests/resources/functional/snowflake/cast/test_typecasts.sql b/tests/resources/functional/snowflake/cast/test_typecasts.sql new file mode 100644 index 0000000000..46b513caff --- /dev/null +++ b/tests/resources/functional/snowflake/cast/test_typecasts.sql @@ -0,0 +1,13 @@ +-- snowflake sql: +SELECT + PARSE_JSON('[1,2,3]')::ARRAY(INTEGER) AS array_val, + 'deadbeef'::BINARY AS binary_val, + 'true'::BOOLEAN AS boolean_val, + 'deadbeef'::VARBINARY AS varbinary_val + +-- databricks sql: +SELECT + FROM_JSON('[1,2,3]', 'ARRAY') AS array_val, + CAST('deadbeef' AS BINARY) AS binary_val, + CAST('true' AS BOOLEAN) AS boolean_val, + CAST('deadbeef' AS BINARY) AS varbinary_val; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/last_value_1.sql b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_1.sql new file mode 100644 index 0000000000..2c7525caa2 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_1.sql @@ -0,0 +1,20 @@ +-- snowflake sql: +SELECT + last_value(col1) over ( + partition by col1 + order by + col2 + ) AS last_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LAST(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING + ) AS last_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/last_value_2.sql b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_2.sql new file mode 100644 index 0000000000..03192fdfea --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_2.sql @@ -0,0 +1,33 @@ +-- snowflake sql: +SELECT + taba.col_a, + taba.col_b, + last_value( + CASE + WHEN taba.col_c IN ('xyz', 'abc') THEN taba.col_d + END + ) ignore nulls OVER ( + partition BY taba.col_e + ORDER BY + taba.col_f DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a taba; + +-- databricks sql: +SELECT + taba.col_a, + taba.col_b, + LAST( + CASE + WHEN taba.col_c IN ('xyz', 'abc') THEN taba.col_d + END + ) IGNORE NULLS OVER ( + PARTITION BY taba.col_e + ORDER BY + taba.col_f DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a AS taba; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/last_value_3.sql b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_3.sql new file mode 100644 index 0000000000..7cef7b85b1 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/last_value_3.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT + last_value(col1) AS last_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LAST(col1) AS last_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_1.sql b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_1.sql new file mode 100644 index 0000000000..d48fda055c --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_1.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT + first_value(col1) AS first_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + FIRST(col1) AS first_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_2.sql b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_2.sql new file mode 100644 index 0000000000..8bb3973706 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_2.sql @@ -0,0 +1,20 @@ +-- snowflake sql: +SELECT + first_value(col1) over ( + partition by col1 + order by + col2 + ) AS first_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + FIRST(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING + ) AS first_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_3.sql b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_3.sql new file mode 100644 index 0000000000..20febd600d --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/aggregates/test_first_value_3.sql @@ -0,0 +1,35 @@ +-- snowflake sql: +SELECT + tabb.col_a, + tabb.col_b, + first_value( + CASE + WHEN tabb.col_c IN ('xyz', 'abc') THEN tabb.col_d + END + ) ignore nulls OVER ( + partition BY tabb.col_e + ORDER BY + tabb.col_f DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a taba + LEFT JOIN schema_b.table_b AS tabb ON taba.col_e = tabb.col_e; + +-- databricks sql: +SELECT + tabb.col_a, + tabb.col_b, + FIRST( + CASE + WHEN tabb.col_c IN ('xyz', 'abc') THEN tabb.col_d + END + ) IGNORE NULLS OVER ( + PARTITION BY tabb.col_e + ORDER BY + tabb.col_f DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a AS taba + LEFT JOIN schema_b.table_b AS tabb ON taba.col_e = tabb.col_e; diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_1.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_1.sql new file mode 100644 index 0000000000..483fe5b9d1 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_TIME('2018-05-15', 'yyyy-MM-dd'); + +-- databricks sql: +SELECT DATE_FORMAT(TO_TIMESTAMP('2018-05-15', 'yyyy-MM-dd'), 'HH:mm:ss'); diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_2.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_2.sql new file mode 100644 index 0000000000..cde2d03870 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_to_time_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_TIME('2018-05-15 00:01:02'); + +-- databricks sql: +SELECT DATE_FORMAT(TO_TIMESTAMP('2018-05-15 00:01:02'), 'HH:mm:ss'); diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_1.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_1.sql new file mode 100644 index 0000000000..34f87bfc54 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_DATE('2018-05-15'); + +-- databricks sql: +SELECT DATE(TRY_TO_TIMESTAMP('2018-05-15')); diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_3.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_3.sql new file mode 100644 index 0000000000..bc89a4ba05 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/test_try_to_date_3.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +SELECT TRY_TO_DATE('2012.20.12', 'yyyy.dd.MM'), TRY_TO_DATE(d.col1) FROM dummy d; + +-- databricks sql: +SELECT DATE(TRY_TO_TIMESTAMP('2012.20.12', 'yyyy.dd.MM')), + DATE(TRY_TO_TIMESTAMP(d.col1)) FROM dummy AS d; diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_1.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_1.sql new file mode 100644 index 0000000000..c669c00b11 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_1.sql @@ -0,0 +1,30 @@ + +-- snowflake sql: +SELECT to_timestamp(col1) AS to_timestamp_col1 FROM tabl; + +-- databricks sql: +SELECT + CASE + TYPEOF(col1) + WHEN 'string' THEN IFNULL( + COALESCE( + TRY_TO_TIMESTAMP(TRY_CAST(col1 AS INT)), + TRY_TO_TIMESTAMP(col1, 'yyyy-MM-dd\'T\'HH:mmXXX'), + TRY_TO_TIMESTAMP(col1, 'yyyy-MM-dd HH:mmXXX'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy HH:mm:ss ZZZ'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy HH:mm:ss.SSSSSSSSS ZZZ'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy hh:mm:ss a ZZZ'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy hh:mm:ss.SSSSSSSSS a ZZZ'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy HH:mm:ss'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy HH:mm:ss.SSSSSSSSS'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy hh:mm:ss a'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ', dd MMM yyyy hh:mm:ss.SSSSSSSSS a'), + TRY_TO_TIMESTAMP(col1, 'M/dd/yyyy HH:mm:ss'), + TRY_TO_TIMESTAMP(SUBSTR(col1, 4), ' MMM dd HH:mm:ss ZZZ yyyy') + ), + TO_TIMESTAMP(col1) + ) + ELSE CAST(col1 AS TIMESTAMP) + END AS to_timestamp_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_variable_format.sql b/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_variable_format.sql new file mode 100644 index 0000000000..6c3bc0806a --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/conversion/to_timestamp_variable_format.sql @@ -0,0 +1,179 @@ +-- snowflake sql: + SELECT TO_TIMESTAMP(str, fmt) FROM (VALUES ('2024-11-20T18:05:59.123456789', 'YYYY-MM-DD"T"HH24:MI:SS.FF'), ('Thu, 21 Dec 2000 04:01:07 PM +0200', 'DY, DD MON YYYY HH12:MI:SS AM TZHTZM') ) AS vals(str, fmt); + +-- databricks sql: +SELECT IF( + STARTSWITH(fmt, 'DY'), + TO_TIMESTAMP(SUBSTR(str, 4), SUBSTR(REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(fmt, 'YYYY', 'yyyy'), 'YY', 'yy'), 'MON', 'MMM'), 'DD', 'dd'), 'DY', 'EEE'), + 'HH24', + 'HH' + ), + 'HH12', + 'hh' + ), + 'AM', + 'a' + ), + 'PM', + 'a' + ), + 'MI', + 'mm' + ), + 'SS', + 'ss' + ), + 'FF9', + 'SSSSSSSSS' + ), + 'FF8', + 'SSSSSSSS' + ), + 'FF7', + 'SSSSSSS' + ), + 'FF6', + 'SSSSSS' + ), + 'FF5', + 'SSSSS' + ), + 'FF4', + 'SSSS' + ), + 'FF3', + 'SSS' + ), + 'FF2', + 'SS' + ), + 'FF1', + 'S' + ), + 'FF0', + '' + ), + 'FF', + 'SSSSSSSSS' + ), + 'TZH:TZM', + 'ZZZ' + ), + 'TZHTZM', + 'ZZZ' + ), + 'TZH', + 'ZZZ' + ), + 'UUUU', + 'yyyy' + ), '"', '\''), 4)), + TO_TIMESTAMP(str, REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE( + REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(fmt, 'YYYY', 'yyyy'), 'YY', 'yy'), 'MON', 'MMM'), 'DD', 'dd'), 'DY', 'EEE'), + 'HH24', + 'HH' + ), + 'HH12', + 'hh' + ), + 'AM', + 'a' + ), + 'PM', + 'a' + ), + 'MI', + 'mm' + ), + 'SS', + 'ss' + ), + 'FF9', + 'SSSSSSSSS' + ), + 'FF8', + 'SSSSSSSS' + ), + 'FF7', + 'SSSSSSS' + ), + 'FF6', + 'SSSSSS' + ), + 'FF5', + 'SSSSS' + ), + 'FF4', + 'SSSS' + ), + 'FF3', + 'SSS' + ), + 'FF2', + 'SS' + ), + 'FF1', + 'S' + ), + 'FF0', + '' + ), + 'FF', + 'SSSSSSSSS' + ), + 'TZH:TZM', + 'ZZZ' + ), + 'TZHTZM', + 'ZZZ' + ), + 'TZH', + 'ZZZ' + ), + 'UUUU', + 'yyyy' + ), '"', '\'')) +) FROM VALUES ('2024-11-20T18:05:59.123456789', 'YYYY-MM-DD"T"HH24:MI:SS.FF'), ('Thu, 21 Dec 2000 04:01:07 PM +0200', 'DY, DD MON YYYY HH12:MI:SS AM TZHTZM') AS vals(str, fmt); diff --git a/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_2.sql b/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_2.sql new file mode 100644 index 0000000000..b5de0c552a --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +DELETE FROM table1 WHERE cre_at >= DATE_TRUNC('month', TRY_TO_DATE('2022-01-15')); + +-- databricks sql: +DELETE FROM table1 WHERE cre_at >= DATE_TRUNC('MONTH', DATE(TRY_TO_TIMESTAMP('2022-01-15'))); diff --git a/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_3.sql b/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_3.sql new file mode 100644 index 0000000000..4a84dd8c8e --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/dates/test_date_trunc_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select DATE_TRUNC('month', TRY_TO_DATE(COLUMN1)) from table; + +-- databricks sql: +SELECT DATE_TRUNC('MONTH', DATE(TRY_TO_TIMESTAMP(COLUMN1))) FROM table; diff --git a/tests/resources/functional/snowflake/core_engine/functions/dates/test_dayname_1.sql b/tests/resources/functional/snowflake/core_engine/functions/dates/test_dayname_1.sql new file mode 100644 index 0000000000..91cdf05f10 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/dates/test_dayname_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT DAYNAME(TO_TIMESTAMP('2015-04-03 10:00:00')) AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT(TO_TIMESTAMP('2015-04-03 10:00:00'), 'E') AS MONTH; diff --git a/tests/resources/functional/snowflake/core_engine/functions/dates/test_monthname_1.sql b/tests/resources/functional/snowflake/core_engine/functions/dates/test_monthname_1.sql new file mode 100644 index 0000000000..b7a529d354 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/dates/test_monthname_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT MONTHNAME(TO_TIMESTAMP('2015-04-03 10:00:00')) AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT(TO_TIMESTAMP('2015-04-03 10:00:00'), 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_3.sql b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_3.sql new file mode 100644 index 0000000000..7346ba5777 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_3.sql @@ -0,0 +1,7 @@ +-- snowflake sql: + +SELECT REGEXP_SUBSTR('The real world of The Doors', 'The\\W+\\w+', 2); + +-- databricks sql: + +SELECT REGEXP_EXTRACT(SUBSTR('The real world of The Doors', 2), 'The\\W+\\w+', 0); diff --git a/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_4.sql b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_4.sql new file mode 100644 index 0000000000..5d9c432eb3 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_4.sql @@ -0,0 +1,7 @@ +-- snowflake sql: + +SELECT REGEXP_SUBSTR('The real world of The Doors', 'The\\W+\\w+', 1, 2); + +-- databricks sql: + +SELECT REGEXP_EXTRACT_ALL(SUBSTR('The real world of The Doors', 1), 'The\\W+\\w+', 0)[1]; diff --git a/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_5.sql b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_5.sql new file mode 100644 index 0000000000..66a8881027 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_5.sql @@ -0,0 +1,7 @@ +-- snowflake sql: + +SELECT REGEXP_SUBSTR('The real world of The Doors', 'the\\W+\\w+', 1, 2, 'i'); + +-- databricks sql: + +SELECT REGEXP_EXTRACT_ALL(SUBSTR('The real world of The Doors', 1), '(?i)the\\W+\\w+', 0)[1]; diff --git a/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_6.sql b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_6.sql new file mode 100644 index 0000000000..3e2d2b4bb2 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_6.sql @@ -0,0 +1,22 @@ +-- snowflake sql: +WITH + params(p) AS (SELECT 'i') +SELECT REGEXP_SUBSTR('The real world of The Doors', 'the\\W+\\w+', 1, 2, p) FROM params; + +-- databricks sql: +WITH + params (p) AS (SELECT 'i') +SELECT REGEXP_EXTRACT_ALL( + SUBSTR('The real world of The Doors', 1), + AGGREGATE( + SPLIT(p, ''), + CAST(ARRAY() AS ARRAY), + (agg, item) -> + CASE + WHEN item = 'c' THEN FILTER(agg, item -> item != 'i') + WHEN item IN ('i', 's', 'm') THEN ARRAY_APPEND(agg, item) + ELSE agg + END + , + filtered -> '(?' || ARRAY_JOIN(filtered, '') || ')' || 'the\\W+\\w+' + ), 0)[1] FROM params; diff --git a/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_7.sql b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_7.sql new file mode 100644 index 0000000000..754f05c203 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/functions/strings/regexp_substr_7.sql @@ -0,0 +1,7 @@ +-- snowflake sql: + +SELECT REGEXP_SUBSTR('The real world of The Doors', 'the\\W+(\\w+)', 1, 2, 'i', 1); + +-- databricks sql: + +SELECT REGEXP_EXTRACT_ALL(SUBSTR('The real world of The Doors', 1), '(?i)the\\W+(\\w+)', 1)[1]; diff --git a/tests/resources/functional/snowflake/core_engine/lca/lca_function.sql b/tests/resources/functional/snowflake/core_engine/lca/lca_function.sql new file mode 100644 index 0000000000..a58d5e1d1e --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/lca/lca_function.sql @@ -0,0 +1,12 @@ +-- snowflake sql: +SELECT + col1 AS ca + FROM table1 + WHERE substr(ca,1,5) = '12345' + ; + +-- databricks sql: +SELECT + col1 AS ca + FROM table1 + WHERE SUBSTR(col1, 1,5) = '12345'; diff --git a/tests/resources/functional/snowflake/core_engine/lca/lca_homonym.sql.saved b/tests/resources/functional/snowflake/core_engine/lca/lca_homonym.sql.saved new file mode 100644 index 0000000000..8bc234a1cd --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/lca/lca_homonym.sql.saved @@ -0,0 +1,28 @@ +-- snowflake sql: +select + ca_zip +from ( + SELECT + substr(ca_zip,1,5) ca_zip, + trim(name) as name, + -- ca_zip should not be transpiled + count(*) over( partition by ca_zip) + FROM + customer_address + WHERE + -- ca_zip should not be transpiled + ca_zip IN ('89436', '30868')); +-- databricks sql: +SELECT + ca_zip +FROM +SELECT + SUBSTR(ca_zip,1,5) AS ca_zip, + TRIM(name) AS name, + COUNT(*) OVER ( + PARTITION BY ca_zip + ) +FROM + customer_address +WHERE + ca_zip IN ('89436', '30868'); diff --git a/tests/resources/functional/snowflake/core_engine/sample.sql b/tests/resources/functional/snowflake/core_engine/sample.sql new file mode 100644 index 0000000000..c9c9dc5f26 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/sample.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +select * from table_name; + +-- databricks sql: +select * from table_name; diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/except.sql b/tests/resources/functional/snowflake/core_engine/set-operations/except.sql new file mode 100644 index 0000000000..bd6b08dcd9 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/except.sql @@ -0,0 +1,14 @@ +-- ## ... EXCEPT ... +-- +-- Verify simple EXCEPT handling. +-- +-- snowflake sql: + +SELECT 1 +EXCEPT +SELECT 2; + +-- databricks sql: +(SELECT 1) +EXCEPT +(SELECT 2); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/intersect.sql b/tests/resources/functional/snowflake/core_engine/set-operations/intersect.sql new file mode 100644 index 0000000000..88f91e71ac --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/intersect.sql @@ -0,0 +1,14 @@ +-- ## ... INTERSECT ... +-- +-- Verify simple INTERSECT handling. +-- +-- snowflake sql: + +SELECT 1 +INTERSECT +SELECT 2; + +-- databricks sql: +(SELECT 1) +INTERSECT +(SELECT 2); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/minus.sql b/tests/resources/functional/snowflake/core_engine/set-operations/minus.sql new file mode 100644 index 0000000000..1b9dc4d608 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/minus.sql @@ -0,0 +1,14 @@ +-- ## ... MINUS ... +-- +-- Verify simple MINUS handling: it is an alias for EXCEPT. +-- +-- snowflake sql: + +SELECT 1 +MINUS +SELECT 2; + +-- databricks sql: +(SELECT 1) +EXCEPT +(SELECT 2); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/precedence.sql b/tests/resources/functional/snowflake/core_engine/set-operations/precedence.sql new file mode 100644 index 0000000000..f42df26df4 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/precedence.sql @@ -0,0 +1,119 @@ +-- +-- Verify the precedence rules are being correctly handled. Order of evaluation when chaining is: +-- 1. Brackets. +-- 2. INTERSECT +-- 3. UNION and EXCEPT, evaluated left to right. +-- + +-- snowflake sql: + +-- Verifies UNION/EXCEPT/MINUS as left-to-right (1/3), with brackets. +(SELECT 1 + UNION + SELECT 2 + EXCEPT + SELECT 3 + MINUS + (SELECT 4 + UNION + SELECT 5)) + +UNION ALL + +-- Verifies UNION/EXCEPT/MINUS as left-to-right (2/3) when the order is rotated from the previous. +(SELECT 6 + EXCEPT + SELECT 7 + MINUS + SELECT 8 + UNION + SELECT 9) + +UNION ALL + +-- Verifies UNION/EXCEPT/MINUS as left-to-right (3/3) when the order is rotated from the previous. +(SELECT 10 + MINUS + SELECT 11 + UNION + SELECT 12 + EXCEPT + SELECT 13) + +UNION ALL + +-- Verifies that INTERSECT has precedence over UNION/EXCEPT/MINUS. +(SELECT 14 + UNION + SELECT 15 + EXCEPT + SELECT 16 + MINUS + SELECT 17 + INTERSECT + SELECT 18) + +UNION ALL + +-- Verifies that INTERSECT is left-to-right, although brackets have precedence. +(SELECT 19 + INTERSECT + SELECT 20 + INTERSECT + (SELECT 21 + INTERSECT + SELECT 22)); + +-- databricks sql: + + ( + ( + ( + ( + ( + ((SELECT 1) UNION (SELECT 2)) + EXCEPT + (SELECT 3) + ) + EXCEPT + ((SELECT 4) UNION (SELECT 5)) + ) + UNION ALL + ( + ( + ((SELECT 6) EXCEPT (SELECT 7)) + EXCEPT + (SELECT 8) + ) + UNION + (SELECT 9) + ) + ) + UNION ALL + ( + ( + ((SELECT 10) EXCEPT (SELECT 11)) + UNION + (SELECT 12) + ) + EXCEPT + (SELECT 13) + ) + ) + UNION ALL + ( + ( + ((SELECT 14) UNION (SELECT 15)) + EXCEPT + (SELECT 16) + ) + EXCEPT + ((SELECT 17) INTERSECT (SELECT 18)) + ) + ) +UNION ALL + ( + ((SELECT 19) INTERSECT (SELECT 20)) + INTERSECT + ((SELECT 21) INTERSECT (SELECT 22)) + ); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/union-all.sql b/tests/resources/functional/snowflake/core_engine/set-operations/union-all.sql new file mode 100644 index 0000000000..9cc59fbdf5 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/union-all.sql @@ -0,0 +1,14 @@ +-- ## ... UNION ALL ... +-- +-- Verify simple UNION ALL handling. +-- +-- snowflake sql: + +SELECT 1 +UNION ALL +SELECT 2; + +-- databricks sql: +(SELECT 1) +UNION ALL +(SELECT 2); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/union.sql b/tests/resources/functional/snowflake/core_engine/set-operations/union.sql new file mode 100644 index 0000000000..28ded38b72 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/union.sql @@ -0,0 +1,14 @@ +-- ## ... UNION ... +-- +-- Verify simple UNION handling. +-- +-- snowflake sql: + +SELECT 1 +UNION +SELECT 2; + +-- databricks sql: +(SELECT 1) +UNION +(SELECT 2); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/union_all_left_grouped.sql b/tests/resources/functional/snowflake/core_engine/set-operations/union_all_left_grouped.sql new file mode 100644 index 0000000000..93c462e0b9 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/union_all_left_grouped.sql @@ -0,0 +1,10 @@ +-- ## (SELECT …) UNION ALL SELECT … +-- +-- Verify UNION handling when the LHS of the union is explicitly wrapped in parentheses. +-- +-- snowflake sql: + +(SELECT a, b from c) UNION ALL SELECT x, y from z; + +-- databricks sql: +(SELECT a, b FROM c) UNION ALL (SELECT x, y FROM z); diff --git a/tests/resources/functional/snowflake/core_engine/set-operations/union_left_grouped.sql b/tests/resources/functional/snowflake/core_engine/set-operations/union_left_grouped.sql new file mode 100644 index 0000000000..03d3cff80d --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/set-operations/union_left_grouped.sql @@ -0,0 +1,10 @@ +-- ## (SELECT …) UNION SELECT … +-- +-- Verify UNION handling when the LHS of the union is explicitly wrapped in parentheses. +-- +-- snowflake sql: + +(SELECT a, b from c) UNION SELECT x, y from z; + +-- databricks sql: +(SELECT a, b FROM c) UNION (SELECT x, y FROM z); diff --git a/tests/resources/functional/snowflake/core_engine/test_command/test_command_1.sql b/tests/resources/functional/snowflake/core_engine/test_command/test_command_1.sql new file mode 100644 index 0000000000..9445b516fe --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_command/test_command_1.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: +!set exit_on_error = true; + +-- databricks sql: +/* The following issues were detected: + + Unknown command in SnowflakeAstBuilder.visitSnowSqlCommand + !set exit_on_error = true; + */ diff --git a/tests/resources/functional/snowflake/core_engine/test_cte/cte_set_operation_precedence.sql b/tests/resources/functional/snowflake/core_engine/test_cte/cte_set_operation_precedence.sql new file mode 100644 index 0000000000..457c85f7bd --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_cte/cte_set_operation_precedence.sql @@ -0,0 +1,17 @@ +-- +-- CTEs are visible to all the SELECT queries within a subsequent sequence of set operations. +-- + +-- snowflake sql: +WITH a AS (SELECT 1, 2, 3) + +SELECT 4, 5, 6 +UNION +SELECT * FROM a; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3) + +(SELECT 4, 5, 6) +UNION +(SELECT * FROM a); diff --git a/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql b/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql new file mode 100644 index 0000000000..bb0ed6c61b --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql @@ -0,0 +1,16 @@ +-- snowflake sql: +WITH employee_hierarchy AS ( + SELECT + employee_id, + manager_id, + employee_name + FROM + employees + WHERE + manager_id IS NULL +) +SELECT * +FROM employee_hierarchy; + +-- databricks sql: +WITH employee_hierarchy AS (SELECT employee_id, manager_id, employee_name FROM employees WHERE manager_id IS NULL) SELECT * FROM employee_hierarchy; diff --git a/tests/resources/functional/snowflake/core_engine/test_cte/multiple_cte.sql b/tests/resources/functional/snowflake/core_engine/test_cte/multiple_cte.sql new file mode 100644 index 0000000000..b330b86a87 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_cte/multiple_cte.sql @@ -0,0 +1,19 @@ +-- +-- Verify a few CTEs that include multiple expressions. +-- + +-- snowflake sql: +WITH a AS (SELECT 1, 2, 3), + b AS (SELECT 4, 5, 6), + c AS (SELECT * FROM a) +SELECT * from b +UNION +SELECT * FROM c; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3), + b AS (SELECT 4, 5, 6), + c AS (SELECT * FROM a) +(SELECT * from b) +UNION +(SELECT * FROM c); diff --git a/tests/resources/functional/snowflake/core_engine/test_cte/nested_set_operation.sql b/tests/resources/functional/snowflake/core_engine/test_cte/nested_set_operation.sql new file mode 100644 index 0000000000..43e5d39189 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_cte/nested_set_operation.sql @@ -0,0 +1,19 @@ +-- +-- Verify a CTE that includes set operations. +-- + +-- snowflake sql: +WITH a AS ( + SELECT 1, 2, 3 + UNION + SELECT 4, 5, 6 +) +SELECT * FROM a; + +-- databricks sql: +WITH a AS ( + (SELECT 1, 2, 3) + UNION + (SELECT 4, 5, 6) +) +SELECT * FROM a; diff --git a/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_1.sql b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_1.sql new file mode 100644 index 0000000000..253aac2f17 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_1.sql @@ -0,0 +1,30 @@ +-- Note that here we have two commas in the select clause and although in other circumstances, +-- the parser could notice that is an additional comma, in this case it is not able to do so because +-- what can be in between the comma is just about anything. Then because any ID is accepted as +-- possibly being some kind of command, then the parser has to assume that the following tokens +-- are some valid command. +-- Hence this error is thrown by a no viable alternative at input ',' and the parser recovers to something +-- that looks like it is a valid command because of the let rule where LET is optional and the next token +-- is an ID, which is therefore predicted and we will accumulate a lot of erroneous errors. + +-- snowflake sql: +select col1,, col2 from table_name; + +-- databricks sql: +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: select col1,, + */ +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: select + Unparsable text: col1 + Unparsable text: , + Unparsable text: , + Unparsable text: col2 + Unparsable text: from + Unparsable text: table_name + Unparsable text: parser recovered by ignoring: select col1,, col2 from table_name; + */ diff --git a/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_2.sql b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_2.sql new file mode 100644 index 0000000000..8acd8c794b --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_2.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +* + +-- databricks sql: +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: unexpected extra input '*' while parsing a Snowflake batch + expecting one of: End of batch, Select Statement, Statement, '(', ';', 'CALL', 'COMMENT', 'DECLARE', 'GET', 'LET', 'START', 'WITH'... + Unparsable text: * + */ diff --git a/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_3.sql b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_3.sql new file mode 100644 index 0000000000..30a2c18378 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_invalid_syntax/syntax_error_3.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +* ; +SELECT 1 ; +SELECT A B FROM C ; + +-- databricks sql: +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: unexpected extra input '*' while parsing a Snowflake batch + expecting one of: End of batch, Select Statement, Statement, '(', ';', 'CALL', 'COMMENT', 'DECLARE', 'GET', 'LET', 'START', 'WITH'... + Unparsable text: * + */ +SELECT + 1; +SELECT + A AS B +FROM + C; diff --git a/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_1.sql b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_1.sql new file mode 100644 index 0000000000..c87ed298fd --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_1.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: +ALTER SESSION SET QUERY_TAG = 'tag1'; + +-- databricks sql: +/* The following issues were detected: + + Unknown ALTER command variant + ALTER SESSION SET QUERY_TAG = 'tag1' + */ diff --git a/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_5.sql b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_5.sql new file mode 100644 index 0000000000..1e8f6ea25c --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_5.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: +CREATE STREAM mystream ON TABLE mytable; + +-- databricks sql: +/* The following issues were detected: + + CREATE STREAM UNSUPPORTED + CREATE STREAM mystream ON TABLE mytable + */ diff --git a/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_6.sql b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_6.sql new file mode 100644 index 0000000000..3470cc6ae9 --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_skip_unsupported_operations/test_skip_unsupported_operations_6.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: +ALTER STREAM mystream SET COMMENT = 'New comment for stream'; + +-- databricks sql: +/* The following issues were detected: + + Unknown ALTER command variant + ALTER STREAM mystream SET COMMENT = 'New comment for stream' + */ diff --git a/tests/resources/functional/snowflake/cte/cte_with_column_list.sql b/tests/resources/functional/snowflake/cte/cte_with_column_list.sql new file mode 100644 index 0000000000..a984432f14 --- /dev/null +++ b/tests/resources/functional/snowflake/cte/cte_with_column_list.sql @@ -0,0 +1,11 @@ +-- +-- A simple CTE, with the column list expressed. +-- + +-- snowflake sql: +WITH a (b, c, d) AS (SELECT 1 AS b, 2 AS c, 3 AS d) +SELECT b, c, d FROM a; + +-- databricks sql: +WITH a (b, c, d) AS (SELECT 1 AS b, 2 AS c, 3 AS d) +SELECT b, c, d FROM a; diff --git a/tests/resources/functional/snowflake/cte/simple_cte.sql b/tests/resources/functional/snowflake/cte/simple_cte.sql new file mode 100644 index 0000000000..cf15cca593 --- /dev/null +++ b/tests/resources/functional/snowflake/cte/simple_cte.sql @@ -0,0 +1,11 @@ +-- +-- Verify a simple CTE. +-- + +-- snowflake sql: +WITH a AS (SELECT 1, 2, 3) +SELECT * FROM a; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3) +SELECT * FROM a; diff --git a/tests/resources/functional/snowflake/ddl/alter/test_alter_1.sql b/tests/resources/functional/snowflake/ddl/alter/test_alter_1.sql new file mode 100644 index 0000000000..842b198836 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/alter/test_alter_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +ALTER TABLE employees ADD COLUMN first_name VARCHAR(50); + +-- databricks sql: +ALTER TABLE employees ADD COLUMN first_name STRING; diff --git a/tests/resources/functional/snowflake/ddl/alter/test_alter_2.sql b/tests/resources/functional/snowflake/ddl/alter/test_alter_2.sql new file mode 100644 index 0000000000..554a58a46d --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/alter/test_alter_2.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +ALTER TABLE employees ADD COLUMN first_name VARCHAR(50) NOT NULL, age INT, hire_date DATE; + +-- databricks sql: +ALTER TABLE employees ADD COLUMN first_name STRING NOT NULL, age DECIMAL(38, 0), hire_date DATE; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_1.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_1.sql new file mode 100644 index 0000000000..31f27b7f93 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_1.sql @@ -0,0 +1,17 @@ +-- snowflake sql: +SELECT + p.info:id AS "ID", + p.info:first AS "First", + p.info:first.b AS C +FROM + (SELECT PARSE_JSON('{"id": {"a":{"c":"102","d":"106"}}, "first": {"b":"105"}}')) AS p(info); + +-- databricks sql: +SELECT + p.info:id AS `ID`, + p.info:first AS `First`, + p.info:first.b AS C +FROM ( + SELECT + PARSE_JSON('{"id": {"a":{"c":"102","d":"106"}}, "first": {"b":"105"}}') +) AS p(info); diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_10.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_10.sql new file mode 100644 index 0000000000..19dd4ebbe6 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_10.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + tt.id, + lit.value: details AS details +FROM VALUES + (1, '{"order": {"id": 101,"items": [{"item_id": "A1","quantity": 2,"details": {"color": "red"}},{"item_id": "B2","quantity": 5,"details": {"color": "blue"}}]}}'), + (2, '{"order": {"id": 202,"items": [{"item_id": "C3","quantity": 4,"details": {"color": "green", "size": "L"}},{"item_id": "D4","quantity": 3,"details": {"color": "yellow", "size": "M"}}]}}') +AS tt(id, resp) +, LATERAL FLATTEN(input => PARSE_JSON(tt.resp):order.items) AS lit + +-- databricks sql: +SELECT +tt.id, +lit.value:details AS details +FROM VALUES + (1, '{"order": {"id": 101,"items": [{"item_id": "A1","quantity": 2,"details": {"color": "red"}},{"item_id": "B2","quantity": 5,"details": {"color": "blue"}}]}}'), + (2, '{"order": {"id": 202,"items": [{"item_id": "C3","quantity": 4,"details": {"color": "green", "size": "L"}},{"item_id": "D4","quantity": 3,"details": {"color": "yellow", "size": "M"}}]}}') +AS tt(id, resp) +, LATERAL VARIANT_EXPLODE(PARSE_JSON(tt.resp):order.items) AS lit; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_11.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_11.sql new file mode 100644 index 0000000000..f56ad0d25b --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_11.sql @@ -0,0 +1,16 @@ +-- snowflake sql: +SELECT + demo.level_key:"level_1_key":"level_2_key"['1'] AS col +FROM + ( + SELECT + PARSE_JSON('{"level_1_key": { "level_2_key": { "1": "desired_value" }}}') AS level_key + ) AS demo; + +-- databricks sql: +SELECT + demo.level_key:level_1_key.level_2_key["1"] AS col +FROM ( + SELECT + PARSE_JSON('{"level_1_key": { "level_2_key": { "1": "desired_value" }}}') AS level_key +) AS demo; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_12.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_12.sql new file mode 100644 index 0000000000..33e16cb8d8 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_12.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + verticals.index AS index, + verticals.value AS array_val + FROM + ( + select ARRAY_CONSTRUCT('value1', 'value2', 'value3') as col + ) AS sample_data(array_column), + LATERAL FLATTEN(input => sample_data.array_column, OUTER => true) AS verticals; + +-- databricks sql: +SELECT + verticals.index AS index, + verticals.value AS array_val +FROM ( + SELECT + ARRAY('value1', 'value2', 'value3') AS col +) AS sample_data(array_column) + LATERAL VIEW OUTER POSEXPLODE(sample_data.array_column) verticals AS index, value; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_2.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_2.sql new file mode 100644 index 0000000000..73ea10d71e --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_2.sql @@ -0,0 +1,25 @@ +-- snowflake sql: +SELECT + f.value:name AS "Contact", + f.value:first, + CAST(p.col:a:info:id AS DOUBLE) AS "id_parsed", + p.col:b:first, + p.col:a:info +FROM + (SELECT + PARSE_JSON('{"a": {"info": {"id": 101, "first": "John" }, "contact": [{"name": "Alice", "first": "A"}, {"name": "Bob", "first": "B"}]}, "b": {"id": 101, "first": "John"}}') + ) AS p(col) +, LATERAL FLATTEN(input => p.col:a:contact) AS f; + +-- databricks sql: +SELECT + f.value:name AS `Contact`, + f.value:first, + CAST(p.col:a.info.id AS DOUBLE) AS `id_parsed`, + p.col:b.first, + p.col:a.info +FROM ( + SELECT + PARSE_JSON('{"a": {"info": {"id": 101, "first": "John" }, "contact": [{"name": "Alice", "first": "A"}, {"name": "Bob", "first": "B"}]}, "b": {"id": 101, "first": "John"}}') +) AS p(col) +, LATERAL VARIANT_EXPLODE(p.col:a.contact) AS f; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_3.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_3.sql new file mode 100644 index 0000000000..efb6f94120 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_3.sql @@ -0,0 +1,35 @@ +-- snowflake sql: +SELECT + d.col:display_position::NUMBER AS display_position, + i.value:attributes::VARCHAR AS attributes, + CAST(CURRENT_TIMESTAMP() AS TIMESTAMP_NTZ(9)) AS created_at, + i.value:prop::FLOAT AS prop, + d.col:candidates AS candidates +FROM + ( + SELECT + PARSE_JSON('{"display_position": 123, "impressions": [{"attributes": "some_attributes", "prop": 12.34}, {"attributes": "other_attributes", "prop": 56.78}], "candidates": "some_candidates"}') AS col, + '2024-08-28' AS event_date, + 'store.replacements_view' AS event_name + ) AS d, + LATERAL FLATTEN(input => d.col:impressions, outer => true) AS i +WHERE + d.event_date = '2024-08-28' + AND d.event_name IN ('store.replacements_view'); + +-- databricks sql: +SELECT + CAST(d.col:display_position AS DECIMAL(38, 0)) AS display_position, + CAST(i.value:attributes AS STRING) AS attributes, + CAST(CURRENT_TIMESTAMP() AS TIMESTAMP_NTZ) AS created_at, + CAST(i.value:prop AS DOUBLE) AS prop, + d.col:candidates AS candidates +FROM ( + SELECT + PARSE_JSON('{"display_position": 123, "impressions": [{"attributes": "some_attributes", "prop": 12.34}, {"attributes": "other_attributes", "prop": 56.78}], "candidates": "some_candidates"}') AS col, + '2024-08-28' AS event_date, + 'store.replacements_view' AS event_name +) AS d +, LATERAL VARIANT_EXPLODE_OUTER(d.col:impressions) AS i +WHERE + d.event_date = '2024-08-28' AND d.event_name IN ('store.replacements_view'); diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_4.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_4.sql new file mode 100644 index 0000000000..3dcb2623d3 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_4.sql @@ -0,0 +1,23 @@ +-- snowflake sql: +SELECT + tt.col:id AS tax_transaction_id, + CAST(tt.col:responseBody.isMpfState AS BOOLEAN) AS is_mpf_state, + REGEXP_REPLACE(tt.col:requestBody.deliveryLocation.city, '-', '') AS delivery_city, + REGEXP_REPLACE(tt.col:requestBody.store.storeAddress.zipCode, '=', '') AS store_zipcode +FROM ( + SELECT + PARSE_JSON('{"id": 1, "responseBody": { "isMpfState": true }, "requestBody": { "deliveryLocation": { "city": "New-York" }, "store": {"storeAddress": {"zipCode": "100=01"}}}}') + AS col +) AS tt; + +-- databricks sql: +SELECT + tt.col:id AS tax_transaction_id, + CAST(tt.col:responseBody.isMpfState AS BOOLEAN) AS is_mpf_state, + REGEXP_REPLACE(tt.col:requestBody.deliveryLocation.city, '-', '') AS delivery_city, + REGEXP_REPLACE(tt.col:requestBody.store.storeAddress.zipCode, '=', '') AS store_zipcode +FROM ( + SELECT + PARSE_JSON('{"id": 1, "responseBody": { "isMpfState": true }, "requestBody": { "deliveryLocation": { "city": "New-York" }, "store": {"storeAddress": {"zipCode": "100=01"}}}}') + AS col +) AS tt; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_5.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_5.sql new file mode 100644 index 0000000000..0b42edbfe1 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_5.sql @@ -0,0 +1,23 @@ +-- snowflake sql: +SELECT + varchar1, + CAST(float1 AS STRING) AS float1_as_string, + CAST(variant1:Loan_Number AS STRING) AS loan_number_as_string +FROM + (SELECT + 'example_varchar' AS varchar1, + 123.456 AS float1, + PARSE_JSON('{"Loan_Number": "LN789"}') AS variant1 + ) AS tmp; + +-- databricks sql: +SELECT + varchar1, + CAST(float1 AS STRING) AS float1_as_string, + CAST(variant1:Loan_Number AS STRING) AS loan_number_as_string +FROM ( + SELECT + 'example_varchar' AS varchar1, + 123.456 AS float1, + PARSE_JSON('{"Loan_Number": "LN789"}') AS variant1 +) AS tmp; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_6.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_6.sql new file mode 100644 index 0000000000..3b3c42fd2e --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_6.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT ARRAY_EXCEPT([{'a': 1, 'b': 2}, 1], [{'a': 1, 'b': 2}, 3]); + +-- databricks sql: +SELECT ARRAY_EXCEPT(ARRAY(STRUCT(1 AS a, 2 AS b), 1), ARRAY(STRUCT(1 AS a, 2 AS b), 3)); diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_7.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_7.sql new file mode 100644 index 0000000000..bc206f75ac --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_7.sql @@ -0,0 +1,18 @@ +-- snowflake sql: +SELECT + v, + v:food AS food, + TO_JSON(v) AS v_as_json +FROM ( + SELECT PARSE_JSON('{"food": "apple"}') AS v +) t; + +-- databricks sql: +SELECT + v, + v:food AS food, + TO_JSON(v) AS v_as_json +FROM ( + SELECT + PARSE_JSON('{"food": "apple"}') AS v +) AS t; diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_8.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_8.sql new file mode 100644 index 0000000000..65404c060b --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_8.sql @@ -0,0 +1,12 @@ +-- snowflake sql: +SELECT PARSE_JSON(src.col):c AS c +FROM VALUES + ('{"a": "1", "b": "2", "c": null}'), + ('{"a": "1", "b": "2", "c": "3"}') AS src(col); + +-- databricks sql: +SELECT + PARSE_JSON(src.col):c AS c +FROM VALUES + ('{"a": "1", "b": "2", "c": null}'), + ('{"a": "1", "b": "2", "c": "3"}') AS src(col); diff --git a/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_9.sql b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_9.sql new file mode 100644 index 0000000000..a7e7bcc8f6 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/lateral_struct/test_lateral_struct_9.sql @@ -0,0 +1,44 @@ +-- snowflake sql: +SELECT + los.value:"objectDomain"::STRING AS object_type, + los.value:"objectName"::STRING AS object_name, + cols.value:"columnName"::STRING AS column_name, + COUNT(DISTINCT lah:"query_token"::STRING) AS n_queries, + COUNT(DISTINCT lah:"consumer_account_locator"::STRING) AS n_distinct_consumer_accounts +FROM + (SELECT + PARSE_JSON('{"query_date": "2022-03-02","query_token": "some_token","consumer_account_locator": "CONSUMER_ACCOUNT_LOCATOR","listing_objects_accessed": [{"objectDomain": "Table","objectName": "DATABASE_NAME.SCHEMA_NAME.TABLE_NAME","columns": [{"columnName": "column1"},{"columnName": "column2"}]}]}') AS lah + ) AS src, + LATERAL FLATTEN(input => src.lah:"listing_objects_accessed") AS los, + LATERAL FLATTEN(input => los.value:"columns") AS cols +WHERE + los.value:"objectDomain"::STRING IN ('Table', 'View') AND + src.lah:"query_date"::DATE BETWEEN '2022-03-01' AND '2022-04-30' AND + los.value:"objectName"::STRING = 'DATABASE_NAME.SCHEMA_NAME.TABLE_NAME' AND + src.lah:"consumer_account_locator"::STRING = 'CONSUMER_ACCOUNT_LOCATOR' +GROUP BY 1, 2, 3; + +-- databricks sql: +SELECT + CAST(los.value:objectDomain AS STRING) AS object_type, + CAST(los.value:objectName AS STRING) AS object_name, + CAST(cols.value:columnName AS STRING) AS column_name, + COUNT(DISTINCT CAST(lah:query_token AS STRING)) AS n_queries, + COUNT(DISTINCT CAST(lah:consumer_account_locator AS STRING)) AS n_distinct_consumer_accounts +FROM ( + SELECT + PARSE_JSON( + '{"query_date": "2022-03-02","query_token": "some_token","consumer_account_locator": "CONSUMER_ACCOUNT_LOCATOR","listing_objects_accessed": [{"objectDomain": "Table","objectName": "DATABASE_NAME.SCHEMA_NAME.TABLE_NAME","columns": [{"columnName": "column1"},{"columnName": "column2"}]}]}' + ) AS lah +) AS src + , LATERAL VARIANT_EXPLODE(src.lah:listing_objects_accessed) AS los + , LATERAL VARIANT_EXPLODE(los.value:columns) AS cols +WHERE + CAST(los.value:objectDomain AS STRING) IN ('Table', 'View') + AND CAST(src.lah:query_date AS DATE) BETWEEN '2022-03-01' AND '2022-04-30' + AND CAST(los.value:objectName AS STRING) = 'DATABASE_NAME.SCHEMA_NAME.TABLE_NAME' + AND CAST(src.lah:consumer_account_locator AS STRING) = 'CONSUMER_ACCOUNT_LOCATOR' +GROUP BY + 1, + 2, + 3; diff --git a/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_1.sql b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_1.sql new file mode 100644 index 0000000000..c2d74214a2 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT OBJECT_CONSTRUCT('a',1,'b','BBBB', 'c',null); + +-- databricks sql: +SELECT STRUCT(1 AS a, 'BBBB' AS b, NULL AS c); diff --git a/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_2.sql b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_2.sql new file mode 100644 index 0000000000..414f2c5b06 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT OBJECT_CONSTRUCT(*) AS oc FROM demo_table_1 ; + +-- databricks sql: +SELECT STRUCT(*) AS oc FROM demo_table_1; diff --git a/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_3.sql b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_3.sql new file mode 100644 index 0000000000..4e5616e26c --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT OBJECT_CONSTRUCT(*) FROM VALUES(1,'x'), (2,'y'); + +-- databricks sql: +SELECT STRUCT(*) FROM VALUES (1, 'x'), (2, 'y'); diff --git a/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_4.sql b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_4.sql new file mode 100644 index 0000000000..732ba4d38d --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/object_construct/test_object_construct_4.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT OBJECT_CONSTRUCT('Key_One', PARSE_JSON('NULL'), 'Key_Two', NULL, 'Key_Three', 'null') as obj; + +-- databricks sql: +SELECT STRUCT(PARSE_JSON('NULL') AS Key_One, NULL AS Key_Two, 'null' AS Key_Three) AS obj; diff --git a/tests/resources/functional/snowflake/ddl/test_cras_simple.sql b/tests/resources/functional/snowflake/ddl/test_cras_simple.sql new file mode 100644 index 0000000000..8e54df800d --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_cras_simple.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +CREATE OR REPLACE TABLE employee as SELECT employee_id, name FROM employee_stage; + +-- databricks sql: +CREATE OR REPLACE TABLE employee as SELECT employee_id, name FROM employee_stage; diff --git a/tests/resources/functional/snowflake/ddl/test_create_ddl_1.sql b/tests/resources/functional/snowflake/ddl/test_create_ddl_1.sql new file mode 100644 index 0000000000..e70a5cc682 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_create_ddl_1.sql @@ -0,0 +1,22 @@ +-- snowflake sql: +CREATE TABLE employee (employee_id INT, + first_name VARCHAR(50) NOT NULL, + last_name VARCHAR(50) NOT NULL, + birth_date DATE, + hire_date DATE, + salary DECIMAL(10, 2), + department_id INT, + remarks VARIANT) +; + +-- databricks sql: +CREATE TABLE employee ( + employee_id DECIMAL(38, 0), + first_name STRING NOT NULL, + last_name STRING NOT NULL, + birth_date DATE, + hire_date DATE, + salary DECIMAL(10, 2), + department_id DECIMAL(38, 0), + remarks VARIANT +); diff --git a/tests/resources/functional/snowflake/ddl/test_create_ddl_2.sql b/tests/resources/functional/snowflake/ddl/test_create_ddl_2.sql new file mode 100644 index 0000000000..a09b654752 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_create_ddl_2.sql @@ -0,0 +1,13 @@ +-- snowflake sql: +CREATE TABLE employee (employee_id INT DEFAULT 3000, + first_name VARCHAR(50) NOT NULL, + last_name VARCHAR(50) NOT NULL + ) +; + +-- databricks sql: +CREATE TABLE employee ( + employee_id DECIMAL(38, 0) DEFAULT 3000, + first_name STRING NOT NULL, + last_name STRING NOT NULL +); diff --git a/tests/resources/functional/snowflake/ddl/test_create_ddl_identity.sql b/tests/resources/functional/snowflake/ddl/test_create_ddl_identity.sql new file mode 100644 index 0000000000..99d2e38e44 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_create_ddl_identity.sql @@ -0,0 +1,20 @@ +-- snowflake sql: + +CREATE OR REPLACE TABLE sales_data ( + sale_id INT AUTOINCREMENT, + product_id INT, + quantity INT, + sale_amount DECIMAL(10, 2), + sale_date DATE, + customer_id INT +); + +-- databricks sql: +CREATE OR REPLACE TABLE sales_data ( + sale_id DECIMAL(38, 0) GENERATED ALWAYS AS IDENTITY, + product_id DECIMAL(38, 0), + quantity DECIMAL(38, 0), + sale_amount DECIMAL(10, 2), + sale_date DATE, + customer_id DECIMAL(38, 0) +); diff --git a/tests/resources/functional/snowflake/ddl/test_ctas_complex.sql b/tests/resources/functional/snowflake/ddl/test_ctas_complex.sql new file mode 100644 index 0000000000..8e6b8c172d --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_ctas_complex.sql @@ -0,0 +1,39 @@ +-- snowflake sql: +CREATE TABLE employee_summary AS +SELECT + e.employee_id, + e.first_name, + e.last_name, + e.salary, + d.department_name, + CASE + WHEN e.salary > 100000 THEN 'High' + WHEN e.salary BETWEEN 50000 AND 100000 THEN 'Medium' + ELSE 'Low' + END AS salary_range, + YEAR(e.hire_date) AS hire_year +FROM + employee e +JOIN + department d ON e.department_id = d.department_id +; + +-- databricks sql: +CREATE TABLE employee_summary AS +SELECT + e.employee_id, + e.first_name, + e.last_name, + e.salary, + d.department_name, + CASE + WHEN e.salary > 100000 + THEN 'High' + WHEN e.salary BETWEEN 50000 AND 100000 + THEN 'Medium' + ELSE 'Low' + END AS salary_range, + YEAR(e.hire_date) AS hire_year +FROM employee AS e +JOIN department AS d + ON e.department_id = d.department_id ; diff --git a/tests/resources/functional/snowflake/ddl/test_ctas_simple.sql b/tests/resources/functional/snowflake/ddl/test_ctas_simple.sql new file mode 100644 index 0000000000..a43fdfcabc --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_ctas_simple.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +CREATE TABLE employee as SELECT employee_id, name FROM employee_stage; + +-- databricks sql: +CREATE TABLE employee as SELECT employee_id, name FROM employee_stage; diff --git a/tests/resources/functional/snowflake/ddl/test_current_database_1.sql b/tests/resources/functional/snowflake/ddl/test_current_database_1.sql new file mode 100644 index 0000000000..05b2019f27 --- /dev/null +++ b/tests/resources/functional/snowflake/ddl/test_current_database_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT current_database() AS current_database_col1 FROM tabl; + +-- databricks sql: +SELECT CURRENT_DATABASE() AS current_database_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/dml/insert/test_insert_1.sql b/tests/resources/functional/snowflake/dml/insert/test_insert_1.sql new file mode 100644 index 0000000000..4d331bda0b --- /dev/null +++ b/tests/resources/functional/snowflake/dml/insert/test_insert_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +INSERT INTO foo VALUES (1, 'bar', true), (1, 'qux', false); + +-- databricks sql: +INSERT INTO foo VALUES (1, 'bar', true), (1, 'qux', false); diff --git a/tests/resources/functional/snowflake/dml/insert/test_insert_2.sql b/tests/resources/functional/snowflake/dml/insert/test_insert_2.sql new file mode 100644 index 0000000000..af8e82be20 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/insert/test_insert_2.sql @@ -0,0 +1,8 @@ + + +-- snowflake sql: +INSERT INTO table1 SELECT * FROM table2; + +-- databricks sql: + +INSERT INTO table1 SELECT * FROM table2; diff --git a/tests/resources/functional/snowflake/dml/insert/test_insert_3.sql b/tests/resources/functional/snowflake/dml/insert/test_insert_3.sql new file mode 100644 index 0000000000..3121b2f279 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/insert/test_insert_3.sql @@ -0,0 +1,17 @@ +-- snowflake sql: +INSERT INTO foo (c1, c2, c3) + SELECT x, y, z FROM bar WHERE x > z AND y = 'qux'; + +-- databricks sql: +INSERT INTO foo ( + c1, + c2, + c3 +) +SELECT + x, + y, + z +FROM bar +WHERE + x > z AND y = 'qux'; diff --git a/tests/resources/functional/snowflake/dml/insert/test_insert_overwrite_1.sql b/tests/resources/functional/snowflake/dml/insert/test_insert_overwrite_1.sql new file mode 100644 index 0000000000..439ccb2570 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/insert/test_insert_overwrite_1.sql @@ -0,0 +1,8 @@ + + +-- snowflake sql: +INSERT OVERWRITE INTO foo VALUES (1, 2, 3); + +-- databricks sql: + +INSERT OVERWRITE TABLE foo VALUES (1, 2, 3); diff --git a/tests/resources/functional/snowflake/dml/test_delete.sql b/tests/resources/functional/snowflake/dml/test_delete.sql new file mode 100644 index 0000000000..e15d214b2d --- /dev/null +++ b/tests/resources/functional/snowflake/dml/test_delete.sql @@ -0,0 +1,6 @@ +-- snowflake sql: + +DELETE FROM t1 WHERE t1.c1 > 42; + +-- databricks sql: +DELETE FROM t1 WHERE t1.c1 > 42; diff --git a/tests/resources/functional/snowflake/dml/test_delete_subquery.sql b/tests/resources/functional/snowflake/dml/test_delete_subquery.sql new file mode 100644 index 0000000000..864912774e --- /dev/null +++ b/tests/resources/functional/snowflake/dml/test_delete_subquery.sql @@ -0,0 +1,13 @@ +-- snowflake sql: +DELETE FROM table1 AS t1 USING (SELECT c2 FROM table2 WHERE t2.c3 = 'foo') AS t2 WHERE t1.c1 = t2.c2; + +-- databricks sql: +MERGE INTO table1 AS t1 USING ( + SELECT + c2 + FROM table2 + WHERE + t2.c3 = 'foo' +) AS t2 +ON + t1.c1 = t2.c2 WHEN MATCHED THEN DELETE; diff --git a/tests/resources/functional/snowflake/dml/test_delete_using.sql b/tests/resources/functional/snowflake/dml/test_delete_using.sql new file mode 100644 index 0000000000..f1982b25d2 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/test_delete_using.sql @@ -0,0 +1,6 @@ +-- snowflake sql: + +DELETE FROM t1 USING t2 WHERE t1.c1 = t2.c2; + +-- databricks sql: +MERGE INTO t1 USING t2 ON t1.c1 = t2.c2 WHEN MATCHED THEN DELETE; diff --git a/tests/resources/functional/snowflake/dml/test_delete_using_where.sql b/tests/resources/functional/snowflake/dml/test_delete_using_where.sql new file mode 100644 index 0000000000..29d6ae2f8f --- /dev/null +++ b/tests/resources/functional/snowflake/dml/test_delete_using_where.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +DELETE FROM table1 USING table2 WHERE table1.id = table2.id; + +-- databricks sql: +MERGE INTO table1 USING table2 ON table1.id = table2.id WHEN MATCHED THEN DELETE; diff --git a/tests/resources/functional/snowflake/dml/test_delete_where.sql b/tests/resources/functional/snowflake/dml/test_delete_where.sql new file mode 100644 index 0000000000..027b22c99b --- /dev/null +++ b/tests/resources/functional/snowflake/dml/test_delete_where.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +DELETE FROM employee WHERE employee_id = 1; + +-- databricks sql: +DELETE FROM employee WHERE employee_id = 1; diff --git a/tests/resources/functional/snowflake/dml/update/test_update_from_dml_1.sql b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_1.sql new file mode 100644 index 0000000000..214826f50b --- /dev/null +++ b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_1.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: +UPDATE t1 + SET column1 = t1.column1 + t2.column1, column3 = 'success' + FROM t2 + WHERE t1.key = t2.t1_key and t1.column1 < 10; + +-- databricks sql: +MERGE INTO t1 USING t2 ON t1.key = t2.t1_key and t1.column1 < 10 WHEN MATCHED THEN UPDATE SET column1 = t1.column1 + t2.column1, +column3 = 'success'; diff --git a/tests/resources/functional/snowflake/dml/update/test_update_from_dml_2.sql b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_2.sql new file mode 100644 index 0000000000..b29fad59d5 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_2.sql @@ -0,0 +1,18 @@ +-- snowflake sql: +UPDATE target +SET v = b.v +FROM (SELECT k, MIN(v) v FROM src GROUP BY k) b +WHERE target.k = b.k; + +-- databricks sql: +MERGE INTO target +USING ( + SELECT + k, + MIN(v) AS v + FROM src + GROUP BY + k +) AS b +ON + target.k = b.k WHEN MATCHED THEN UPDATE SET v = b.v; diff --git a/tests/resources/functional/snowflake/dml/update/test_update_from_dml_3.sql b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_3.sql new file mode 100644 index 0000000000..ad87606ed7 --- /dev/null +++ b/tests/resources/functional/snowflake/dml/update/test_update_from_dml_3.sql @@ -0,0 +1,16 @@ +-- snowflake sql: +UPDATE orders t1 +SET order_status = 'returned' +WHERE EXISTS (SELECT oid FROM returned_orders WHERE t1.oid = oid); + +-- databricks sql: +UPDATE orders AS t1 SET order_status = 'returned' +WHERE + EXISTS( + SELECT + oid + FROM returned_orders + WHERE + t1.oid = oid + ) +; diff --git a/tests/resources/functional/snowflake/functions/conversion/is_integer_1.sql b/tests/resources/functional/snowflake/functions/conversion/is_integer_1.sql new file mode 100644 index 0000000000..f4b6c17e44 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/is_integer_1.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +select IS_INTEGER(col); + +-- databricks sql: +SELECT + + CASE + WHEN col IS NULL THEN NULL + WHEN col RLIKE '^-?[0-9]+$' AND TRY_CAST(col AS INT) IS NOT NULL THEN true + ELSE false + END; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_1.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_1.sql new file mode 100644 index 0000000000..e25a88d647 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_DECIMAL('$345', '$999.00') AS col1; + +-- databricks sql: +SELECT TO_NUMBER('$345', '$999.00') AS col1; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_2.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_2.sql new file mode 100644 index 0000000000..5e72502446 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_NUMERIC('$345', '$999.99') AS num; + +-- databricks sql: +SELECT TO_NUMBER('$345', '$999.99') AS num; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_3.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_3.sql new file mode 100644 index 0000000000..e2188339b2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_NUMBER('$345', '$999.99') AS num; + +-- databricks sql: +SELECT TO_NUMBER('$345', '$999.99') AS num; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_4.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_4.sql new file mode 100644 index 0000000000..0b0669edf4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_4.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT TO_DECIMAL(col1, '$999.099'), + TO_NUMERIC(tbl.col2, '$999,099.99') FROM dummy tbl; + +-- databricks sql: +SELECT TO_NUMBER(col1, '$999.099'), + TO_NUMBER(tbl.col2, '$999,099.99') FROM dummy AS tbl; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_5.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_5.sql new file mode 100644 index 0000000000..57e069837b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_NUMERIC('$345', '$999.99', 5, 2) AS num_with_scale; + +-- databricks sql: +SELECT CAST(TO_NUMBER('$345', '$999.99') AS DECIMAL(5, 2)) AS num_with_scale; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_6.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_6.sql new file mode 100644 index 0000000000..fddd8f67d4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_6.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_DECIMAL('$755', '$999.00', 15, 5) AS num_with_scale; + +-- databricks sql: +SELECT CAST(TO_NUMBER('$755', '$999.00') AS DECIMAL(15, 5)) AS num_with_scale; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_7.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_7.sql new file mode 100644 index 0000000000..7e8a07ab1c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_7.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT TO_NUMERIC(sm.col1, '$999.00', 15, 5) AS col1, + TO_NUMBER(sm.col2, '$99.00', 15, 5) AS col2 FROM sales_reports sm; + +-- databricks sql: +SELECT CAST(TO_NUMBER(sm.col1, '$999.00') AS DECIMAL(15, 5)) AS col1, + CAST(TO_NUMBER(sm.col2, '$99.00') AS DECIMAL(15, 5)) AS col2 FROM sales_reports AS sm; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_8.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_8.sql new file mode 100644 index 0000000000..4868208c75 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_8.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_NUMERIC(col1, 15, 5) AS col1 FROM sales_reports; + +-- databricks sql: +SELECT CAST(col1 AS DECIMAL(15, 5)) AS col1 FROM sales_reports; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_9.sql b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_9.sql new file mode 100644 index 0000000000..b7cf06e0fa --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_to_number/test_to_number_9.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select TO_NUMBER(EXPR) from test_tbl; + +-- databricks sql: +SELECT CAST(EXPR AS DECIMAL(38, 0)) FROM test_tbl; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_1.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_1.sql new file mode 100644 index 0000000000..2f2bc3a077 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT try_cast('10' AS INT); + +-- databricks sql: +SELECT TRY_CAST('10' AS DECIMAL(38, 0)); diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_2.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_2.sql new file mode 100644 index 0000000000..9cbc141123 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_cast/test_try_cast_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT try_cast(col1 AS FLOAT) AS try_cast_col1 FROM tabl; + +-- databricks sql: +SELECT TRY_CAST(col1 AS DOUBLE) AS try_cast_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_date/test_try_to_date_2.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_date/test_try_to_date_2.sql new file mode 100644 index 0000000000..dd534dd1ba --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_date/test_try_to_date_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_DATE('2023-25-09', 'yyyy-dd-MM'); + +-- databricks sql: +SELECT DATE(TRY_TO_TIMESTAMP('2023-25-09', 'yyyy-dd-MM')); diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_1.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_1.sql new file mode 100644 index 0000000000..9b51202dab --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_DECIMAL('$345', '$999.00') AS col1; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER('$345', '$999.00') AS DECIMAL(38, 0)) AS col1; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_2.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_2.sql new file mode 100644 index 0000000000..9c3d1628d0 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_NUMERIC('$345', '$999.99') AS num; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER('$345', '$999.99') AS DECIMAL(38, 0)) AS num; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_3.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_3.sql new file mode 100644 index 0000000000..91e8b9bfe5 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_NUMBER('$345', '$999.99') AS num; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER('$345', '$999.99') AS DECIMAL(38, 0)) AS num; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_4.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_4.sql new file mode 100644 index 0000000000..d83096118b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_4.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT TRY_TO_DECIMAL(col1, '$999.099'), + TRY_TO_NUMERIC(tbl.col2, '$999,099.99') FROM dummy tbl; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER(col1, '$999.099') AS DECIMAL(38, 0)), + CAST(TRY_TO_NUMBER(tbl.col2, '$999,099.99') AS DECIMAL(38, 0)) FROM dummy AS tbl; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_5.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_5.sql new file mode 100644 index 0000000000..db5d337473 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_NUMERIC('$345', '$999.99', 5, 2) AS num_with_scale; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER('$345', '$999.99') AS DECIMAL(5, 2)) AS num_with_scale; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_6.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_6.sql new file mode 100644 index 0000000000..f9912a5d23 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_6.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_DECIMAL('$755', '$999.00', 15, 5) AS num_with_scale; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER('$755', '$999.00') AS DECIMAL(15, 5)) AS num_with_scale; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_7.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_7.sql new file mode 100644 index 0000000000..3fd9563fa1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_7.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT TRY_TO_NUMERIC(sm.col1, '$999.00', 15, 5) AS col1, + TRY_TO_NUMBER(sm.col2, '$99.00', 15, 5) AS col2 FROM sales_reports sm; + +-- databricks sql: +SELECT CAST(TRY_TO_NUMBER(sm.col1, '$999.00') AS DECIMAL(15, 5)) AS col1, + CAST(TRY_TO_NUMBER(sm.col2, '$99.00') AS DECIMAL(15, 5)) AS col2 FROM sales_reports AS sm; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_8.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_8.sql new file mode 100644 index 0000000000..7a5078671c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_number/test_try_to_number_8.sql @@ -0,0 +1,12 @@ + +-- snowflake sql: +SELECT TRY_TO_DECIMAL('$345') AS str_col, + TRY_TO_DECIMAL(99.56854634) AS num_col, + TRY_TO_DECIMAL(-4.35) AS num_col1, + TRY_TO_DECIMAL(col1) AS col1; + +-- databricks sql: +SELECT CAST('$345' AS DECIMAL(38, 0)) AS str_col, + CAST(99.56854634 AS DECIMAL(38, 0)) AS num_col, + CAST(-4.35 AS DECIMAL(38, 0)) AS num_col1, + CAST(col1 AS DECIMAL(38, 0)) AS col1; diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_1.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_1.sql new file mode 100644 index 0000000000..896a28b998 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_TIMESTAMP('2016-12-31 00:12:00'); + +-- databricks sql: +SELECT TRY_TO_TIMESTAMP('2016-12-31 00:12:00'); diff --git a/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_2.sql b/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_2.sql new file mode 100644 index 0000000000..28ebc6f5fa --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/test_try_to_timestamp/test_try_to_timestamp_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TRY_TO_TIMESTAMP('2018-05-15', 'yyyy-MM-dd'); + +-- databricks sql: +SELECT TRY_TO_TIMESTAMP('2018-05-15', 'yyyy-MM-dd'); diff --git a/tests/resources/functional/snowflake/functions/conversion/to_array_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_array_1.sql new file mode 100644 index 0000000000..f90f771f56 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_array_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT to_array(col1) AS ary_col; + +-- databricks sql: +SELECT IF(col1 IS NULL, NULL, ARRAY(col1)) AS ary_col; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_array_2.sql b/tests/resources/functional/snowflake/functions/conversion/to_array_2.sql new file mode 100644 index 0000000000..9cd615ee63 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_array_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT to_array(col1,'STRING') AS ary_col; + +-- databricks sql: +SELECT IF(col1 IS NULL, NULL, ARRAY(col1)) AS ary_col; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_to_boolean_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_to_boolean_1.sql new file mode 100644 index 0000000000..02666a99a6 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_to_boolean_1.sql @@ -0,0 +1,24 @@ + +-- snowflake sql: +select TO_BOOLEAN(col1); + +-- databricks sql: +SELECT + + CASE + WHEN col1 IS NULL THEN NULL + WHEN TYPEOF(col1) = 'boolean' THEN BOOLEAN(col1) + WHEN TYPEOF(col1) = 'string' THEN + CASE + WHEN LOWER(col1) IN ('true', 't', 'yes', 'y', 'on', '1') THEN true + WHEN LOWER(col1) IN ('false', 'f', 'no', 'n', 'off', '0') THEN false + ELSE RAISE_ERROR('Boolean value of x is not recognized by TO_BOOLEAN') + END + WHEN TRY_CAST(col1 AS DOUBLE) IS NOT NULL THEN + CASE + WHEN ISNAN(CAST(col1 AS DOUBLE)) OR CAST(col1 AS DOUBLE) = DOUBLE('infinity') THEN + RAISE_ERROR('Invalid parameter type for TO_BOOLEAN') + ELSE CAST(col1 AS DOUBLE) != 0.0 + END + ELSE RAISE_ERROR('Invalid parameter type for TO_BOOLEAN') + END; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_try_to_boolean_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_try_to_boolean_1.sql new file mode 100644 index 0000000000..ae7fae678b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_boolean/test_try_to_boolean_1.sql @@ -0,0 +1,23 @@ +-- snowflake sql: +select TRY_TO_BOOLEAN(1); + +-- databricks sql: +SELECT + + CASE + WHEN 1 IS NULL THEN NULL + WHEN TYPEOF(1) = 'boolean' THEN BOOLEAN(1) + WHEN TYPEOF(1) = 'string' THEN + CASE + WHEN LOWER(1) IN ('true', 't', 'yes', 'y', 'on', '1') THEN true + WHEN LOWER(1) IN ('false', 'f', 'no', 'n', 'off', '0') THEN false + ELSE RAISE_ERROR('Boolean value of x is not recognized by TO_BOOLEAN') + END + WHEN TRY_CAST(1 AS DOUBLE) IS NOT NULL THEN + CASE + WHEN ISNAN(CAST(1 AS DOUBLE)) OR CAST(1 AS DOUBLE) = DOUBLE('infinity') THEN + RAISE_ERROR('Invalid parameter type for TO_BOOLEAN') + ELSE CAST(1 AS DOUBLE) != 0.0 + END + ELSE NULL + END; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_double_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_double_1.sql new file mode 100644 index 0000000000..2096205ec2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_double_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_DOUBLE('HELLO'); + +-- databricks sql: +SELECT DOUBLE('HELLO'); diff --git a/tests/resources/functional/snowflake/functions/conversion/to_json_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_json_1.sql new file mode 100644 index 0000000000..dbfb3cc092 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_json_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT to_json(col1) AS to_json_col1 FROM tabl; + +-- databricks sql: +SELECT TO_JSON(col1) AS to_json_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_object_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_object_1.sql new file mode 100644 index 0000000000..d265374e0e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_object_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT to_object(k) FROM tabl; + +-- databricks sql: +SELECT TO_JSON(k) FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_rlike_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_rlike_1.sql new file mode 100644 index 0000000000..e86289a27e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_rlike_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT RLIKE('800-456-7891','[2-9]d{2}-d{3}-d{4}') AS matches_phone_number;; + +-- databricks sql: +SELECT '800-456-7891' RLIKE '[2-9]d{2}-d{3}-d{4}' AS matches_phone_number; diff --git a/tests/resources/functional/snowflake/functions/conversion/to_variant_1.sql b/tests/resources/functional/snowflake/functions/conversion/to_variant_1.sql new file mode 100644 index 0000000000..c4b285e196 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/to_variant_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT to_variant(col1) AS json_col1 FROM dummy; + +-- databricks sql: +SELECT TO_JSON(col1) AS json_col1 FROM dummy; diff --git a/tests/resources/functional/snowflake/functions/conversion/tochar_1.sql b/tests/resources/functional/snowflake/functions/conversion/tochar_1.sql new file mode 100644 index 0000000000..5caad95910 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/tochar_1.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +select to_char(column1, '">"$99.0"<"') as D2_1, + to_char(column1, '">"B9,999.0"<"') as D4_1 FROM table; + +-- databricks sql: +SELECT TO_CHAR(column1, '">"$99.0"<"') AS D2_1, TO_CHAR(column1, '">"B9,999.0"<"') AS D4_1 FROM table; diff --git a/tests/resources/functional/snowflake/functions/conversion/zeroifnull_1.sql b/tests/resources/functional/snowflake/functions/conversion/zeroifnull_1.sql new file mode 100644 index 0000000000..77ea0282c9 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/conversion/zeroifnull_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT zeroifnull(col1) AS pcol1 FROM tabl; + +-- databricks sql: +SELECT IF(col1 IS NULL, 0, col1) AS pcol1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_1.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_1.sql new file mode 100644 index 0000000000..3d1ce22883 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_1.sql @@ -0,0 +1,10 @@ +-- snowflake sql: +SELECT datediff(yrs, TIMESTAMP'2021-02-28 12:00:00', TIMESTAMP'2021-03-28 12:00:00'); + +-- databricks sql: +SELECT + DATEDIFF( + year, + CAST('2021-02-28 12:00:00' AS TIMESTAMP), + CAST('2021-03-28 12:00:00' AS TIMESTAMP) + ); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_10.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_10.sql new file mode 100644 index 0000000000..8daaa2ddfc --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_10.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff(quarters, 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(quarter, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_11.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_11.sql new file mode 100644 index 0000000000..fa81a4b7c3 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_11.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT DATEDIFF('DAY', start, end); + +-- databricks sql: +SELECT DATEDIFF(day, start, end); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_2.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_2.sql new file mode 100644 index 0000000000..a394de9787 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_2.sql @@ -0,0 +1,10 @@ +-- snowflake sql: +SELECT datediff(years, TIMESTAMP'2021-02-28 12:00:00', TIMESTAMP'2021-03-28 12:00:00'); + +-- databricks sql: +SELECT + DATEDIFF( + year, + CAST('2021-02-28 12:00:00' AS TIMESTAMP), + CAST('2021-03-28 12:00:00' AS TIMESTAMP) + ); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_3.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_3.sql new file mode 100644 index 0000000000..d47a9b2d25 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff(mm, DATE'2021-02-28', DATE'2021-03-28'); + +-- databricks sql: +SELECT DATEDIFF(month, CAST('2021-02-28' AS DATE), CAST('2021-03-28' AS DATE)); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_4.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_4.sql new file mode 100644 index 0000000000..a3a89f2f3d --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff(mons, DATE'2021-02-28', DATE'2021-03-28'); + +-- databricks sql: +SELECT DATEDIFF(month, CAST('2021-02-28' AS DATE), CAST('2021-03-28' AS DATE)); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_5.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_5.sql new file mode 100644 index 0000000000..37c2682f52 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff('days', 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(day, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_6.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_6.sql new file mode 100644 index 0000000000..a4000e0e00 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_6.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff(dayofmonth, 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(day, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_7.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_7.sql new file mode 100644 index 0000000000..c86a88c0de --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_7.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff(wk, 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(week, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_8.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_8.sql new file mode 100644 index 0000000000..580d94cf0e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_8.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff('woy', 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(week, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_9.sql b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_9.sql new file mode 100644 index 0000000000..dac0c29aa3 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/datediff/test_datediff_9.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT datediff('qtrs', 'start', 'end'); + +-- databricks sql: +SELECT DATEDIFF(quarter, 'start', 'end'); diff --git a/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_2.sql b/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_2.sql new file mode 100644 index 0000000000..de305e0e7f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_2.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT DAYNAME(TO_DATE('2015-05-01')) AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT(cast('2015-05-01' as DATE), 'E') AS MONTH; diff --git a/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_3.sql b/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_3.sql new file mode 100644 index 0000000000..f8c8cbdccf --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/dayname/test_dayname_3.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT DAYNAME('2015-04-03 10:00') AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT('2015-04-03 10:00', 'E') AS MONTH; diff --git a/tests/resources/functional/snowflake/functions/dates/last_day_1.sql b/tests/resources/functional/snowflake/functions/dates/last_day_1.sql new file mode 100644 index 0000000000..75fbc38012 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/last_day_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT last_day(col1) AS last_day_col1 FROM tabl; + +-- databricks sql: +SELECT LAST_DAY(col1) AS last_day_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_2.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_2.sql new file mode 100644 index 0000000000..733de222bd --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_2.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT MONTHNAME(TO_DATE('2015-05-01')) AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT(cast('2015-05-01' as DATE), 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_3.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_3.sql new file mode 100644 index 0000000000..f609a2dc1e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_3.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT MONTHNAME(TO_DATE('2020-01-01')) AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT(cast('2020-01-01' as DATE), 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_4.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_4.sql new file mode 100644 index 0000000000..c78d6e4949 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_4.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT MONTHNAME('2015-04-03 10:00') AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT('2015-04-03 10:00', 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_5.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_5.sql new file mode 100644 index 0000000000..68f41a3b29 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_5.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT d, MONTHNAME(d) FROM dates; + +-- databricks sql: +SELECT d, DATE_FORMAT(d, 'MMM') FROM dates; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_6.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_6.sql new file mode 100644 index 0000000000..22a9025430 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_6.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT MONTHNAME('2015-03-04') AS MON; + +-- databricks sql: +SELECT DATE_FORMAT('2015-03-04', 'MMM') AS MON; diff --git a/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_7.sql b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_7.sql new file mode 100644 index 0000000000..38b1b0a493 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/monthname/test_monthname_7.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT TO_DATE('2015.03.04', 'yyyy.dd.MM') AS MON; + +-- databricks sql: +SELECT TO_DATE('2015.03.04', 'yyyy.dd.MM') AS MON; diff --git a/tests/resources/functional/snowflake/functions/dates/next_day_1.sql b/tests/resources/functional/snowflake/functions/dates/next_day_1.sql new file mode 100644 index 0000000000..3f40beb137 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/next_day_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT next_day('2015-01-14', 'TU') AS next_day_col1 FROM tabl; + +-- databricks sql: +SELECT NEXT_DAY('2015-01-14', 'TU') AS next_day_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/test_add_months_1.sql b/tests/resources/functional/snowflake/functions/dates/test_add_months_1.sql new file mode 100644 index 0000000000..6db43df302 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_add_months_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT add_months(col1,1) AS add_months_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 1) AS add_months_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_1.sql b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_1.sql new file mode 100644 index 0000000000..8f9314728a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_1.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT + CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2019-01-01 14:00:00'::timestamp_ntz) + AS conv; + +-- databricks sql: +SELECT CONVERT_TIMEZONE( 'America/Los_Angeles', 'America/New_York', CAST('2019-01-01 14:00:00' AS TIMESTAMP_NTZ) ) AS conv; diff --git a/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_2.sql b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_2.sql new file mode 100644 index 0000000000..17dae6393b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_2.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +SELECT CONVERT_TIMEZONE('America/Los_Angeles', '2018-04-05 12:00:00 +02:00') + AS conv; + +-- databricks sql: +SELECT CONVERT_TIMEZONE('America/Los_Angeles', '2018-04-05 12:00:00 +02:00') AS conv; diff --git a/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_3.sql b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_3.sql new file mode 100644 index 0000000000..14dfcfde75 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT a.col1, CONVERT_TIMEZONE('IST', a.ts_col) AS conv_ts FROM dummy a; + +-- databricks sql: +SELECT a.col1, CONVERT_TIMEZONE('IST', a.ts_col) AS conv_ts FROM dummy AS a; diff --git a/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_4.sql b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_4.sql new file mode 100644 index 0000000000..5815939572 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_4.sql @@ -0,0 +1,12 @@ + +-- snowflake sql: +SELECT CURRENT_TIMESTAMP() AS now_in_la, + CONVERT_TIMEZONE('America/New_York', CURRENT_TIMESTAMP()) AS now_in_nyc, + CONVERT_TIMEZONE('Europe/Paris', CURRENT_TIMESTAMP()) AS now_in_paris, + CONVERT_TIMEZONE('Asia/Tokyo', CURRENT_TIMESTAMP()) AS now_in_tokyo; + +-- databricks sql: +SELECT + CURRENT_TIMESTAMP() AS now_in_la, CONVERT_TIMEZONE('America/New_York', CURRENT_TIMESTAMP()) AS now_in_nyc, + CONVERT_TIMEZONE('Europe/Paris', CURRENT_TIMESTAMP()) AS now_in_paris, + CONVERT_TIMEZONE('Asia/Tokyo', CURRENT_TIMESTAMP()) AS now_in_tokyo; diff --git a/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_5.sql b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_5.sql new file mode 100644 index 0000000000..2c69a98996 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_convert_timezone_5.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +SELECT + CONVERT_TIMEZONE('Europe/Warsaw', 'UTC', '2019-01-01 00:00:00 +03:00'::timestamp_ntz); + +-- databricks sql: +SELECT CONVERT_TIMEZONE('Europe/Warsaw', 'UTC', CAST('2019-01-01 00:00:00 +03:00' AS TIMESTAMP_NTZ)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_current_timestamp_1.sql b/tests/resources/functional/snowflake/functions/dates/test_current_timestamp_1.sql new file mode 100644 index 0000000000..59367d9b26 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_current_timestamp_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT current_timestamp() AS current_timestamp_col1 FROM tabl; + +-- databricks sql: +SELECT CURRENT_TIMESTAMP() AS current_timestamp_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_1.sql b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_1.sql new file mode 100644 index 0000000000..9e8ec5056c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select date_from_parts(1992, 6, 1); + +-- databricks sql: +SELECT MAKE_DATE(1992, 6, 1); diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_2.sql b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_2.sql new file mode 100644 index 0000000000..e513b065ec --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select date_from_parts(2023, 10, 3), date_from_parts(2020, 4, 4); + +-- databricks sql: +SELECT MAKE_DATE(2023, 10, 3), MAKE_DATE(2020, 4, 4); diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_3.sql b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_3.sql new file mode 100644 index 0000000000..c8bb8025c2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_from_parts_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select datefromparts(2023, 10, 3), datefromparts(2020, 4, 4); + +-- databricks sql: +SELECT MAKE_DATE(2023, 10, 3), MAKE_DATE(2020, 4, 4); diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_part_1.sql b/tests/resources/functional/snowflake/functions/dates/test_date_part_1.sql new file mode 100644 index 0000000000..155237f795 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_part_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT date_part('seconds', col1) AS date_part_col1 FROM tabl; + +-- databricks sql: +SELECT EXTRACT(second FROM col1) AS date_part_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_trunc_1.sql b/tests/resources/functional/snowflake/functions/dates/test_date_trunc_1.sql new file mode 100644 index 0000000000..033f3c0d52 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_trunc_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359'); + +-- databricks sql: +SELECT DATE_TRUNC('YEAR', '2015-03-05T09:32:05.359'); diff --git a/tests/resources/functional/snowflake/functions/dates/test_date_trunc_4.sql b/tests/resources/functional/snowflake/functions/dates/test_date_trunc_4.sql new file mode 100644 index 0000000000..c4d59631d8 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_date_trunc_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT date_trunc('MM', col1) AS date_trunc_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_TRUNC('MONTH', col1) AS date_trunc_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/dates/test_dateadd_1.sql b/tests/resources/functional/snowflake/functions/dates/test_dateadd_1.sql new file mode 100644 index 0000000000..ac72b82ebe --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_dateadd_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select dateadd('day', 3, '2020-02-03'::date); + +-- databricks sql: +SELECT DATEADD(day, 3, CAST('2020-02-03' AS DATE)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_1.sql b/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_1.sql new file mode 100644 index 0000000000..b4c3940d5f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timestamp_from_parts(1992, 6, 1, 12, 35, 12); + +-- databricks sql: +SELECT MAKE_TIMESTAMP(1992, 6, 1, 12, 35, 12); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_2.sql b/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_2.sql new file mode 100644 index 0000000000..9d97ace1e9 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestamp_from_parts/test_timestamp_from_parts_2.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +select TIMESTAMP_FROM_PARTS(2023, 10, 3, 14, 10, 45), + timestamp_from_parts(2020, 4, 4, 4, 5, 6); + +-- databricks sql: +SELECT MAKE_TIMESTAMP(2023, 10, 3, 14, 10, 45), MAKE_TIMESTAMP(2020, 4, 4, 4, 5, 6); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_1.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_1.sql new file mode 100644 index 0000000000..ee0aa11301 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_1.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +SELECT timestampadd('hour', -1, bp.ts) AND timestampadd('day', 2, bp.ts) + FROM base_prep AS bp; + +-- databricks sql: +SELECT DATEADD(hour, -1, bp.ts) AND DATEADD(day, 2, bp.ts) FROM base_prep AS bp; diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_2.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_2.sql new file mode 100644 index 0000000000..ac72b82ebe --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select dateadd('day', 3, '2020-02-03'::date); + +-- databricks sql: +SELECT DATEADD(day, 3, CAST('2020-02-03' AS DATE)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_3.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_3.sql new file mode 100644 index 0000000000..83ad689d82 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timestampadd(year, -3, '2023-02-03'); + +-- databricks sql: +SELECT DATEADD(year, -3, '2023-02-03'); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_4.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_4.sql new file mode 100644 index 0000000000..1dadb209c5 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timestampadd(year, -3, '2023-02-03 01:02'::timestamp); + +-- databricks sql: +SELECT DATEADD(year, -3, CAST('2023-02-03 01:02' AS TIMESTAMP)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_5.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_5.sql new file mode 100644 index 0000000000..bb738e6802 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampadd/test_timestampadd_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timeadd(year, -3, '2023-02-03 01:02'::timestamp); + +-- databricks sql: +SELECT DATEADD(year, -3, CAST('2023-02-03 01:02' AS TIMESTAMP)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_1.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_1.sql new file mode 100644 index 0000000000..a5b6db5947 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timestampDIFF(month, '2021-01-01'::timestamp, '2021-02-28'::timestamp); + +-- databricks sql: +SELECT DATEDIFF(month, CAST('2021-01-01' AS TIMESTAMP), CAST('2021-02-28' AS TIMESTAMP)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_2.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_2.sql new file mode 100644 index 0000000000..9bb85ed62a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select timestampDIFF('month', '2021-01-01'::timestamp, '2021-02-28'::timestamp); + +-- databricks sql: +SELECT DATEDIFF(month, CAST('2021-01-01' AS TIMESTAMP), CAST('2021-02-28' AS TIMESTAMP)); diff --git a/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_3.sql b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_3.sql new file mode 100644 index 0000000000..60e45aefc1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/dates/test_timestampdiff/test_timestampdiff_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select datediff('day', '2020-02-03'::timestamp, '2023-10-26'::timestamp); + +-- databricks sql: +SELECT DATEDIFF(day, CAST('2020-02-03' AS TIMESTAMP), CAST('2023-10-26' AS TIMESTAMP)); diff --git a/tests/resources/functional/snowflake/functions/initcap_1.sql b/tests/resources/functional/snowflake/functions/initcap_1.sql new file mode 100644 index 0000000000..10543b57b4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/initcap_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT initcap(col1) AS initcap_col1 FROM tabl; + +-- databricks sql: +SELECT INITCAP(col1) AS initcap_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/left_1.sql b/tests/resources/functional/snowflake/functions/left_1.sql new file mode 100644 index 0000000000..559faad13b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/left_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT left(col1, 3) AS left_col1 FROM tabl; + +-- databricks sql: +SELECT LEFT(col1, 3) AS left_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/ln_1.sql b/tests/resources/functional/snowflake/functions/math/ln_1.sql new file mode 100644 index 0000000000..25d813daff --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/ln_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ln(col1) AS ln_col1 FROM tabl; + +-- databricks sql: +SELECT LN(col1) AS ln_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/log_1.sql b/tests/resources/functional/snowflake/functions/math/log_1.sql new file mode 100644 index 0000000000..fbc1fc4689 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/log_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT log(x, y) AS log_col1 FROM tabl; + +-- databricks sql: +SELECT LOG(x, y) AS log_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/mod_1.sql b/tests/resources/functional/snowflake/functions/math/mod_1.sql new file mode 100644 index 0000000000..e300d4a2ec --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/mod_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT mod(col1, col2) AS mod_col1 FROM tabl; + +-- databricks sql: +SELECT MOD(col1, col2) AS mod_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/mode_1.sql b/tests/resources/functional/snowflake/functions/math/mode_1.sql new file mode 100644 index 0000000000..3a37001f70 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/mode_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT mode(col1) AS mode_col1 FROM tabl; + +-- databricks sql: +SELECT MODE(col1) AS mode_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/pi_1.sql b/tests/resources/functional/snowflake/functions/math/pi_1.sql new file mode 100644 index 0000000000..79e819d41a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/pi_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT PI() AS pi_col1 FROM tabl; + +-- databricks sql: +SELECT PI() AS pi_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/radians_1.sql b/tests/resources/functional/snowflake/functions/math/radians_1.sql new file mode 100644 index 0000000000..fe7b0dbbbd --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/radians_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT radians(col1) AS radians_col1 FROM tabl; + +-- databricks sql: +SELECT RADIANS(col1) AS radians_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/random_1.sql b/tests/resources/functional/snowflake/functions/math/random_1.sql new file mode 100644 index 0000000000..c557a8f0cd --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/random_1.sql @@ -0,0 +1,9 @@ +-- TODO: Fix this test, it's currently incorrect because RANDOM() can't be simply passed through as-is. +-- See: https://github.com/databrickslabs/remorph/issues/1280 +-- Reference: https://docs.snowflake.com/en/sql-reference/functions/random.html + +-- snowflake sql: +SELECT random(), random(col1) FROM tabl; + +-- databricks sql: +SELECT RANDOM(), RANDOM(col1) FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/round_1.sql b/tests/resources/functional/snowflake/functions/math/round_1.sql new file mode 100644 index 0000000000..7ab6b491b2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/round_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT round(100.123) AS rounded; + +-- databricks sql: +SELECT ROUND(100.123) AS rounded; diff --git a/tests/resources/functional/snowflake/functions/math/sign_1.sql b/tests/resources/functional/snowflake/functions/math/sign_1.sql new file mode 100644 index 0000000000..9d4dc7df56 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/sign_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT sign(col1) AS sign_col1 FROM tabl; + +-- databricks sql: +SELECT SIGN(col1) AS sign_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/sin_1.sql b/tests/resources/functional/snowflake/functions/math/sin_1.sql new file mode 100644 index 0000000000..50308dd23f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/sin_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT sin(col1) AS sin_col1 FROM tabl; + +-- databricks sql: +SELECT SIN(col1) AS sin_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/sqrt_1.sql b/tests/resources/functional/snowflake/functions/math/sqrt_1.sql new file mode 100644 index 0000000000..3831cd6415 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/sqrt_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT sqrt(col1) AS sqrt_col1 FROM tabl; + +-- databricks sql: +SELECT SQRT(col1) AS sqrt_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/square/test_square_1.sql b/tests/resources/functional/snowflake/functions/math/square/test_square_1.sql new file mode 100644 index 0000000000..3c6303e28c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/square/test_square_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select square(1); + +-- databricks sql: +SELECT POWER(1, 2); diff --git a/tests/resources/functional/snowflake/functions/math/square/test_square_2.sql b/tests/resources/functional/snowflake/functions/math/square/test_square_2.sql new file mode 100644 index 0000000000..8cc9f679ef --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/square/test_square_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select square(-2); + +-- databricks sql: +SELECT POWER(-2, 2); diff --git a/tests/resources/functional/snowflake/functions/math/square/test_square_3.sql b/tests/resources/functional/snowflake/functions/math/square/test_square_3.sql new file mode 100644 index 0000000000..87518ce743 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/square/test_square_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select SQUARE(3.15); + +-- databricks sql: +SELECT POWER(3.15, 2); diff --git a/tests/resources/functional/snowflake/functions/math/square/test_square_4.sql b/tests/resources/functional/snowflake/functions/math/square/test_square_4.sql new file mode 100644 index 0000000000..a0c2f273a0 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/square/test_square_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select SQUARE(null); + +-- databricks sql: +SELECT POWER(NULL, 2); diff --git a/tests/resources/functional/snowflake/functions/math/sum_1.sql b/tests/resources/functional/snowflake/functions/math/sum_1.sql new file mode 100644 index 0000000000..da25c203ff --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/sum_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT sum(col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT SUM(col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_abs_1.sql b/tests/resources/functional/snowflake/functions/math/test_abs_1.sql new file mode 100644 index 0000000000..579e1561f1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_abs_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT abs(col1) AS abs_col1 FROM tabl; + +-- databricks sql: +SELECT ABS(col1) AS abs_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_asin_1.sql b/tests/resources/functional/snowflake/functions/math/test_asin_1.sql new file mode 100644 index 0000000000..8098cb3df2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_asin_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT asin(col1) AS asin_col1 FROM tabl; + +-- databricks sql: +SELECT ASIN(col1) AS asin_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_atan2_1.sql b/tests/resources/functional/snowflake/functions/math/test_atan2_1.sql new file mode 100644 index 0000000000..1d0b9d49f4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_atan2_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT atan2(exprY, exprX) AS atan2_col1 FROM tabl; + +-- databricks sql: +SELECT ATAN2(exprY, exprX) AS atan2_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_ceil_1.sql b/tests/resources/functional/snowflake/functions/math/test_ceil_1.sql new file mode 100644 index 0000000000..63e10076ff --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_ceil_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ceil(col1) AS ceil_col1 FROM tabl; + +-- databricks sql: +SELECT CEIL(col1) AS ceil_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_cos_1.sql b/tests/resources/functional/snowflake/functions/math/test_cos_1.sql new file mode 100644 index 0000000000..630ee8c559 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_cos_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT cos(col1) AS cos_col1 FROM tabl; + +-- databricks sql: +SELECT COS(col1) AS cos_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_div0_1.sql b/tests/resources/functional/snowflake/functions/math/test_div0_1.sql new file mode 100644 index 0000000000..285a550667 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_div0_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT DIV0(a, b); + +-- databricks sql: +SELECT IF(b = 0, 0, a / b); diff --git a/tests/resources/functional/snowflake/functions/math/test_div0null_1.sql b/tests/resources/functional/snowflake/functions/math/test_div0null_1.sql new file mode 100644 index 0000000000..5a41058989 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_div0null_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT DIV0NULL(a, b); + +-- databricks sql: +SELECT IF(b = 0 OR b IS NULL, 0, a / b); diff --git a/tests/resources/functional/snowflake/functions/math/test_exp_1.sql b/tests/resources/functional/snowflake/functions/math/test_exp_1.sql new file mode 100644 index 0000000000..c62778ec7c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_exp_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT exp(col1) AS exp_col1 FROM tabl; + +-- databricks sql: +SELECT EXP(col1) AS exp_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/math/test_floor_1.sql b/tests/resources/functional/snowflake/functions/math/test_floor_1.sql new file mode 100644 index 0000000000..780a24f56b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/math/test_floor_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT floor(col1) AS floor_col1 FROM tabl; + +-- databricks sql: +SELECT FLOOR(col1) AS floor_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_1.sql b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_1.sql new file mode 100644 index 0000000000..d2302891b2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT JSON_EXTRACT_PATH_TEXT(json_data, 'level_1_key.level_2_key[1]') FROM demo1; + +-- databricks sql: +SELECT GET_JSON_OBJECT(json_data, '$.level_1_key.level_2_key[1]') FROM demo1; diff --git a/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_2.sql b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_2.sql new file mode 100644 index 0000000000..046871aef8 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT JSON_EXTRACT_PATH_TEXT(json_data, path_col) FROM demo1; + +-- databricks sql: +SELECT GET_JSON_OBJECT(json_data, CONCAT('$.', path_col)) FROM demo1; diff --git a/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_3.sql b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_3.sql new file mode 100644 index 0000000000..b343f2e3df --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/extract_path_text/test_parse_json_extract_path_text_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT JSON_EXTRACT_PATH_TEXT('{}', path_col) FROM demo1; + +-- databricks sql: +SELECT GET_JSON_OBJECT('{}', CONCAT('$.', path_col)) FROM demo1; diff --git a/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_1.sql b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_1.sql new file mode 100644 index 0000000000..911f283397 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT tt.id, PARSE_JSON(tt.details) FROM prod.public.table tt; + +-- databricks sql: +SELECT tt.id, PARSE_JSON(tt.details) FROM prod.public.table AS tt; diff --git a/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_2.sql b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_2.sql new file mode 100644 index 0000000000..b9a159d6a4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_2.sql @@ -0,0 +1,8 @@ +-- snowflake sql: +SELECT col1, TRY_PARSE_JSON(col2) FROM tabl; + +-- databricks sql: +SELECT col1, PARSE_JSON(col2) FROM tabl; + + + diff --git a/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_3.sql b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_3.sql new file mode 100644 index 0000000000..5dd3f25d8b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/parse_json/test_parse_json_3.sql @@ -0,0 +1,33 @@ +-- snowflake sql: +WITH users AS ( + SELECT + 1 AS user_id, + '[{"id":1,"name":"A"},{"id":2,"name":"B"}]' AS json_data + UNION ALL + SELECT + 2 AS user_id, + '[{"id":3,"name":"C"},{"id":4,"name":"D"}]' AS json_data +) +SELECT + user_id, + value AS json_item +FROM + users, + LATERAL FLATTEN(input => PARSE_JSON(json_data)) as value; + +-- databricks sql: +WITH users AS ( + SELECT + 1 AS user_id, + '[{"id":1,"name":"A"},{"id":2,"name":"B"}]' AS json_data + UNION ALL + SELECT + 2 AS user_id, + '[{"id":3,"name":"C"},{"id":4,"name":"D"}]' AS json_data +) +SELECT + user_id, + value AS json_item +FROM + users , + LATERAL VARIANT_EXPLODE(PARSE_JSON(json_data)) AS value diff --git a/tests/resources/functional/snowflake/functions/stats/median_1.sql b/tests/resources/functional/snowflake/functions/stats/median_1.sql new file mode 100644 index 0000000000..18729de956 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/median_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT median(col1) AS median_col1 FROM tabl; + +-- databricks sql: +SELECT MEDIAN(col1) AS median_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/nvl_1.sql b/tests/resources/functional/snowflake/functions/stats/nvl_1.sql new file mode 100644 index 0000000000..7ed6fdf331 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/nvl_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nvl(col1, col2) AS nvl_col FROM tabl; + +-- databricks sql: +SELECT COALESCE(col1, col2) AS nvl_col FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/regr_intercept_1.sql b/tests/resources/functional/snowflake/functions/stats/regr_intercept_1.sql new file mode 100644 index 0000000000..0b690c992f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/regr_intercept_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regr_intercept(v, v2) AS regr_intercept_col1 FROM tabl; + +-- databricks sql: +SELECT REGR_INTERCEPT(v, v2) AS regr_intercept_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/regr_r2_1.sql b/tests/resources/functional/snowflake/functions/stats/regr_r2_1.sql new file mode 100644 index 0000000000..d309be5c24 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/regr_r2_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regr_r2(v, v2) AS regr_r2_col1 FROM tabl; + +-- databricks sql: +SELECT REGR_R2(v, v2) AS regr_r2_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/regr_slope_1.sql b/tests/resources/functional/snowflake/functions/stats/regr_slope_1.sql new file mode 100644 index 0000000000..e9c6fbb608 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/regr_slope_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regr_slope(v, v2) AS regr_slope_col1 FROM tabl; + +-- databricks sql: +SELECT REGR_SLOPE(v, v2) AS regr_slope_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/stddev_1.sql b/tests/resources/functional/snowflake/functions/stats/stddev_1.sql new file mode 100644 index 0000000000..72d1ee3591 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/stddev_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT stddev(col1) AS stddev_col1 FROM tabl; + +-- databricks sql: +SELECT STDDEV(col1) AS stddev_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/stddev_pop_1.sql b/tests/resources/functional/snowflake/functions/stats/stddev_pop_1.sql new file mode 100644 index 0000000000..7b70860336 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/stddev_pop_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT stddev_pop(col1) AS stddev_pop_col1 FROM tabl; + +-- databricks sql: +SELECT STDDEV_POP(col1) AS stddev_pop_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/stddev_samp_1.sql b/tests/resources/functional/snowflake/functions/stats/stddev_samp_1.sql new file mode 100644 index 0000000000..9fbc1d777b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/stddev_samp_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT stddev_samp(col1) AS stddev_samp_col1 FROM tabl; + +-- databricks sql: +SELECT STDDEV_SAMP(col1) AS stddev_samp_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/test_approx_percentile_1.sql b/tests/resources/functional/snowflake/functions/stats/test_approx_percentile_1.sql new file mode 100644 index 0000000000..570894af09 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/test_approx_percentile_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT approx_percentile(col1, 0.5) AS approx_percentile_col1 FROM tabl; + +-- databricks sql: +SELECT APPROX_PERCENTILE(col1, 0.5) AS approx_percentile_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/test_approx_top_k_1.sql b/tests/resources/functional/snowflake/functions/stats/test_approx_top_k_1.sql new file mode 100644 index 0000000000..4bca27d97d --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/test_approx_top_k_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT approx_top_k(col1) AS approx_top_k_col1 FROM tabl; + +-- databricks sql: +SELECT APPROX_TOP_K(col1) AS approx_top_k_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/test_avg_1.sql b/tests/resources/functional/snowflake/functions/stats/test_avg_1.sql new file mode 100644 index 0000000000..9eb80584f4 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/test_avg_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT avg(col1) AS avg_col1 FROM tabl; + +-- databricks sql: +SELECT AVG(col1) AS avg_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/test_corr_1.sql b/tests/resources/functional/snowflake/functions/stats/test_corr_1.sql new file mode 100644 index 0000000000..67e7424f3a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/test_corr_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT CORR(v, v2) AS corr_col1 FROM tabl; + +-- databricks sql: +SELECT CORR(v, v2) AS corr_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/stats/test_cume_dist_1.sql b/tests/resources/functional/snowflake/functions/stats/test_cume_dist_1.sql new file mode 100644 index 0000000000..6eda8d1731 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/stats/test_cume_dist_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT cume_dist() AS cume_dist_col1 FROM tabl; + +-- databricks sql: +SELECT CUME_DIST() AS cume_dist_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/lower_1.sql b/tests/resources/functional/snowflake/functions/strings/lower_1.sql new file mode 100644 index 0000000000..655d75dfe5 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/lower_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT lower(col1) AS lower_col1 FROM tabl; + +-- databricks sql: +SELECT LOWER(col1) AS lower_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/lpad_1.sql b/tests/resources/functional/snowflake/functions/strings/lpad_1.sql new file mode 100644 index 0000000000..2292d4b06c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/lpad_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT lpad('hi', 5, 'ab'); + +-- databricks sql: +SELECT LPAD('hi', 5, 'ab'); diff --git a/tests/resources/functional/snowflake/functions/strings/ltrim_1.sql b/tests/resources/functional/snowflake/functions/strings/ltrim_1.sql new file mode 100644 index 0000000000..5ce08e090f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/ltrim_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT ltrim(col1) AS ltrim_col1 FROM tabl; + +-- databricks sql: +SELECT LTRIM(col1) AS ltrim_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/parse_url_1.sql b/tests/resources/functional/snowflake/functions/strings/parse_url_1.sql new file mode 100644 index 0000000000..c448272919 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/parse_url_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT parse_url(col1) AS parse_url_col1 FROM tabl; + +-- databricks sql: +SELECT PARSE_URL(col1) AS parse_url_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_count_1.sql b/tests/resources/functional/snowflake/functions/strings/regexp_count_1.sql new file mode 100644 index 0000000000..ae447d86a8 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_count_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_count(col1, patt1) AS regexp_count_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_COUNT(col1, patt1) AS regexp_count_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_instr_1.sql b/tests/resources/functional/snowflake/functions/strings/regexp_instr_1.sql new file mode 100644 index 0000000000..f9e1d645b0 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_instr_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_instr(col1, '.') AS regexp_instr_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_INSTR(col1, '.') AS regexp_instr_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_like_1.sql b/tests/resources/functional/snowflake/functions/strings/regexp_like_1.sql new file mode 100644 index 0000000000..e0cfc09d6a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_like_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_like(col1, 'Users.*') AS regexp_like_col1 FROM tabl; + +-- databricks sql: +SELECT col1 RLIKE 'Users.*' AS regexp_like_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_replace_1.sql b/tests/resources/functional/snowflake/functions/strings/regexp_replace_1.sql new file mode 100644 index 0000000000..cf6b036978 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_replace_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_replace(col1, '(d+)', '***') AS regexp_replace_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_REPLACE(col1, '(d+)', '***') AS regexp_replace_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_substr_1.sql b/tests/resources/functional/snowflake/functions/strings/regexp_substr_1.sql new file mode 100644 index 0000000000..f58f597001 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_substr_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_substr(col1, '(E|e)rror') AS regexp_substr_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_EXTRACT(col1, '(E|e)rror', 0) AS regexp_substr_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/regexp_substr_2.sql b/tests/resources/functional/snowflake/functions/strings/regexp_substr_2.sql new file mode 100644 index 0000000000..fbbedd2865 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/regexp_substr_2.sql @@ -0,0 +1,10 @@ +-- snowflake sql: +select id, string1, + regexp_substr(string1, 'the\\W+\\w+') as "SUBSTRING" + from demo2; + +-- databricks sql: + +select id, string1, + REGEXP_EXTRACT(string1, 'the\\W+\\w+', 0) as `SUBSTRING` + from demo2; diff --git a/tests/resources/functional/snowflake/functions/strings/repeat_1.sql b/tests/resources/functional/snowflake/functions/strings/repeat_1.sql new file mode 100644 index 0000000000..62af863948 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/repeat_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT repeat(col1, 5) AS repeat_col1 FROM tabl; + +-- databricks sql: +SELECT REPEAT(col1, 5) AS repeat_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/replace_1.sql b/tests/resources/functional/snowflake/functions/strings/replace_1.sql new file mode 100644 index 0000000000..8a0ca54f53 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/replace_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT replace('ABC_abc', 'abc', 'DEF'); + +-- databricks sql: +SELECT REPLACE('ABC_abc', 'abc', 'DEF'); diff --git a/tests/resources/functional/snowflake/functions/strings/reverse_1.sql b/tests/resources/functional/snowflake/functions/strings/reverse_1.sql new file mode 100644 index 0000000000..18c1eed7bb --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/reverse_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT reverse(col1) AS reverse_col1 FROM tabl; + +-- databricks sql: +SELECT REVERSE(col1) AS reverse_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/right_1.sql b/tests/resources/functional/snowflake/functions/strings/right_1.sql new file mode 100644 index 0000000000..6add2ccc0e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/right_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT right(col1, 5) AS right_col1 FROM tabl; + +-- databricks sql: +SELECT RIGHT(col1, 5) AS right_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/rpad_1.sql b/tests/resources/functional/snowflake/functions/strings/rpad_1.sql new file mode 100644 index 0000000000..e9bdf90ec9 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/rpad_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT rpad('hi', 5, 'ab'); + +-- databricks sql: +SELECT RPAD('hi', 5, 'ab'); diff --git a/tests/resources/functional/snowflake/functions/strings/rtrim_1.sql b/tests/resources/functional/snowflake/functions/strings/rtrim_1.sql new file mode 100644 index 0000000000..3badda83ab --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/rtrim_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT rtrim(col1) AS rtrim_col1 FROM tabl; + +-- databricks sql: +SELECT RTRIM(col1) AS rtrim_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_1.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_1.sql new file mode 100644 index 0000000000..531c41c4e1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART(col1, ',', 0); + +-- databricks sql: +SELECT SPLIT_PART(col1, ',', 1); diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_2.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_2.sql new file mode 100644 index 0000000000..fc4277a210 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART(NULL, ',', 0); + +-- databricks sql: +SELECT SPLIT_PART(NULL, ',', 1); diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_3.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_3.sql new file mode 100644 index 0000000000..ac873315ce --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART(col1, ',', 5); + +-- databricks sql: +SELECT SPLIT_PART(col1, ',', 5); diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_4.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_4.sql new file mode 100644 index 0000000000..d73b50f71e --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART('lit_string', ',', 1); + +-- databricks sql: +SELECT SPLIT_PART('lit_string', ',', 1); diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_5.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_5.sql new file mode 100644 index 0000000000..c7151b27e2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART('lit_string', '', 1); + +-- databricks sql: +SELECT SPLIT_PART('lit_string', '', 1); diff --git a/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_6.sql b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_6.sql new file mode 100644 index 0000000000..f626a551e0 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/split_part/test_split_part_6.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART(col1, 'delim', len('abc')); + +-- databricks sql: +SELECT SPLIT_PART(col1, 'delim', IF(LENGTH('abc') = 0, 1, LENGTH('abc'))); diff --git a/tests/resources/functional/snowflake/functions/strings/startswith_1.sql b/tests/resources/functional/snowflake/functions/strings/startswith_1.sql new file mode 100644 index 0000000000..39235a3ce1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/startswith_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT startswith(col1, 'Spark') AS startswith_col1 FROM tabl; + +-- databricks sql: +SELECT STARTSWITH(col1, 'Spark') AS startswith_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_1.sql b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_1.sql new file mode 100644 index 0000000000..8a93e20515 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select STRTOK('my text is divided'); + +-- databricks sql: +SELECT SPLIT_PART('my text is divided', ' ', 1); diff --git a/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_2.sql b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_2.sql new file mode 100644 index 0000000000..b65b6637ba --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select STRTOK('a_b_c'), STRTOK(tbl.col123, '.', 3) FROM table tbl; + +-- databricks sql: +SELECT SPLIT_PART('a_b_c', ' ', 1), SPLIT_PART(tbl.col123, '.', 3) FROM table AS tbl; diff --git a/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_3.sql b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_3.sql new file mode 100644 index 0000000000..e6baed6b1f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/strtok/test_strtok_3.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +select STRTOK('user@example.com', '@.', 2), + SPLIT_PART(col123, '.', 1) FROM table tbl; + +-- databricks sql: +SELECT SPLIT_PART('user@example.com', '@.', 2), SPLIT_PART(col123, '.', 1) FROM table AS tbl; diff --git a/tests/resources/functional/snowflake/functions/strings/test_base64_decode_1.sql b/tests/resources/functional/snowflake/functions/strings/test_base64_decode_1.sql new file mode 100644 index 0000000000..726b5764b7 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_base64_decode_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT BASE64_DECODE_STRING(BASE64_ENCODE('HELLO')), TRY_BASE64_DECODE_STRING(BASE64_ENCODE('HELLO')); + +-- databricks sql: +SELECT UNBASE64(BASE64('HELLO')), UNBASE64(BASE64('HELLO')); diff --git a/tests/resources/functional/snowflake/functions/strings/test_base64_encode_1.sql b/tests/resources/functional/snowflake/functions/strings/test_base64_encode_1.sql new file mode 100644 index 0000000000..ce29cf8a5f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_base64_encode_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT BASE64_ENCODE('HELLO'), BASE64_ENCODE('HELLO'); + +-- databricks sql: +SELECT BASE64('HELLO'), BASE64('HELLO'); diff --git a/tests/resources/functional/snowflake/functions/strings/test_charindex_1.sql b/tests/resources/functional/snowflake/functions/strings/test_charindex_1.sql new file mode 100644 index 0000000000..0925f3de77 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_charindex_1.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +select charindex('an', 'banana', 3), + charindex('ab', 'abababab'), n, h, CHARINDEX(n, h) FROM pos; + +-- databricks sql: +SELECT CHARINDEX('an', 'banana', 3), CHARINDEX('ab', 'abababab'), n, h, CHARINDEX(n, h) FROM pos; diff --git a/tests/resources/functional/snowflake/functions/strings/test_collate_1.sql b/tests/resources/functional/snowflake/functions/strings/test_collate_1.sql new file mode 100644 index 0000000000..3c49ea8196 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_collate_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ENDSWITH(COLLATE('ñn', 'sp'), COLLATE('n', 'sp')); + +-- databricks sql: +SELECT ENDSWITH(COLLATE('ñn', 'sp'), COLLATE('n', 'sp')); diff --git a/tests/resources/functional/snowflake/functions/strings/test_collate_2.sql b/tests/resources/functional/snowflake/functions/strings/test_collate_2.sql new file mode 100644 index 0000000000..c7fa5f7948 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_collate_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT v, COLLATION(v), COLLATE(v, 'sp-upper'), COLLATION(COLLATE(v, 'sp-upper')) FROM collation1; + +-- databricks sql: +SELECT v, COLLATION(v), COLLATE(v, 'sp-upper'), COLLATION(COLLATE(v, 'sp-upper')) FROM collation1; diff --git a/tests/resources/functional/snowflake/functions/strings/test_decode_1.sql b/tests/resources/functional/snowflake/functions/strings/test_decode_1.sql new file mode 100644 index 0000000000..ff20d8f136 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_decode_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT decode(column1, 1, 'one', 2, 'two', NULL, '-NULL-', 'other') AS decode_result; + +-- databricks sql: +SELECT CASE WHEN column1 = 1 THEN 'one' WHEN column1 = 2 THEN 'two' WHEN column1 IS NULL THEN '-NULL-' ELSE 'other' END AS decode_result; diff --git a/tests/resources/functional/snowflake/functions/strings/test_editdistance.sql b/tests/resources/functional/snowflake/functions/strings/test_editdistance.sql new file mode 100644 index 0000000000..54ad4631d6 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_editdistance.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT s, t, EDITDISTANCE(s, t), EDITDISTANCE(t, s), EDITDISTANCE(s, t, 3), EDITDISTANCE(s, t, -1) FROM ed; + +-- databricks sql: +SELECT s, t, LEVENSHTEIN(s, t), LEVENSHTEIN(t, s), LEVENSHTEIN(s, t, 3), LEVENSHTEIN(s, t, -1) FROM ed; diff --git a/tests/resources/functional/snowflake/functions/strings/test_endswith_1.sql b/tests/resources/functional/snowflake/functions/strings/test_endswith_1.sql new file mode 100644 index 0000000000..3c5265a4ab --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/test_endswith_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT endswith('SparkSQL', 'SQL'); + +-- databricks sql: +SELECT ENDSWITH('SparkSQL', 'SQL'); diff --git a/tests/resources/functional/snowflake/functions/strings/trim_1.sql b/tests/resources/functional/snowflake/functions/strings/trim_1.sql new file mode 100644 index 0000000000..0412525784 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/trim_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT trim(col1) AS trim_col1 FROM tabl; + +-- databricks sql: +SELECT TRIM(col1) AS trim_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/trunc_1.sql b/tests/resources/functional/snowflake/functions/strings/trunc_1.sql new file mode 100644 index 0000000000..fc2f739d3a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/trunc_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT trunc(col1, 'YEAR') AS trunc_col1 FROM tabl; + +-- databricks sql: +SELECT TRUNC(col1, 'YEAR') AS trunc_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/strings/upper_1.sql b/tests/resources/functional/snowflake/functions/strings/upper_1.sql new file mode 100644 index 0000000000..814f299bfa --- /dev/null +++ b/tests/resources/functional/snowflake/functions/strings/upper_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT upper(col1) AS upper_col1 FROM tabl; + +-- databricks sql: +SELECT UPPER(col1) AS upper_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/struct_1.sql b/tests/resources/functional/snowflake/functions/struct_1.sql new file mode 100644 index 0000000000..2067e2c13c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/struct_1.sql @@ -0,0 +1,8 @@ + +-- snowflake sql: +SELECT {'a': 1, 'b': 2}, [{'c': 11, 'd': 22}, 3]; + +-- databricks sql: +SELECT STRUCT(1 AS a, 2 AS b), ARRAY(STRUCT(11 AS c, 22 AS d), 3); + + diff --git a/tests/resources/functional/snowflake/functions/sysdate_1.sql b/tests/resources/functional/snowflake/functions/sysdate_1.sql new file mode 100644 index 0000000000..9280942ca0 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/sysdate_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SYSDATE(), CURRENT_TIMESTAMP(); + +-- databricks sql: +SELECT CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP(); diff --git a/tests/resources/functional/snowflake/functions/tan_1.sql b/tests/resources/functional/snowflake/functions/tan_1.sql new file mode 100644 index 0000000000..f04b82145a --- /dev/null +++ b/tests/resources/functional/snowflake/functions/tan_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT tan(col1) AS tan_col1 FROM tabl; + +-- databricks sql: +SELECT TAN(col1) AS tan_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_1.sql b/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_1.sql new file mode 100644 index 0000000000..d9cedea260 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nvl2(NULL, 2, 1); + +-- databricks sql: +SELECT NVL2(NULL, 2, 1); diff --git a/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_2.sql b/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_2.sql new file mode 100644 index 0000000000..9e65e46058 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_nvl2/test_nvl2_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nvl2(cond, col1, col2) AS nvl2_col1 FROM tabl; + +-- databricks sql: +SELECT NVL2(cond, col1, col2) AS nvl2_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_1.sql b/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_1.sql new file mode 100644 index 0000000000..b9bc723913 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_1.sql @@ -0,0 +1,10 @@ + +-- snowflake sql: + + SELECT OBJECT_KEYS(object1), OBJECT_KEYS(variant1) + FROM objects_1 + ORDER BY id; + ; + +-- databricks sql: +SELECT JSON_OBJECT_KEYS(object1), JSON_OBJECT_KEYS(variant1) FROM objects_1 ORDER BY id NULLS LAST; diff --git a/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_2.sql b/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_2.sql new file mode 100644 index 0000000000..94dd52c473 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_object_keys/test_object_keys_2.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT OBJECT_KEYS (PARSE_JSON (column1)) AS keys +FROM table +ORDER BY 1; + +-- databricks sql: +SELECT + JSON_OBJECT_KEYS(PARSE_JSON(column1)) AS keys +FROM table +ORDER BY + 1 NULLS LAST; diff --git a/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_1.sql b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_1.sql new file mode 100644 index 0000000000..3558a6a9b7 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select STRTOK_TO_ARRAY('my text is divided'); + +-- databricks sql: +SELECT SPLIT('my text is divided','[ ]'); diff --git a/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_2.sql b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_2.sql new file mode 100644 index 0000000000..289aa7c308 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select STRTOK_TO_ARRAY('v_p_n', '_'), STRTOK_TO_ARRAY(col123, '.') FROM table tbl; + +-- databricks sql: +SELECT SPLIT('v_p_n','[_]'), SPLIT(col123,'[.]') FROM table AS tbl; diff --git a/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_3.sql b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_3.sql new file mode 100644 index 0000000000..5f4a91c04f --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_strtok_to_array/test_strtok_to_array_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select STRTOK_TO_ARRAY('a@b.c', '.@'); + +-- databricks sql: +SELECT SPLIT('a@b.c','[.@]'); diff --git a/tests/resources/functional/snowflake/functions/test_uuid_string_1.sql b/tests/resources/functional/snowflake/functions/test_uuid_string_1.sql new file mode 100644 index 0000000000..a44aa1f3c5 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/test_uuid_string_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT UUID_STRING(); + +-- databricks sql: +SELECT UUID(); diff --git a/tests/resources/functional/snowflake/functions/tovarchar_1.sql b/tests/resources/functional/snowflake/functions/tovarchar_1.sql new file mode 100644 index 0000000000..bafef2cd51 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/tovarchar_1.sql @@ -0,0 +1,7 @@ + +-- snowflake sql: +select to_varchar(-12454.8, '99,999.9S'), + '>' || to_char(col1, '00000.00') || '<' FROM dummy; + +-- databricks sql: +SELECT TO_CHAR(-12454.8, '99,999.9S'), '>' || TO_CHAR(col1, '00000.00') || '<' FROM dummy; diff --git a/tests/resources/functional/snowflake/functions/translate_1.sql b/tests/resources/functional/snowflake/functions/translate_1.sql new file mode 100644 index 0000000000..6ddcbf6929 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/translate_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT translate('AaBbCc', 'abc', '123') AS translate_col1 FROM tabl; + +-- databricks sql: +SELECT TRANSLATE('AaBbCc', 'abc', '123') AS translate_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/typeof_1.sql b/tests/resources/functional/snowflake/functions/typeof_1.sql new file mode 100644 index 0000000000..73958be10c --- /dev/null +++ b/tests/resources/functional/snowflake/functions/typeof_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT typeof(col1) AS typeof_col1 FROM tabl; + +-- databricks sql: +SELECT TYPEOF(col1) AS typeof_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/window/lag_1.sql b/tests/resources/functional/snowflake/functions/window/lag_1.sql new file mode 100644 index 0000000000..837188f2ad --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/lag_1.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + lag(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 + ) AS lag_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LAG(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST + ) AS lag_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/lag_2.sql b/tests/resources/functional/snowflake/functions/window/lag_2.sql new file mode 100644 index 0000000000..58c10bfb57 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/lag_2.sql @@ -0,0 +1,21 @@ +-- snowflake sql: +SELECT + lag(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS lag_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LAG(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS lag_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/lead_1.sql b/tests/resources/functional/snowflake/functions/window/lead_1.sql new file mode 100644 index 0000000000..7e1a241dba --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/lead_1.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + lead(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 + ) AS lead_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LEAD(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST + ) AS lead_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/lead_2.sql b/tests/resources/functional/snowflake/functions/window/lead_2.sql new file mode 100644 index 0000000000..2fc59e391d --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/lead_2.sql @@ -0,0 +1,21 @@ +-- snowflake sql: +SELECT + lead(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS lead_col1 +FROM + tabl; + +-- databricks sql: +SELECT + LEAD(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS lead_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/nth_value_1.sql b/tests/resources/functional/snowflake/functions/window/nth_value_1.sql new file mode 100644 index 0000000000..c74df8e3d9 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/nth_value_1.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +SELECT + nth_value(col1, 42) AS nth_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + NTH_VALUE(col1, 42) AS nth_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/nth_value_2.sql b/tests/resources/functional/snowflake/functions/window/nth_value_2.sql new file mode 100644 index 0000000000..aef11fd550 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/nth_value_2.sql @@ -0,0 +1,20 @@ +-- snowflake sql: +SELECT + nth_value(col1, 42) over ( + partition by col1 + order by + col2 + ) AS nth_value_col1 +FROM + tabl; + +-- databricks sql: +SELECT + NTH_VALUE(col1, 42) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING + ) AS nth_value_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/nth_value_3.sql b/tests/resources/functional/snowflake/functions/window/nth_value_3.sql new file mode 100644 index 0000000000..fcebc6d746 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/nth_value_3.sql @@ -0,0 +1,35 @@ +-- snowflake sql: +SELECT + taba.col_a, + taba.col_b, + nth_value( + CASE + WHEN taba.col_c IN ('xyz', 'abc') THEN taba.col_d + END, + 42 + ) ignore nulls OVER ( + partition BY taba.col_e + ORDER BY + taba.col_f DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a taba; + +-- databricks sql: +SELECT + taba.col_a, + taba.col_b, + NTH_VALUE( + CASE + WHEN taba.col_c IN ('xyz', 'abc') THEN taba.col_d + END, + 42 + ) IGNORE NULLS OVER ( + PARTITION BY taba.col_e + ORDER BY + taba.col_f DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS derived_col_a +FROM + schema_a.table_a AS taba; diff --git a/tests/resources/functional/snowflake/functions/window/ntile_1.sql b/tests/resources/functional/snowflake/functions/window/ntile_1.sql new file mode 100644 index 0000000000..2a851b35d2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/ntile_1.sql @@ -0,0 +1,20 @@ +-- snowflake sql: +SELECT + ntile(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 + ) AS ntile_col1 +FROM + tabl; + +-- databricks sql: +SELECT + NTILE(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING + ) AS ntile_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/ntile_2.sql b/tests/resources/functional/snowflake/functions/window/ntile_2.sql new file mode 100644 index 0000000000..a91a36e4f9 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/ntile_2.sql @@ -0,0 +1,21 @@ +-- snowflake sql: +SELECT + ntile(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS ntile_col1 +FROM + tabl; + +-- databricks sql: +SELECT + NTILE(col1) OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS ntile_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/percent_rank_1.sql b/tests/resources/functional/snowflake/functions/window/percent_rank_1.sql new file mode 100644 index 0000000000..1cd6ef7a94 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/percent_rank_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT percent_rank() AS percent_rank_col1 FROM tabl; + +-- databricks sql: +SELECT PERCENT_RANK() AS percent_rank_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/window/percentile_cont_1.sql b/tests/resources/functional/snowflake/functions/window/percentile_cont_1.sql new file mode 100644 index 0000000000..ed893223fa --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/percentile_cont_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT percentile_cont(col1) AS percentile_cont_col1 FROM tabl; + +-- databricks sql: +SELECT PERCENTILE_CONT(col1) AS percentile_cont_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/window/percentile_disc_1.sql b/tests/resources/functional/snowflake/functions/window/percentile_disc_1.sql new file mode 100644 index 0000000000..c8e5f298b1 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/percentile_disc_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT percentile_disc(col1) AS percentile_disc_col1 FROM tabl; + +-- databricks sql: +SELECT PERCENTILE_DISC(col1) AS percentile_disc_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/window/position_1.sql b/tests/resources/functional/snowflake/functions/window/position_1.sql new file mode 100644 index 0000000000..71b36e305b --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/position_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT position('exc', col1) AS position_col1 FROM tabl; + +-- databricks sql: +SELECT LOCATE('exc', col1) AS position_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/functions/window/rank_1.sql b/tests/resources/functional/snowflake/functions/window/rank_1.sql new file mode 100644 index 0000000000..58bfc85e7d --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/rank_1.sql @@ -0,0 +1,19 @@ +-- snowflake sql: +SELECT + rank() over ( + partition by col1 + order by + col2 + ) AS rank_col1 +FROM + tabl; + +-- databricks sql: +SELECT + RANK() OVER ( + PARTITION BY col1 + ORDER BY + col2 ASC NULLS LAST + ) AS rank_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/rank_2.sql b/tests/resources/functional/snowflake/functions/window/rank_2.sql new file mode 100644 index 0000000000..8be920dbd5 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/rank_2.sql @@ -0,0 +1,21 @@ +-- snowflake sql: +SELECT + rank() over ( + partition by col1 + order by + col2 DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS rank_col1 +FROM + tabl; + +-- databricks sql: +SELECT + RANK() OVER ( + PARTITION BY col1 + ORDER BY + col2 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS rank_col1 +FROM + tabl; diff --git a/tests/resources/functional/snowflake/functions/window/row_number_1.sql b/tests/resources/functional/snowflake/functions/window/row_number_1.sql new file mode 100644 index 0000000000..e55228eeb2 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/row_number_1.sql @@ -0,0 +1,25 @@ +-- snowflake sql: +SELECT + symbol, + exchange, + shares, + ROW_NUMBER() OVER ( + PARTITION BY exchange + ORDER BY + shares + ) AS row_number +FROM + trades; + +-- databricks sql: +SELECT + symbol, + exchange, + shares, + ROW_NUMBER() OVER ( + PARTITION BY exchange + ORDER BY + shares ASC NULLS LAST + ) AS row_number +FROM + trades; diff --git a/tests/resources/functional/snowflake/functions/window/row_number_2.sql b/tests/resources/functional/snowflake/functions/window/row_number_2.sql new file mode 100644 index 0000000000..fe374f6369 --- /dev/null +++ b/tests/resources/functional/snowflake/functions/window/row_number_2.sql @@ -0,0 +1,27 @@ +-- snowflake sql: +SELECT + symbol, + exchange, + shares, + ROW_NUMBER() OVER ( + PARTITION BY exchange + ORDER BY + shares DESC RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS row_number +FROM + trades; + +-- databricks sql: +SELECT + symbol, + exchange, + shares, + ROW_NUMBER() OVER ( + PARTITION BY exchange + ORDER BY + shares DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING + AND CURRENT ROW + ) AS row_number +FROM + trades; diff --git a/tests/resources/functional/snowflake/joins/test_join_1.sql b/tests/resources/functional/snowflake/joins/test_join_1.sql new file mode 100644 index 0000000000..278e48b0e7 --- /dev/null +++ b/tests/resources/functional/snowflake/joins/test_join_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT t1.c1, t2.c2 FROM t1 JOIN t2 USING (c3); + +-- databricks sql: +SELECT t1.c1, t2.c2 FROM t1 JOIN t2 USING (c3); diff --git a/tests/resources/functional/snowflake/joins/test_join_2.sql b/tests/resources/functional/snowflake/joins/test_join_2.sql new file mode 100644 index 0000000000..fe068bcf75 --- /dev/null +++ b/tests/resources/functional/snowflake/joins/test_join_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT * FROM table1, table2 WHERE table1.column_name = table2.column_name; + +-- databricks sql: +SELECT * FROM table1, table2 WHERE table1.column_name = table2.column_name; diff --git a/tests/resources/functional/snowflake/lca/lca_cte.sql b/tests/resources/functional/snowflake/lca/lca_cte.sql new file mode 100644 index 0000000000..3ee9c89a81 --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_cte.sql @@ -0,0 +1,11 @@ +-- snowflake sql: +WITH cte AS (SELECT column_a as customer_id + FROM my_table + WHERE customer_id = '123') +SELECT * FROM cte; + +-- databricks sql: +WITH cte AS (SELECT column_a as customer_id + FROM my_table + WHERE column_a = '123') +SELECT * FROM cte; diff --git a/tests/resources/functional/snowflake/lca/lca_derived.sql b/tests/resources/functional/snowflake/lca/lca_derived.sql new file mode 100644 index 0000000000..cf41b3072b --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_derived.sql @@ -0,0 +1,8 @@ +-- snowflake sql: +SELECT column_a as cid +FROM (select column_x as column_a, column_y as y from my_table where y = '456') +WHERE cid = '123'; +-- databricks sql: +SELECT column_a as cid +FROM (select column_x as column_a, column_y as y from my_table where column_y = '456') +WHERE column_a = '123'; diff --git a/tests/resources/functional/snowflake/lca/lca_in.sql b/tests/resources/functional/snowflake/lca/lca_in.sql new file mode 100644 index 0000000000..b1ff47ba9d --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_in.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT t.col1, t.col2, t.col3 AS ca FROM table1 t WHERE ca in ('v1', 'v2'); + +-- databricks sql: +SELECT t.col1, t.col2, t.col3 AS ca FROM table1 AS t WHERE t.col3 in ('v1', 'v2'); diff --git a/tests/resources/functional/snowflake/lca/lca_nested.sql b/tests/resources/functional/snowflake/lca/lca_nested.sql new file mode 100644 index 0000000000..d3b2c2e5a3 --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_nested.sql @@ -0,0 +1,13 @@ +-- snowflake sql: +SELECT + b * c as new_b, + a - new_b as ab_diff +FROM my_table +WHERE ab_diff >= 0; + +-- databricks sql: +SELECT + b * c as new_b, + a - new_b as ab_diff +FROM my_table +WHERE a - b * c >= 0; diff --git a/tests/resources/functional/snowflake/lca/lca_subquery.sql b/tests/resources/functional/snowflake/lca/lca_subquery.sql new file mode 100644 index 0000000000..8fc77adcdc --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_subquery.sql @@ -0,0 +1,8 @@ +-- snowflake sql: +SELECT column_a as cid +FROM my_table +WHERE cid in (select cid as customer_id from customer_table where customer_id = '123'); +-- databricks sql: +SELECT column_a as cid +FROM my_table +WHERE column_a in (select cid as customer_id from customer_table where cid = '123'); diff --git a/tests/resources/functional/snowflake/lca/lca_where.sql b/tests/resources/functional/snowflake/lca/lca_where.sql new file mode 100644 index 0000000000..920915f1df --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_where.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT column_a as alias_a FROM table_a where alias_a = '123'; + +-- databricks sql: +SELECT column_a as alias_a FROM table_a where column_a = '123'; diff --git a/tests/resources/functional/snowflake/lca/lca_window.sql b/tests/resources/functional/snowflake/lca/lca_window.sql new file mode 100644 index 0000000000..66d804732d --- /dev/null +++ b/tests/resources/functional/snowflake/lca/lca_window.sql @@ -0,0 +1,15 @@ +-- snowflake sql: +SELECT + t.col1, + t.col2, + t.col3 AS ca, + ROW_NUMBER() OVER (PARTITION by ca ORDER BY t.col2 DESC) AS rn + FROM table1 t; + +-- databricks sql: +SELECT + t.col1, + t.col2, + t.col3 AS ca, + ROW_NUMBER() OVER (PARTITION by t.col3 ORDER BY t.col2 DESC NULLS FIRST) AS rn + FROM table1 AS t; diff --git a/tests/resources/functional/snowflake/misc/test_any_value_1.sql b/tests/resources/functional/snowflake/misc/test_any_value_1.sql new file mode 100644 index 0000000000..98648f2766 --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_any_value_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT customer.id , ANY_VALUE(customer.name) , SUM(orders.value) FROM customer JOIN orders ON customer.id = orders.customer_id GROUP BY customer.id; + +-- databricks sql: +SELECT customer.id, ANY_VALUE(customer.name), SUM(orders.value) FROM customer JOIN orders ON customer.id = orders.customer_id GROUP BY customer.id; diff --git a/tests/resources/functional/snowflake/misc/test_coalesce_1.sql b/tests/resources/functional/snowflake/misc/test_coalesce_1.sql new file mode 100644 index 0000000000..421dd6258a --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_coalesce_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT coalesce(col1, col2) AS coalesce_col FROM tabl; + +-- databricks sql: +SELECT COALESCE(col1, col2) AS coalesce_col FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_contains_1.sql b/tests/resources/functional/snowflake/misc/test_contains_1.sql new file mode 100644 index 0000000000..0bc56eb55e --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_contains_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT contains('SparkSQL', 'Spark'); + +-- databricks sql: +SELECT CONTAINS('SparkSQL', 'Spark'); diff --git a/tests/resources/functional/snowflake/misc/test_equal_null_1.sql b/tests/resources/functional/snowflake/misc/test_equal_null_1.sql new file mode 100644 index 0000000000..a56ebe6143 --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_equal_null_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT equal_null(2, 2) AS equal_null_col1 FROM tabl; + +-- databricks sql: +SELECT EQUAL_NULL(2, 2) AS equal_null_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_hash_1.sql b/tests/resources/functional/snowflake/misc/test_hash_1.sql new file mode 100644 index 0000000000..3b2daa784f --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_hash_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT hash(col1) AS hash_col1 FROM tabl; + +-- databricks sql: +SELECT HASH(col1) AS hash_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_iff_1.sql b/tests/resources/functional/snowflake/misc/test_iff_1.sql new file mode 100644 index 0000000000..40992d519e --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_iff_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT iff(cond, col1, col2) AS iff_col1 FROM tabl; + +-- databricks sql: +SELECT IF(cond, col1, col2) AS iff_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_ifnull_1.sql b/tests/resources/functional/snowflake/misc/test_ifnull_1.sql new file mode 100644 index 0000000000..e02a064fd2 --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_ifnull_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ifnull(col1, 'NA') AS ifnull_col1 FROM tabl; + +-- databricks sql: +SELECT COALESCE(col1, 'NA') AS ifnull_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_ifnull_2.sql b/tests/resources/functional/snowflake/misc/test_ifnull_2.sql new file mode 100644 index 0000000000..ab998fedf2 --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_ifnull_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT ifnull(col1) AS ifnull_col1 FROM tabl; + +-- databricks sql: +SELECT COALESCE(col1) AS ifnull_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/misc/test_sha2_1.sql b/tests/resources/functional/snowflake/misc/test_sha2_1.sql new file mode 100644 index 0000000000..e3b1a9a7ee --- /dev/null +++ b/tests/resources/functional/snowflake/misc/test_sha2_1.sql @@ -0,0 +1,9 @@ +-- snowflake sql: +select sha2(test_col), sha2(test_col, 256), sha2(test_col, 224) from test_tbl; + +-- databricks sql: +SELECT + SHA2(test_col, 256), + SHA2(test_col, 256), + SHA2(test_col, 224) +FROM test_tbl; diff --git a/tests/resources/functional/snowflake/nested_query_with_json_1.sql b/tests/resources/functional/snowflake/nested_query_with_json_1.sql new file mode 100644 index 0000000000..e83ff2bcde --- /dev/null +++ b/tests/resources/functional/snowflake/nested_query_with_json_1.sql @@ -0,0 +1,32 @@ +-- snowflake sql: +SELECT A.COL1, A.COL2, B.COL3, B.COL4 FROM + (SELECT COL1, COL2 FROM TABLE1) A, + (SELECT VALUE:PRICE::FLOAT AS COL3, COL4 FROM + (SELECT * FROM TABLE2 ) AS K + ) B + WHERE A.COL1 = B.COL4; + +-- databricks sql: +SELECT + A.COL1, + A.COL2, + B.COL3, + B.COL4 +FROM ( + SELECT + COL1, + COL2 + FROM TABLE1 +) AS A, ( + SELECT + CAST(VALUE:PRICE AS DOUBLE) AS COL3, + COL4 + FROM ( + SELECT + * + FROM TABLE2 + ) AS K +) AS B +WHERE + A.COL1 = B.COL4; + diff --git a/tests/resources/functional/snowflake/nulls/test_nullif_1.sql b/tests/resources/functional/snowflake/nulls/test_nullif_1.sql new file mode 100644 index 0000000000..e5c18de906 --- /dev/null +++ b/tests/resources/functional/snowflake/nulls/test_nullif_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nullif(col1, col2) AS nullif_col1 FROM tabl; + +-- databricks sql: +SELECT NULLIF(col1, col2) AS nullif_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake/nulls/test_nullifzero_1.sql b/tests/resources/functional/snowflake/nulls/test_nullifzero_1.sql new file mode 100644 index 0000000000..593fc2a5c0 --- /dev/null +++ b/tests/resources/functional/snowflake/nulls/test_nullifzero_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT t.n, nullifzero(t.n) AS pcol1 FROM tbl t; + +-- databricks sql: +SELECT t.n, IF(t.n = 0, NULL, t.n) AS pcol1 FROM tbl AS t; diff --git a/tests/resources/functional/snowflake/nulls/test_nullsafe_eq_1.sql b/tests/resources/functional/snowflake/nulls/test_nullsafe_eq_1.sql new file mode 100644 index 0000000000..92ca93b2e8 --- /dev/null +++ b/tests/resources/functional/snowflake/nulls/test_nullsafe_eq_1.sql @@ -0,0 +1,14 @@ +-- snowflake sql: +SELECT A.COL1, B.COL2 FROM TABL1 A JOIN TABL2 B ON (A.COL1 = B.COL1 OR (A.COL1 IS NULL AND B.COL1 IS NULL)); + +-- databricks sql: +SELECT + A.COL1, + B.COL2 +FROM TABL1 AS A +JOIN TABL2 AS B + ON ( + A.COL1 = B.COL1 OR ( + A.COL1 IS NULL AND B.COL1 IS NULL + ) + ); diff --git a/tests/resources/functional/snowflake/operators/test_bitor_agg_1.sql b/tests/resources/functional/snowflake/operators/test_bitor_agg_1.sql new file mode 100644 index 0000000000..3159861509 --- /dev/null +++ b/tests/resources/functional/snowflake/operators/test_bitor_agg_1.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select bitor_agg(k) from bitwise_example; + +-- databricks sql: +SELECT BIT_OR(k) FROM bitwise_example; diff --git a/tests/resources/functional/snowflake/operators/test_bitor_agg_2.sql b/tests/resources/functional/snowflake/operators/test_bitor_agg_2.sql new file mode 100644 index 0000000000..f446228d9b --- /dev/null +++ b/tests/resources/functional/snowflake/operators/test_bitor_agg_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +select s2, bitor_agg(k) from bitwise_example group by s2; + +-- databricks sql: +SELECT s2, BIT_OR(k) FROM bitwise_example GROUP BY s2; diff --git a/tests/resources/functional/snowflake/sqlglot-incorrect/test_current_date_1.sql b/tests/resources/functional/snowflake/sqlglot-incorrect/test_current_date_1.sql new file mode 100644 index 0000000000..b2e046a02f --- /dev/null +++ b/tests/resources/functional/snowflake/sqlglot-incorrect/test_current_date_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +SELECT CURRENT_DATE() FROM tabl; + +-- databricks sql: +SELECT CURRENT_DATE() FROM tabl; diff --git a/tests/resources/functional/snowflake/sqlglot-incorrect/test_uuid_string_2.sql b/tests/resources/functional/snowflake/sqlglot-incorrect/test_uuid_string_2.sql new file mode 100644 index 0000000000..b062e5b50c --- /dev/null +++ b/tests/resources/functional/snowflake/sqlglot-incorrect/test_uuid_string_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT UUID_STRING('fe971b24-9572-4005-b22f-351e9c09274d','foo'); + +-- databricks sql: +SELECT UUID('fe971b24-9572-4005-b22f-351e9c09274d', 'foo'); diff --git a/tests/resources/functional/snowflake/tablesample/test_tablesample_1.sql b/tests/resources/functional/snowflake/tablesample/test_tablesample_1.sql new file mode 100644 index 0000000000..de90bb1805 --- /dev/null +++ b/tests/resources/functional/snowflake/tablesample/test_tablesample_1.sql @@ -0,0 +1,7 @@ +-- snowflake sql: +select * from (select * from example_table) sample (1) seed (99); + +-- databricks sql: +SELECT * FROM ( + SELECT * FROM example_table +) TABLESAMPLE (1 PERCENT) REPEATABLE (99); diff --git a/tests/resources/functional/snowflake/tablesample/test_tablesample_2.sql b/tests/resources/functional/snowflake/tablesample/test_tablesample_2.sql new file mode 100644 index 0000000000..0a683be7ac --- /dev/null +++ b/tests/resources/functional/snowflake/tablesample/test_tablesample_2.sql @@ -0,0 +1,7 @@ +-- snowflake sql: +select * from (select * from example_table) tablesample (1) seed (99); + +-- databricks sql: +SELECT * FROM ( + SELECT * FROM example_table +) TABLESAMPLE (1 PERCENT) REPEATABLE (99); diff --git a/tests/resources/functional/snowflake/tablesample/test_tablesample_3.sql b/tests/resources/functional/snowflake/tablesample/test_tablesample_3.sql new file mode 100644 index 0000000000..ba3f2e085d --- /dev/null +++ b/tests/resources/functional/snowflake/tablesample/test_tablesample_3.sql @@ -0,0 +1,14 @@ +-- snowflake sql: + + select * + from ( + select * + from t1 join t2 + on t1.a = t2.c + ) sample (1); + ; + +-- databricks sql: +SELECT * FROM ( + SELECT * FROM t1 JOIN t2 ON t1.a = t2.c +) TABLESAMPLE (1 PERCENT); diff --git a/tests/resources/functional/snowflake/test_command/test_command_2.sql b/tests/resources/functional/snowflake/test_command/test_command_2.sql new file mode 100644 index 0000000000..9379f18c84 --- /dev/null +++ b/tests/resources/functional/snowflake/test_command/test_command_2.sql @@ -0,0 +1,6 @@ +-- snowflake sql: +SELECT !(2 = 2) AS always_false + + +-- databricks sql: +SELECT !(2 = 2) AS always_false diff --git a/tests/resources/functional/snowflake/test_command/test_command_3.sql b/tests/resources/functional/snowflake/test_command/test_command_3.sql new file mode 100644 index 0000000000..100e361372 --- /dev/null +++ b/tests/resources/functional/snowflake/test_command/test_command_3.sql @@ -0,0 +1,9 @@ + +-- snowflake sql: +!set exit_on_error = true; +SELECT !(2 = 2) AS always_false; + + +-- databricks sql: +-- !set exit_on_error = true; +SELECT !(2 = 2) AS always_false diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_10.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_10.sql new file mode 100644 index 0000000000..0ab08e25ec --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_10.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +EXECUTE TASK mytask; + +-- databricks sql: +-- EXECUTE TASK mytask; diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_11.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_11.sql new file mode 100644 index 0000000000..f4d78086af --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_11.sql @@ -0,0 +1,48 @@ +-- snowflake sql: +SELECT DISTINCT + dst.CREATED_DATE + , dst.task_id + , dd.delivery_id::TEXT AS delivery_id + , dst.ISSUE_CATEGORY + , dst.issue + , dd.store_id + FROM proddb.public.dimension_salesforce_tasks dst + JOIN edw.finance.dimension_deliveries dd + ON dd.delivery_id = dst.delivery_id + WHERE dst.mto_flag = 1 + AND dst.customer_type IN ('Consumer') + AND dd.STORE_ID IN (SELECT store_id FROM foo.bar.cng_stores_stage) + AND dd.is_test = false + AND dst.origin IN ('Chat') + AND dd_agent_id IS NOT NULL + AND dst.CREATED_DATE > CURRENT_DATE - 7 + ORDER BY 1 DESC, 3 DESC; + ; + +-- databricks sql: +SELECT DISTINCT + dst.CREATED_DATE, + dst.task_id, + CAST(dd.delivery_id AS STRING) AS delivery_id, + dst.ISSUE_CATEGORY, + dst.issue, + dd.store_id + FROM proddb.public.dimension_salesforce_tasks AS dst + JOIN edw.finance.dimension_deliveries AS dd + ON dd.delivery_id = dst.delivery_id + WHERE + dst.mto_flag = 1 + AND dst.customer_type IN ('Consumer') + AND dd.STORE_ID IN ( + SELECT + store_id + FROM foo.bar.cng_stores_stage + ) + AND dd.is_test = false + AND dst.origin IN ('Chat') + AND dd_agent_id IS NOT NULL + AND dst.CREATED_DATE > CURRENT_DATE() - 7 + ORDER BY + 1 DESC NULLS FIRST, + 3 DESC NULLS FIRST; + diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_2.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_2.sql new file mode 100644 index 0000000000..7660368601 --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_2.sql @@ -0,0 +1,5 @@ + +-- snowflake sql: +BEGIN; + +-- databricks sql: diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_3.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_3.sql new file mode 100644 index 0000000000..f0761574db --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_3.sql @@ -0,0 +1,5 @@ + +-- snowflake sql: +ROLLBACK; + +-- databricks sql: diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_4.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_4.sql new file mode 100644 index 0000000000..27972c1a8e --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_4.sql @@ -0,0 +1,5 @@ + +-- snowflake sql: +COMMIT; + +-- databricks sql: diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_7.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_7.sql new file mode 100644 index 0000000000..9e7984430f --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_7.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SHOW STREAMS LIKE 'line%' IN tpch.public; + +-- databricks sql: +-- SHOW STREAMS LIKE 'line%' IN tpch.public; diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_8.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_8.sql new file mode 100644 index 0000000000..dfa2a7dcfd --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_8.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +ALTER TABLE tab1 ADD COLUMN c2 NUMBER; + +-- databricks sql: +ALTER TABLE tab1 ADD COLUMN c2 DECIMAL(38, 0); diff --git a/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_9.sql b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_9.sql new file mode 100644 index 0000000000..a7d87855f7 --- /dev/null +++ b/tests/resources/functional/snowflake/test_skip_unsupported_operations/test_skip_unsupported_operations_9.sql @@ -0,0 +1,13 @@ + +-- snowflake sql: + + CREATE TASK t1 + SCHEDULE = '30 MINUTE' + TIMESTAMP_INPUT_FORMAT = 'YYYY-MM-DD HH24' + USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = 'XSMALL' + AS + INSERT INTO mytable(ts) VALUES(CURRENT_TIMESTAMP); + ; + +-- databricks sql: +-- CREATE TASK t1 SCHEDULE = '30 MINUTE' TIMESTAMP_INPUT_FORMAT = 'YYYY-MM-DD HH24' USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = 'XSMALL' AS INSERT INTO mytable(ts) VALUES(CURRENT_TIMESTAMP); diff --git a/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_1.sql b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_1.sql new file mode 100644 index 0000000000..3b42445228 --- /dev/null +++ b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_1.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +select emp_id from abc.emp where emp_id = &ids; + +-- databricks sql: +select emp_id from abc.emp where emp_id = $ids; diff --git a/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_2.sql b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_2.sql new file mode 100644 index 0000000000..ec07bee841 --- /dev/null +++ b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_2.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +select emp_id from abc.emp where emp_id = $ids; + +-- databricks sql: +select emp_id from abc.emp where emp_id = $ids; diff --git a/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_3.sql b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_3.sql new file mode 100644 index 0000000000..ed8e90455d --- /dev/null +++ b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_3.sql @@ -0,0 +1,5 @@ +-- snowflake sql: +select count(*) from &TEST_USER.EMP_TBL; + +-- databricks sql: +select count(*) from $TEST_USER.EMP_TBL; diff --git a/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_4.sql b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_4.sql new file mode 100644 index 0000000000..a8a71b2bfb --- /dev/null +++ b/tests/resources/functional/snowflake/test_tokens_parameter/test_tokens_parameter_4.sql @@ -0,0 +1,11 @@ +-- Note that we cannot enable the WHERE clause here as teh python +-- tests will fail because SqlGlot cannot translate variable references +-- within strings. Enable the WHERE clause when we are not limited by Python + +-- WHERE EMP_ID = '&empNo' => WHERE EMP_ID = '${empNo}' + +-- snowflake sql: +select count(*) from &TEST_USER.EMP_TBL; + +-- databricks sql: +select count(*) from $TEST_USER.EMP_TBL; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_approx_percentile_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_approx_percentile_2.sql new file mode 100644 index 0000000000..0f32aae892 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_approx_percentile_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT approx_percentile(col1) AS approx_percentile_col1 FROM tabl; + +-- databricks sql: +SELECT APPROX_PERCENTILE(col1) AS approx_percentile_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_array_contains_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_array_contains_2.sql new file mode 100644 index 0000000000..70b62d8645 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_array_contains_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_contains(arr_col) AS array_contains_col1 FROM tabl; + +-- databricks sql: +SELECT ARRAY_CONTAINS(arr_col) AS array_contains_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_array_slice_3.sql b/tests/resources/functional/snowflake_expected_exceptions/test_array_slice_3.sql new file mode 100644 index 0000000000..00963f129f --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_array_slice_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT array_slice(array_construct(90,91,92,93,94,95,96), -4, -1); + +-- databricks sql: +SELECT SLICE(ARRAY(90, 91, 92, 93, 94, 95, 96), -4, -1); diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_date_part_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_date_part_2.sql new file mode 100644 index 0000000000..4ce19bcec3 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_date_part_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT date_part(col1) AS date_part_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_PART(col1) AS date_part_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_date_trunc_5.sql b/tests/resources/functional/snowflake_expected_exceptions/test_date_trunc_5.sql new file mode 100644 index 0000000000..0b0cb9e5bb --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_date_trunc_5.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT date_trunc(col1) AS date_trunc_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_TRUNC(col1) AS date_trunc_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_dayname_4.sql b/tests/resources/functional/snowflake_expected_exceptions/test_dayname_4.sql new file mode 100644 index 0000000000..8610daf168 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_dayname_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT DAYNAME('2015-04-03 10:00', 'EEE') AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT('2015-04-03 10:00', 'E') AS MONTH; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_extract_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_extract_2.sql new file mode 100644 index 0000000000..e98ac7a9b7 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_extract_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT extract(col1) AS extract_col1 FROM tabl; + +-- databricks sql: +SELECT EXTRACT(col1) AS extract_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_iff_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_iff_2.sql new file mode 100644 index 0000000000..7263354c2d --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_iff_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT iff(col1) AS iff_col1 FROM tabl; + +-- databricks sql: +SELECT IFF(col1) AS iff_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_left_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_left_2.sql new file mode 100644 index 0000000000..8c9f3a90fd --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_left_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT left(col1) AS left_col1 FROM tabl; + +-- databricks sql: +SELECT LEFT(col1) AS left_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_monthname_8.sql b/tests/resources/functional/snowflake_expected_exceptions/test_monthname_8.sql new file mode 100644 index 0000000000..8a2b45949d --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_monthname_8.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT MONTHNAME('2015-04-03 10:00', 'MMM') AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT('2015-04-03 10:00', 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_monthname_9.sql b/tests/resources/functional/snowflake_expected_exceptions/test_monthname_9.sql new file mode 100644 index 0000000000..2572d2c7f9 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_monthname_9.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT DAYNAME('2015-04-03 10:00', 'MMM') AS MONTH; + +-- databricks sql: +SELECT DATE_FORMAT('2015-04-03 10:00', 'MMM') AS MONTH; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_nullif_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_nullif_2.sql new file mode 100644 index 0000000000..154b69f247 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_nullif_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nullif(col1) AS nullif_col1 FROM tabl; + +-- databricks sql: +SELECT NULLIF(col1) AS nullif_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_nvl2_3.sql b/tests/resources/functional/snowflake_expected_exceptions/test_nvl2_3.sql new file mode 100644 index 0000000000..b91edf80a4 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_nvl2_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT nvl2(col1) AS nvl2_col1 FROM tabl; + +-- databricks sql: +SELECT NVL2(col1) AS nvl2_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_parse_json_extract_path_text_4.sql b/tests/resources/functional/snowflake_expected_exceptions/test_parse_json_extract_path_text_4.sql new file mode 100644 index 0000000000..80499218ea --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_parse_json_extract_path_text_4.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT JSON_EXTRACT_PATH_TEXT('{}') FROM demo1; + +-- databricks sql: +SELECT GET_JSON_OBJECT('{}', CONCAT('$.', path_col)) FROM demo1; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_position_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_position_2.sql new file mode 100644 index 0000000000..9fb1b9f0cd --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_position_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT position(col1) AS position_col1 FROM tabl; + +-- databricks sql: +SELECT POSITION(col1) AS position_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_regexp_like_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_like_2.sql new file mode 100644 index 0000000000..d8b903bde8 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_like_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_like(col1) AS regexp_like_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_LIKE(col1) AS regexp_like_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_regexp_replace_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_replace_2.sql new file mode 100644 index 0000000000..26565cce06 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_replace_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_replace(col1) AS regexp_replace_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_REPLACE(col1) AS regexp_replace_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_regexp_substr_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_substr_2.sql new file mode 100644 index 0000000000..fa2fc7266e --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_regexp_substr_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT regexp_substr(col1) AS regexp_substr_col1 FROM tabl; + +-- databricks sql: +SELECT REGEXP_SUBSTR(col1) AS regexp_substr_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_repeat_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_repeat_2.sql new file mode 100644 index 0000000000..45ac0f0937 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_repeat_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT repeat(col1) AS repeat_col1 FROM tabl; + +-- databricks sql: +SELECT REPEAT(col1) AS repeat_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_right_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_right_2.sql new file mode 100644 index 0000000000..7a1e7346d2 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_right_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT right(col1) AS left_col1 FROM tabl; + +-- databricks sql: +SELECT RIGHT(col1) AS left_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_split_part_7.sql b/tests/resources/functional/snowflake_expected_exceptions/test_split_part_7.sql new file mode 100644 index 0000000000..84454d3765 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_split_part_7.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT SPLIT_PART('lit_string', ','); + +-- databricks sql: +SELECT SPLIT_PART('lit_string', ',', 5); diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_split_part_8.sql b/tests/resources/functional/snowflake_expected_exceptions/test_split_part_8.sql new file mode 100644 index 0000000000..f16f54dde3 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_split_part_8.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT split_part(col1) AS split_part_col1 FROM tabl; + +-- databricks sql: +SELECT SPLIT_PART(col1) AS split_part_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_startswith_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_startswith_2.sql new file mode 100644 index 0000000000..649413e621 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_startswith_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT startswith(col1) AS startswith_col1 FROM tabl; + +-- databricks sql: +SELECT STARTSWITH(col1) AS startswith_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_timestampadd_6.sql b/tests/resources/functional/snowflake_expected_exceptions/test_timestampadd_6.sql new file mode 100644 index 0000000000..d7726f1040 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_timestampadd_6.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT timestampadd(col1) AS timestampadd_col1 FROM tabl; + +-- databricks sql: +SELECT DATEADD(col1) AS timestampadd_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_to_number_10.sql b/tests/resources/functional/snowflake_expected_exceptions/test_to_number_10.sql new file mode 100644 index 0000000000..a66887a6c7 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_to_number_10.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT TO_NUMERIC('$345', '$999.99', 5, 2, 1) AS num_with_scale; + +-- databricks sql: +SELECT CAST(TO_NUMBER('$345', '$999.99') AS DECIMAL(5, 2)) AS num_with_scale; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_trunc_2.sql b/tests/resources/functional/snowflake_expected_exceptions/test_trunc_2.sql new file mode 100644 index 0000000000..e7944c19c8 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_trunc_2.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT trunc(col1) AS trunc_col1 FROM tabl; + +-- databricks sql: +SELECT TRUNC(col1) AS trunc_col1 FROM tabl; diff --git a/tests/resources/functional/snowflake_expected_exceptions/test_try_cast_3.sql b/tests/resources/functional/snowflake_expected_exceptions/test_try_cast_3.sql new file mode 100644 index 0000000000..05379d31f6 --- /dev/null +++ b/tests/resources/functional/snowflake_expected_exceptions/test_try_cast_3.sql @@ -0,0 +1,6 @@ + +-- snowflake sql: +SELECT try_cast(col1) AS try_cast_col1 FROM tabl; + +-- databricks sql: +SELECT TRY_CAST(col1) AS try_cast_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/core_engine/dbt/nofmt_twisted_input.sql b/tests/resources/functional/tsql/core_engine/dbt/nofmt_twisted_input.sql new file mode 100644 index 0000000000..10309ea2ce --- /dev/null +++ b/tests/resources/functional/tsql/core_engine/dbt/nofmt_twisted_input.sql @@ -0,0 +1,27 @@ +-- tsql sql: +{%- set payment_methods = dbt_utils.get_column_values( + table=ref('raw_payments'), + column='payment_method' +) -%} + +select + order_id, + {%- for payment_method in payment_methods %} + sum(case when payment_method = '{{payment_method}}' then amount end) as {{payment_method}}_amount + {%- if not loop.last %},{% endif -%} + {% endfor %} + from {{ ref('raw_payments') }} + group by 1 +-- databricks sql: +{%- set payment_methods = dbt_utils.get_column_values( + table=ref('raw_payments'), + column='payment_method' +) -%} + +SELECT order_id, + {%- for payment_method in payment_methods %} + SUM(CASE WHEN payment_method = '{{payment_method}}' THEN amount END) AS {{payment_method}}_amount + {%- if not loop.last %},{% endif -%} + {% endfor %} + FROM {{ ref('raw_payments') }} + GROUP BY 1; diff --git a/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_1.sql b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_1.sql new file mode 100644 index 0000000000..6912e2a8f2 --- /dev/null +++ b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_1.sql @@ -0,0 +1,19 @@ +-- Note that here we have two commas in the select clause and teh TSQL grammar not +-- quite as bad as the Snowflake grammar, is able to see that it can delete + +-- tsql sql: +select col1,, col2 from table_name; + +-- databricks sql: +SELECT + col1, +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: unexpected extra input ',' while parsing a SELECT statement + expecting one of: $Currency, 'String', @@Reference, @Local, Float, Identifier, Integer, Operator, Real, Statement, '$ACTION', '$NODE_ID'... + Unparsable text: , + */ + col2 +FROM + table_name; diff --git a/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_2.sql b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_2.sql new file mode 100644 index 0000000000..b1bd7ee603 --- /dev/null +++ b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_2.sql @@ -0,0 +1,11 @@ +-- tsql sql: +* + +-- databricks sql: +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: unexpected extra input '*' while parsing a T-SQL batch + expecting one of: End of batch, Identifier, Select Statement, Statement, '$NODE_ID', '(', ';', 'ACCOUNTADMIN', 'ALERT', 'ARRAY', 'BODY', 'BULK'... + Unparsable text: * + */ diff --git a/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_3.sql b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_3.sql new file mode 100644 index 0000000000..efcaa7c067 --- /dev/null +++ b/tests/resources/functional/tsql/core_engine/test_invalid_syntax/syntax_error_3.sql @@ -0,0 +1,23 @@ +-- The TSql parser is much better than the Snowflake one because it does not have the +-- super ambiguous LET statement that makes it impossible for batch level queries to +-- recover from syntax errors. Note here how TSQL has stupid grammar though as A and B +-- could be columns with a missing ',' but we cannot know. +-- tsql sql: +* ; +SELECT 1 ; +SELECT A B FROM C ; + +-- databricks sql: +/* The following issues were detected: + + Unparsed input - ErrorNode encountered + Unparsable text: unexpected extra input '*' while parsing a T-SQL batch + expecting one of: End of batch, Identifier, Select Statement, Statement, '$NODE_ID', '(', ';', 'ACCOUNTADMIN', 'ALERT', 'ARRAY', 'BODY', 'BULK'... + Unparsable text: * + */ +SELECT + 1; +SELECT + A AS B +FROM + C; diff --git a/tests/resources/functional/tsql/cte/cte_set_operation_precedence.sql b/tests/resources/functional/tsql/cte/cte_set_operation_precedence.sql new file mode 100644 index 0000000000..50980b98a3 --- /dev/null +++ b/tests/resources/functional/tsql/cte/cte_set_operation_precedence.sql @@ -0,0 +1,17 @@ +-- +-- CTEs are visible to all the SELECT queries within a subsequent sequence of set operations. +-- + +-- tsql sql: +WITH a AS (SELECT 1, 2, 3) + +SELECT 4, 5, 6 +UNION +SELECT * FROM a; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3) + +(SELECT 4, 5, 6) +UNION +(SELECT * FROM a); diff --git a/tests/resources/functional/tsql/cte/cte_with_column_list.sql b/tests/resources/functional/tsql/cte/cte_with_column_list.sql new file mode 100644 index 0000000000..a61e7f2e58 --- /dev/null +++ b/tests/resources/functional/tsql/cte/cte_with_column_list.sql @@ -0,0 +1,11 @@ +-- +-- A simple CTE, with the column list expressed. +-- + +-- tsql sql: +WITH a (b, c, d) AS (SELECT 1 AS b, 2 AS c, 3 AS d) +SELECT b, c, d FROM a; + +-- databricks sql: +WITH a (b, c, d) AS (SELECT 1 AS b, 2 AS c, 3 AS d) +SELECT b, c, d FROM a; diff --git a/tests/resources/functional/tsql/cte/multiple_cte.sql b/tests/resources/functional/tsql/cte/multiple_cte.sql new file mode 100644 index 0000000000..ba7c2ea66e --- /dev/null +++ b/tests/resources/functional/tsql/cte/multiple_cte.sql @@ -0,0 +1,19 @@ +-- +-- Verify a few CTEs that include multiple expressions. +-- + +-- tsql sql: +WITH a AS (SELECT 1, 2, 3), + b AS (SELECT 4, 5, 6), + c AS (SELECT * FROM a) +SELECT * from b +UNION +SELECT * FROM c; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3), + b AS (SELECT 4, 5, 6), + c AS (SELECT * FROM a) +(SELECT * from b) +UNION +(SELECT * FROM c); diff --git a/tests/resources/functional/tsql/cte/nested_set_operation.sql b/tests/resources/functional/tsql/cte/nested_set_operation.sql new file mode 100644 index 0000000000..7d1e0f380f --- /dev/null +++ b/tests/resources/functional/tsql/cte/nested_set_operation.sql @@ -0,0 +1,19 @@ +-- +-- Verify a CTE that includes set operations. +-- + +-- tsql sql: +WITH a AS ( + SELECT 1, 2, 3 + UNION + SELECT 4, 5, 6 + ) +SELECT * FROM a; + +-- databricks sql: +WITH a AS ( + (SELECT 1, 2, 3) + UNION + (SELECT 4, 5, 6) +) +SELECT * FROM a; diff --git a/tests/resources/functional/tsql/cte/simple_cte.sql b/tests/resources/functional/tsql/cte/simple_cte.sql new file mode 100644 index 0000000000..c393f15d36 --- /dev/null +++ b/tests/resources/functional/tsql/cte/simple_cte.sql @@ -0,0 +1,11 @@ +-- +-- Verify a simple CTE. +-- + +-- tsql sql: +WITH a AS (SELECT 1, 2, 3) +SELECT * FROM a; + +-- databricks sql: +WITH a AS (SELECT 1, 2, 3) +SELECT * FROM a; diff --git a/tests/resources/functional/tsql/functions/test_aadbts_1.sql b/tests/resources/functional/tsql/functions/test_aadbts_1.sql new file mode 100644 index 0000000000..a5d703b921 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aadbts_1.sql @@ -0,0 +1,9 @@ +-- ##@@DBTS +-- +-- The @@DBTS function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@DBTS; + +-- databricks sql: +SELECT @@DBTS; diff --git a/tests/resources/functional/tsql/functions/test_aalangid1.sql b/tests/resources/functional/tsql/functions/test_aalangid1.sql new file mode 100644 index 0000000000..0e271219fd --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aalangid1.sql @@ -0,0 +1,9 @@ +-- ##@@LANGID +-- +-- The @@LANGID function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@LANGID; + +-- databricks sql: +SELECT @@LANGID; diff --git a/tests/resources/functional/tsql/functions/test_aalanguage_1.sql b/tests/resources/functional/tsql/functions/test_aalanguage_1.sql new file mode 100644 index 0000000000..52a5ddd58f --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aalanguage_1.sql @@ -0,0 +1,9 @@ +-- ##@@LANGUAGE +-- +-- The @@LANGUAGE function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@LANGUAGE; + +-- databricks sql: +SELECT @@LANGUAGE; diff --git a/tests/resources/functional/tsql/functions/test_aalock_timeout_1.sql b/tests/resources/functional/tsql/functions/test_aalock_timeout_1.sql new file mode 100644 index 0000000000..e571d6b676 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aalock_timeout_1.sql @@ -0,0 +1,9 @@ +-- ##@@LOCKTIMEOUT +-- +-- The @@LOCKTIMEOUT function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@LOCKTIMEOUT; + +-- databricks sql: +SELECT @@LOCKTIMEOUT; diff --git a/tests/resources/functional/tsql/functions/test_aamax_connections_1.sql b/tests/resources/functional/tsql/functions/test_aamax_connections_1.sql new file mode 100644 index 0000000000..7e10a44447 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aamax_connections_1.sql @@ -0,0 +1,9 @@ +-- ##@@MAX_CONNECTIONS +-- +-- The @@MAX_CONNECTIONS function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@MAX_CONNECTIONS; + +-- databricks sql: +SELECT @@MAX_CONNECTIONS; diff --git a/tests/resources/functional/tsql/functions/test_aamax_precision_1.sql b/tests/resources/functional/tsql/functions/test_aamax_precision_1.sql new file mode 100644 index 0000000000..ff54d9b474 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aamax_precision_1.sql @@ -0,0 +1,9 @@ +-- ##@@MAX_PRECISION +-- +-- The @@MAX_PRECISION function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@MAX_PRECISION; + +-- databricks sql: +SELECT @@MAX_PRECISION; diff --git a/tests/resources/functional/tsql/functions/test_aaoptions_1.sql b/tests/resources/functional/tsql/functions/test_aaoptions_1.sql new file mode 100644 index 0000000000..fc70013be4 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaoptions_1.sql @@ -0,0 +1,9 @@ +-- ##@@OPTIONS +-- +-- The @@OPTIONS function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@OPTIONS; + +-- databricks sql: +SELECT @@OPTIONS; diff --git a/tests/resources/functional/tsql/functions/test_aaremserver_1.sql b/tests/resources/functional/tsql/functions/test_aaremserver_1.sql new file mode 100644 index 0000000000..c4a2d7e053 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaremserver_1.sql @@ -0,0 +1,9 @@ +-- ##@@REMSERVER +-- +-- The @@REMSERVER function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@REMSERVER; + +-- databricks sql: +SELECT @@REMSERVER; diff --git a/tests/resources/functional/tsql/functions/test_aaservername_1.sql b/tests/resources/functional/tsql/functions/test_aaservername_1.sql new file mode 100644 index 0000000000..88ffeff5c8 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaservername_1.sql @@ -0,0 +1,9 @@ +-- ##@@SERVERNAME +-- +-- The @@SERVERNAME function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@SERVERNAME; + +-- databricks sql: +SELECT @@SERVERNAME; diff --git a/tests/resources/functional/tsql/functions/test_aaservicename_1.sql b/tests/resources/functional/tsql/functions/test_aaservicename_1.sql new file mode 100644 index 0000000000..0b7e9651d8 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaservicename_1.sql @@ -0,0 +1,9 @@ +-- ##@@SERVICENAME +-- +-- The @@SERVICENAME function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@SERVICENAME; + +-- databricks sql: +SELECT @@SERVICENAME; diff --git a/tests/resources/functional/tsql/functions/test_aaspid_1.sql b/tests/resources/functional/tsql/functions/test_aaspid_1.sql new file mode 100644 index 0000000000..229f5897ed --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaspid_1.sql @@ -0,0 +1,9 @@ +-- ##@@SPID +-- +-- The @@DBTS function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@SPID; + +-- databricks sql: +SELECT @@SPID; diff --git a/tests/resources/functional/tsql/functions/test_aatextsize_1.sql b/tests/resources/functional/tsql/functions/test_aatextsize_1.sql new file mode 100644 index 0000000000..4c27303445 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aatextsize_1.sql @@ -0,0 +1,9 @@ +-- ##@@TEXTSIZE +-- +-- The @@TEXTSIZE function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@TEXTSIZE; + +-- databricks sql: +SELECT @@TEXTSIZE; diff --git a/tests/resources/functional/tsql/functions/test_aaversion_1.sql b/tests/resources/functional/tsql/functions/test_aaversion_1.sql new file mode 100644 index 0000000000..cf276afc58 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_aaversion_1.sql @@ -0,0 +1,9 @@ +-- ##@@VERSION +-- +-- The @@VERSION function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@VERSION; + +-- databricks sql: +SELECT @@VERSION; diff --git a/tests/resources/functional/tsql/functions/test_abs_1.sql b/tests/resources/functional/tsql/functions/test_abs_1.sql new file mode 100644 index 0000000000..524fa7b53c --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_abs_1.sql @@ -0,0 +1,9 @@ +-- ##ABS +-- +-- The ABS function is identical in TSql and Databricks. +-- +-- tsql sql: +SELECT abs(col1) AS abs_col1 FROM tabl; + +-- databricks sql: +SELECT ABS(col1) AS abs_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_approx_count_distinct.sql b/tests/resources/functional/tsql/functions/test_approx_count_distinct.sql new file mode 100644 index 0000000000..d59dc81b68 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_approx_count_distinct.sql @@ -0,0 +1,11 @@ +-- ##APPROX_COUNT_DISTINCT +-- +-- This function is identical to the APPROX_COUNT_DISTINCT function in Databricks. Though +-- the syntax is the same, the results may differ slightly due to the difference in the implementations +-- and the fact that it is an approximation. + +-- tsql sql: +SELECT APPROX_COUNT_DISTINCT(col1) AS approx_count_distinct_col1 FROM tabl; + +-- databricks sql: +SELECT APPROX_COUNT_DISTINCT(col1) AS approx_count_distinct_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_approx_percentile_cont_1.sql b/tests/resources/functional/tsql/functions/test_approx_percentile_cont_1.sql new file mode 100644 index 0000000000..1ca5ad3e76 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_approx_percentile_cont_1.sql @@ -0,0 +1,12 @@ +-- ##APPROX_PERCENTILE_CONT +-- +-- Note that TSQL uses a continuous distribution model and requires an ORDER BY clause. +-- Databricks uses an approximate distribution algorithm, and does not require an ORDER BY clause. +-- The results may differ slightly due to the difference, but as teh result from both is an approximation, +-- the difference is unlikely to be significant. + +-- tsql sql: +SELECT APPROX_PERCENTILE_CONT(col1) WITHIN GROUP (ORDER BY something) AS approx_percentile_col1 FROM tabl; + +-- databricks sql: +SELECT APPROX_PERCENTILE(col1) AS approx_percentile_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_approx_percentile_disc_1.sql b/tests/resources/functional/tsql/functions/test_approx_percentile_disc_1.sql new file mode 100644 index 0000000000..701b16c052 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_approx_percentile_disc_1.sql @@ -0,0 +1,10 @@ +-- ## APPROX_PERCENTILE_DISC +-- +-- This function has no direct equivalent in Databricks. The closest equivalent is the PERCENTILE function. +-- Approximations are generally faster then exact calculations, so performance may be something to explore. + +-- tsql sql: +SELECT APPROX_PERCENTILE_DISC(0.5) WITHIN GROUP(ORDER BY col1) AS percent50 FROM tabl; + +-- databricks sql: +SELECT PERCENTILE(col1, 0.5) AS percent50 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_avg_1.sql b/tests/resources/functional/tsql/functions/test_avg_1.sql new file mode 100644 index 0000000000..fe2fd6c292 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_avg_1.sql @@ -0,0 +1,9 @@ +-- ## AVG +-- +-- This function is directly equivalent in Databricks SQL. + +-- tsql sql: +SELECT AVG(col1) AS vagcol1 FROM tabl; + +-- databricks sql: +SELECT AVG(col1) AS vagcol1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_avg_2.sql b/tests/resources/functional/tsql/functions/test_avg_2.sql new file mode 100644 index 0000000000..51571bb834 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_avg_2.sql @@ -0,0 +1,9 @@ +-- ## AVG with DISTINCT clause +-- +-- This function is directly equivalent in Databricks. + +-- tsql sql: +SELECT AVG(DISTINCT col1) AS vagcol1 FROM tabl; + +-- databricks sql: +SELECT AVG(DISTINCT col1) AS vagcol1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_avg_3.sql b/tests/resources/functional/tsql/functions/test_avg_3.sql new file mode 100644 index 0000000000..6b0ac4d7ef --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_avg_3.sql @@ -0,0 +1,11 @@ +-- ## AVG with ALL clause +-- +-- This function is directly equivalent in Databricks. +-- However, as ALL does not change the result, it is not necessary to include it in the Databricks SQL and +-- it is elided. + +-- tsql sql: +SELECT AVG(ALL col1) AS vagcol1 FROM tabl; + +-- databricks sql: +SELECT AVG(col1) AS vagcol1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_avg_4.sql b/tests/resources/functional/tsql/functions/test_avg_4.sql new file mode 100644 index 0000000000..45533f580c --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_avg_4.sql @@ -0,0 +1,9 @@ +-- ## AVG with the DISTINCT and OVER clauses +-- +-- This function is directly equivalent in Databricks when used with the DISTINCT clause, + +-- tsql sql: +SELECT AVG(DISTINCT col1) OVER (PARTITION BY col1 ORDER BY col1 ASC) AS avgcol1 FROM tabl; + +-- databricks sql: +SELECT AVG(DISTINCT col1) OVER (PARTITION BY col1 ORDER BY col1 ASC) AS avgcol1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_avg_5.sql b/tests/resources/functional/tsql/functions/test_avg_5.sql new file mode 100644 index 0000000000..43db9a03a3 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_avg_5.sql @@ -0,0 +1,11 @@ +-- ## AVG with the DISTINCT and OVER clauses +-- +-- This function is directly equivalent in Databricks, when used with the DISTINCT clause. +-- +-- tsql sql: +SELECT AVG(DISTINCT col1) OVER (PARTITION BY col1 ORDER BY col1 ASC) AS avgcol1, + AVG(DISTINCT col2) FROM tabl; + +-- databricks sql: +SELECT AVG(DISTINCT col1) OVER (PARTITION BY col1 ORDER BY col1 ASC) AS avgcol1, + AVG(DISTINCT col2) FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_bit_count_1.sql b/tests/resources/functional/tsql/functions/test_bit_count_1.sql new file mode 100644 index 0000000000..9c8cd17ce4 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_bit_count_1.sql @@ -0,0 +1,9 @@ +-- ## BIT_COUNT +-- +-- The BIT_COUNT function is identical in TSql and Databricks. +-- +-- tsql sql: +SELECT BIT_COUNT(42); + +-- databricks sql: +SELECT BIT_COUNT(42); diff --git a/tests/resources/functional/tsql/functions/test_checksum_agg_1.sql b/tests/resources/functional/tsql/functions/test_checksum_agg_1.sql new file mode 100644 index 0000000000..780c5a23bf --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_checksum_agg_1.sql @@ -0,0 +1,10 @@ +-- ## CHECKSUM_AGG +-- +-- There is no direct equivalent of CHECKSUM_AGG in Databricks SQL. The following +-- conversion is a suggestion and may not be perfectly functional. + +-- tsql sql: +SELECT CHECKSUM_AGG(col1) FROM t1; + +-- databricks sql: +SELECT MD5(CONCAT_WS(',', ARRAY_AGG(col1))) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_collationproperty_1.sql b/tests/resources/functional/tsql/functions/test_collationproperty_1.sql new file mode 100644 index 0000000000..c61695ca99 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_collationproperty_1.sql @@ -0,0 +1,10 @@ +-- ##COLLATIONPROPERTY +-- +-- The COLLATIONPROPERTY function is unsupported in Databricks SQL as collation +-- tends to be a function of the underlying system. +-- +-- tsql sql: +SELECT COLLATIONPROPERTY('somelocale', 'someproperty'); + +-- databricks sql: +SELECT COLLATIONPROPERTY('somelocale', 'someproperty'); diff --git a/tests/resources/functional/tsql/functions/test_count_1.sql b/tests/resources/functional/tsql/functions/test_count_1.sql new file mode 100644 index 0000000000..7ea94f6d7a --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_count_1.sql @@ -0,0 +1,9 @@ +-- ## COUNT +-- +-- The TSQl COUNT function and the DataBricks COUNT function are equivalent. +-- +-- tsql sql: +SELECT COUNT(*) FROM t1; + +-- databricks sql: +SELECT COUNT(*) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_count_2.sql b/tests/resources/functional/tsql/functions/test_count_2.sql new file mode 100644 index 0000000000..16cecab7a9 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_count_2.sql @@ -0,0 +1,9 @@ +-- ## COUNT +-- +-- The TSQl COUNT function and the DataBricks COUNT function are equivalent. +-- +-- tsql sql: +SELECT COUNT(DISTINCT col1) FROM t1; + +-- databricks sql: +SELECT COUNT(DISTINCT col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_cume_dist_1.sql b/tests/resources/functional/tsql/functions/test_cume_dist_1.sql new file mode 100644 index 0000000000..96115f8c55 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_cume_dist_1.sql @@ -0,0 +1,9 @@ +-- ## CUME_DIST +-- +-- The CUME_DIST function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, cume_dist() OVER (PARTITION BY col1 ORDER BY col2) AS cume_dist FROM tabl; + +-- databricks sql: +SELECT col1, col2, cume_dist() OVER (PARTITION BY col1 ORDER BY col2) AS cume_dist FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_day_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_day_1.sql new file mode 100644 index 0000000000..425d0c6ffe --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_day_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the DAY keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(day, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_day_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_day_2.sql new file mode 100644 index 0000000000..3084c2f5a2 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_day_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the DD keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(dd, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_day_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_day_3.sql new file mode 100644 index 0000000000..03f603b434 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_day_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the D keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(d, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_1.sql new file mode 100644 index 0000000000..3ee1b8379e --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the DAYOFYEAR keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(dayofyear, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_2.sql new file mode 100644 index 0000000000..71a7c0a0b2 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the DY keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(dy, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_3.sql new file mode 100644 index 0000000000..1d2a2ab880 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_dayofyear_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the Y keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(y, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_hour_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_hour_1.sql new file mode 100644 index 0000000000..f5f4ab8694 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_hour_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the HOUR keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment HOUR + +-- tsql sql: +SELECT DATEADD(hour, 7, col1) AS add_hours_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 HOUR AS add_hours_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_hour_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_hour_2.sql new file mode 100644 index 0000000000..77d69cd154 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_hour_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the HH keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment HOUR + +-- tsql sql: +SELECT DATEADD(hh, 7, col1) AS add_hours_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 HOUR AS add_hours_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_microsecond_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_microsecond_1.sql new file mode 100644 index 0000000000..c66b2a17f5 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_microsecond_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MICROSECOND keyword +-- +-- Databricks SQL does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MICROSECOND + +-- tsql sql: +SELECT DATEADD(MICROSECOND, 7, col1) AS add_microsecond_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MICROSECOND AS add_microsecond_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_microsecond_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_microsecond_2.sql new file mode 100644 index 0000000000..ed5bd991cc --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_microsecond_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MCS keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MICROSECOND + +-- tsql sql: +SELECT DATEADD(mcs, 7, col1) AS add_microsecond_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MICROSECOND AS add_microsecond_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_millisecond_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_millisecond_1.sql new file mode 100644 index 0000000000..22af9c8265 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_millisecond_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MILLISECOND keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MILLISECOND + +-- tsql sql: +SELECT DATEADD(millisecond, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MILLISECOND AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_millisecond_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_millisecond_2.sql new file mode 100644 index 0000000000..bf04f38770 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_millisecond_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MS keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MILLISECOND + +-- tsql sql: +SELECT DATEADD(ms, 7, col1) AS add_milliseconds_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MILLISECOND AS add_milliseconds_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_minute_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_minute_1.sql new file mode 100644 index 0000000000..85de3a7c9d --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_minute_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MINUTE keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MINUTE + +-- tsql sql: +SELECT DATEADD(minute, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MINUTE AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_minute_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_minute_2.sql new file mode 100644 index 0000000000..51b564a178 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_minute_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MI keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MINUTE + +-- tsql sql: +SELECT DATEADD(mi, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MINUTE AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_minute_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_minute_3.sql new file mode 100644 index 0000000000..fba1d5e8b8 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_minute_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the MI keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment MINUTE + +-- tsql sql: +SELECT DATEADD(n, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 MINUTE AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_month_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_month_1.sql new file mode 100644 index 0000000000..ec2287bb08 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_month_1.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the MONTH keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function. + + +-- tsql sql: +SELECT DATEADD(MONTH, 1, col1) AS add_months_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 1) AS add_months_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_month_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_month_2.sql new file mode 100644 index 0000000000..4f8868a1d9 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_month_2.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the MM keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function. + + +-- tsql sql: +SELECT DATEADD(mm, 1, col1) AS add_months_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 1) AS add_months_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_month_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_month_3.sql new file mode 100644 index 0000000000..61cc95f0d5 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_month_3.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the M keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function. + + +-- tsql sql: +SELECT DATEADD(m, 1, col1) AS add_months_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 1) AS add_months_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_1.sql new file mode 100644 index 0000000000..688c836591 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the NANOSECOND keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment NANOSECOND + +-- tsql sql: +SELECT DATEADD(NANOSECOND, 7, col1) AS add_nanoseconds_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 NANOSECOND AS add_nanoseconds_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_2.sql new file mode 100644 index 0000000000..63886749a7 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_nanosecond_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the NANOSECOND keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment NS + +-- tsql sql: +SELECT DATEADD(NS, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 NANOSECOND AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_quarter_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_quarter_1.sql new file mode 100644 index 0000000000..5d8451c6a5 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_quarter_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the QUARTER keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 3. + +-- tsql sql: +SELECT DATEADD(QUARTER, 2, col1) AS add_quarters_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 3) AS add_quarters_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_quarter_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_quarter_2.sql new file mode 100644 index 0000000000..a34dca06d0 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_quarter_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the QQ keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 3. + +-- tsql sql: +SELECT DATEADD(qq, 2, col1) AS add_quarters_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 3) AS add_quarters_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_quarter_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_quarter_3.sql new file mode 100644 index 0000000000..2c3f4ef436 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_quarter_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the Q keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 3. + +-- tsql sql: +SELECT DATEADD(q, 2, col1) AS add_quarters_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 3) AS add_quarters_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_second_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_second_1.sql new file mode 100644 index 0000000000..7c9a326599 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_second_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the SECOND keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment SECOND + +-- tsql sql: +SELECT DATEADD(second, 7, col1) AS add_seconds_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 SECOND AS add_seconds_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_second_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_second_2.sql new file mode 100644 index 0000000000..0f2829dceb --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_second_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the SS keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment SECOND + +-- tsql sql: +SELECT DATEADD(ss, 7, col1) AS add_seconds_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 SECOND AS add_seconds_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_second_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_second_3.sql new file mode 100644 index 0000000000..0853177f0b --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_second_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the S keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- INTERVAL increment SECOND + +-- tsql sql: +SELECT DATEADD(s, 7, col1) AS add_minutes_col1 FROM tabl; + +-- databricks sql: +SELECT col1 + INTERVAL 7 SECOND AS add_minutes_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_week_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_week_1.sql new file mode 100644 index 0000000000..05991b882b --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_week_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the WEEK keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD with the number of weeks multiplied by 7. + +-- tsql sql: +SELECT DATEADD(week, 2, col1) AS add_weeks_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2 * 7) AS add_weeks_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_week_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_week_2.sql new file mode 100644 index 0000000000..9a132835a1 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_week_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the WK keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD with the number of weeks multiplied by 7. + +-- tsql sql: +SELECT DATEADD(wk, 2, col1) AS add_weeks_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2 * 7) AS add_weeks_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_week_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_week_3.sql new file mode 100644 index 0000000000..ca8ecbcc96 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_week_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the WW keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD with the number of weeks multiplied by 7. + +-- tsql sql: +SELECT DATEADD(ww, 2, col1) AS add_weeks_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2 * 7) AS add_weeks_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_weekday_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_weekday_1.sql new file mode 100644 index 0000000000..e7bb477482 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_weekday_1.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the WEEKDAY keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(weekday, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_weekday_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_weekday_2.sql new file mode 100644 index 0000000000..6d232dd5d9 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_weekday_2.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the DW keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(DW, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_weekday_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_weekday_3.sql new file mode 100644 index 0000000000..cac0617c56 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_weekday_3.sql @@ -0,0 +1,10 @@ +-- ## DATEADD with the W keyword +-- +-- Databricks SQl does not directly support `DATEADD`, so it is translated to the equivalent +-- DATE_ADD as in the context of `DATEADD`, `day`, `dayofyear` and `weekday` are equivalent. + +-- tsql sql: +SELECT DATEADD(W, 2, col1) AS add_days_col1 FROM tabl; + +-- databricks sql: +SELECT DATE_ADD(col1, 2) AS add_days_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_year_1.sql b/tests/resources/functional/tsql/functions/test_dateadd_year_1.sql new file mode 100644 index 0000000000..07d50f6cf7 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_year_1.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the YEAR keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 12. + + +-- tsql sql: +SELECT DATEADD(YEAR, 2, col1) AS add_years_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 12) AS add_years_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_year_2.sql b/tests/resources/functional/tsql/functions/test_dateadd_year_2.sql new file mode 100644 index 0000000000..257d66609b --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_year_2.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the YYYY keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 12. + + +-- tsql sql: +SELECT DATEADD(yyyy, 2, col1) AS add_years_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 12) AS add_years_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_dateadd_year_3.sql b/tests/resources/functional/tsql/functions/test_dateadd_year_3.sql new file mode 100644 index 0000000000..fd16646f9d --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_dateadd_year_3.sql @@ -0,0 +1,11 @@ +-- ## DATEADD with the YY keyword +-- +-- Databricks SQl does not directly support DATEADD, so it is translated to the equivalent +-- ADD_MONTHS function with the number of months multiplied by 12. + + +-- tsql sql: +SELECT DATEADD(yy, 2, col1) AS add_years_col1 FROM tabl; + +-- databricks sql: +SELECT ADD_MONTHS(col1, 2 * 12) AS add_years_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_first_value_1.sql b/tests/resources/functional/tsql/functions/test_first_value_1.sql new file mode 100644 index 0000000000..7fe13fcef3 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_first_value_1.sql @@ -0,0 +1,9 @@ +-- ## FIRST_VALUE +-- +-- The FIRST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, FIRST_VALUE(col1) OVER (ORDER BY col2 DESC) AS first_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, FIRST_VALUE(col1) OVER (ORDER BY col2 DESC) AS first_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_first_value_2.sql b/tests/resources/functional/tsql/functions/test_first_value_2.sql new file mode 100644 index 0000000000..6e2241b840 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_first_value_2.sql @@ -0,0 +1,9 @@ +-- ## FIRST_VALUE +-- +-- The FIRST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, FIRST_VALUE(col1) OVER (ORDER BY col2 ASC) AS first_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, FIRST_VALUE(col1) OVER (ORDER BY col2 ASC) AS first_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_first_value_3.sql b/tests/resources/functional/tsql/functions/test_first_value_3.sql new file mode 100644 index 0000000000..1cfad5486c --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_first_value_3.sql @@ -0,0 +1,9 @@ +-- ## FIRST_VALUE over PARTITIONS +-- +-- The FIRST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, col3, FIRST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 DESC) AS first_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, col3, FIRST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 DESC) AS first_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_first_value_4.sql b/tests/resources/functional/tsql/functions/test_first_value_4.sql new file mode 100644 index 0000000000..480d859d63 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_first_value_4.sql @@ -0,0 +1,9 @@ +-- ## FIRST_VALUE over PARTITIONS +-- +-- The FIRST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, col3, FIRST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 ASC ROWS UNBOUNDED PRECEDING) AS first_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, col3, FIRST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 ASC ROWS UNBOUNDED PRECEDING) AS first_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_get_bit_1.sql b/tests/resources/functional/tsql/functions/test_get_bit_1.sql new file mode 100644 index 0000000000..010ade0130 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_get_bit_1.sql @@ -0,0 +1,10 @@ +-- ## GET_BIT +-- +-- The GET_BIT function is not supported in Databricks SQL. The following example +-- shows how to convert it to a Databricks equivalent +-- +-- tsql sql: +SELECT GET_BIT(42, 7); + +-- databricks sql: +SELECT GETBIT(42, 7); diff --git a/tests/resources/functional/tsql/functions/test_grouping_1.sql b/tests/resources/functional/tsql/functions/test_grouping_1.sql new file mode 100644 index 0000000000..76eb11bbcc --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_grouping_1.sql @@ -0,0 +1,15 @@ +-- ## GROUPING +-- +-- The TSql GROUPING() function is directly equivalent in Databricks SQL. There are however +-- some differences that should be accounted for. +-- +-- - TSql supports GROUPING on column aliases, while Databricks SQL does not. +-- - TSql allows GROUPING to be used in GROUPING SETS, The GROUPING function in Databricks +-- does not support GROUPING SETS +-- - TSql returns a 1 or 0, whereas Databricks returns a boolean + +-- tsql sql: +SELECT GROUPING(col1) As g FROM t1 GROUP BY g WITH ROLLUP; + +-- databricks sql: +SELECT GROUPING(col1) as g FROM t1 GROUP BY col1 WITH ROLLUP; diff --git a/tests/resources/functional/tsql/functions/test_grouping_id_1.sql b/tests/resources/functional/tsql/functions/test_grouping_id_1.sql new file mode 100644 index 0000000000..26bc39cec5 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_grouping_id_1.sql @@ -0,0 +1,9 @@ +-- ## GROUPING_ID +-- +-- GROUPING_ID is directly equivalent in Databricks SQL and TSQL. + +-- tsql sql: +SELECT GROUPING_ID(col1, col2) FROM t1 GROUP BY CUBE(col1, col2); + +-- databricks sql: +SELECT GROUPING_ID(col1, col2) FROM t1 GROUP BY CUBE(col1, col2); diff --git a/tests/resources/functional/tsql/functions/test_grouping_id_2.sql b/tests/resources/functional/tsql/functions/test_grouping_id_2.sql new file mode 100644 index 0000000000..dd14143e74 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_grouping_id_2.sql @@ -0,0 +1,9 @@ +-- ## GROUPING_ID +-- +-- GROUPING_ID is directly equivalent in Databricks SQL and TSQL. + +-- tsql sql: +SELECT GROUPING_ID(col1, col2) As someAlias FROM t1 GROUP BY CUBE(col1, col2); + +-- databricks sql: +SELECT GROUPING_ID(col1, col2) AS someAlias FROM t1 GROUP BY CUBE(col1, col2); diff --git a/tests/resources/functional/tsql/functions/test_isnull_1.sql b/tests/resources/functional/tsql/functions/test_isnull_1.sql new file mode 100644 index 0000000000..0dbd457710 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_isnull_1.sql @@ -0,0 +1,13 @@ +-- ## ISNULL +-- +-- In TSQL ISNULL is a function that returns the first expression if it is not NULL, +-- otherwise it returns the second expression. +-- +-- In Databricks ISNULL is a function that returns boolean true if the single argument is NULL, +-- so it is replaced with IFNULL, which is the equivalent function in Databricks. +-- +-- tsql sql: +SELECT ISNULL(col1, 0) AS pcol1 FROM table; + +-- databricks sql: +SELECT IFNULL(col1, 0) AS pcol1 FROM table; diff --git a/tests/resources/functional/tsql/functions/test_lag_1.sql b/tests/resources/functional/tsql/functions/test_lag_1.sql new file mode 100644 index 0000000000..de01e502b0 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_lag_1.sql @@ -0,0 +1,9 @@ +-- ## LAG +-- +-- The LAG function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, LAG(col2, 1, 0) OVER (ORDER BY col2 DESC) AS lag_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LAG(col2, 1, 0) OVER (ORDER BY col2 DESC) AS lag_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_lag_2.sql b/tests/resources/functional/tsql/functions/test_lag_2.sql new file mode 100644 index 0000000000..75d9612037 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_lag_2.sql @@ -0,0 +1,9 @@ +-- ## LAG ignoring NULL values +-- +-- The LAG function is identical in TSql and Databricks when IGNORING or RESPECTING NULLS (default). + +-- tsql sql: +SELECT col1, col2, LAG(col2, 1, 0) IGNORE NULLS OVER (ORDER BY col2 DESC) AS lag_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LAG(col2, 1, 0) IGNORE NULLS OVER (ORDER BY col2 DESC) AS lag_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_last_value_1.sql b/tests/resources/functional/tsql/functions/test_last_value_1.sql new file mode 100644 index 0000000000..ab4c8813bc --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_last_value_1.sql @@ -0,0 +1,9 @@ +-- ## LAST_VALUE +-- +-- The LAST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, LAST_VALUE(col1) OVER (ORDER BY col2 DESC) AS last_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LAST_VALUE(col1) OVER (ORDER BY col2 DESC) AS last_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_last_value_2.sql b/tests/resources/functional/tsql/functions/test_last_value_2.sql new file mode 100644 index 0000000000..8fc53695c0 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_last_value_2.sql @@ -0,0 +1,9 @@ +-- ## LAST_VALUE +-- +-- The LAST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, LAST_VALUE(col1) OVER (ORDER BY col2 ASC) AS last_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LAST_VALUE(col1) OVER (ORDER BY col2 ASC) AS last_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_last_value_3.sql b/tests/resources/functional/tsql/functions/test_last_value_3.sql new file mode 100644 index 0000000000..cf08f47d31 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_last_value_3.sql @@ -0,0 +1,9 @@ +-- ## LAST_VALUE over PARTITIONS +-- +-- The LAST_VALUE function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, col3, LAST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 DESC) AS last_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, col3, LAST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 DESC) AS last_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_last_value_4.sql b/tests/resources/functional/tsql/functions/test_last_value_4.sql new file mode 100644 index 0000000000..a9a31c7269 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_last_value_4.sql @@ -0,0 +1,10 @@ +-- ## LAST_VALUE over PARTITIONS +-- +-- The LAST_VALUE function is identical in TSql and Databricks. +-- + +-- tsql sql: +SELECT col1, col2, col3, LAST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 ASC ROWS UNBOUNDED PRECEDING) AS last_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, col3, LAST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col2 ASC ROWS UNBOUNDED PRECEDING) AS last_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_lead_1.sql b/tests/resources/functional/tsql/functions/test_lead_1.sql new file mode 100644 index 0000000000..86efb0772f --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_lead_1.sql @@ -0,0 +1,9 @@ +-- ## LEAD +-- +-- The LEAD function is identical in TSql and Databricks. + +-- tsql sql: +SELECT col1, col2, LEAD(col2, 1, 0) OVER (ORDER BY col2 DESC) AS lead_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LEAD(col2, 1, 0) OVER (ORDER BY col2 DESC) AS lead_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_lead_2.sql b/tests/resources/functional/tsql/functions/test_lead_2.sql new file mode 100644 index 0000000000..a2737b3a82 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_lead_2.sql @@ -0,0 +1,9 @@ +-- ## LEAD ignoring NULL values +-- +-- The LEAD function is identical in TSql and Databricks when IGNORING or RESPECTING NULLS (default). + +-- tsql sql: +SELECT col1, col2, LEAD(col2, 1, 0) IGNORE NULLS OVER (ORDER BY col2 DESC) AS lead_value FROM tabl; + +-- databricks sql: +SELECT col1, col2, LEAD(col2, 1, 0) IGNORE NULLS OVER (ORDER BY col2 DESC) AS lead_value FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_left_shift_1.sql b/tests/resources/functional/tsql/functions/test_left_shift_1.sql new file mode 100644 index 0000000000..7a6fbb55bc --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_left_shift_1.sql @@ -0,0 +1,9 @@ +-- ## LEFT_SHIFT +-- +-- The LEFT_SHIFT is identical in TSql and Databricks other than naming style. +-- +-- tsql sql: +SELECT LEFT_SHIFT(42, 7); + +-- databricks sql: +SELECT LEFTSHIFT(42, 7); diff --git a/tests/resources/functional/tsql/functions/test_max.2.sql b/tests/resources/functional/tsql/functions/test_max.2.sql new file mode 100644 index 0000000000..5a73496f9b --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_max.2.sql @@ -0,0 +1,10 @@ +-- ## MAX with DISTINCT +-- +-- The MAX function is identical in Databricks SQL and T-SQL. As DISTINCT is merely removing duplicates, +-- its presence or otherwise is irrelevant to the MAX function. + +-- tsql sql: +SELECT MAX(DISTINCT col1) FROM t1; + +-- databricks sql: +SELECT MAX(DISTINCT col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_max_1.sql b/tests/resources/functional/tsql/functions/test_max_1.sql new file mode 100644 index 0000000000..d87f3a3ad0 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_max_1.sql @@ -0,0 +1,9 @@ +-- ## MAX +-- +-- The MAX function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT MAX(col1) FROM t1; + +-- databricks sql: +SELECT MAX(col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_min.1.sql b/tests/resources/functional/tsql/functions/test_min.1.sql new file mode 100644 index 0000000000..7e9f734f62 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_min.1.sql @@ -0,0 +1,9 @@ +-- ## MIN +-- +-- The MIN function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT MIN(col1) FROM t1; + +-- databricks sql: +SELECT MIN(col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_min.2.sql b/tests/resources/functional/tsql/functions/test_min.2.sql new file mode 100644 index 0000000000..5a73496f9b --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_min.2.sql @@ -0,0 +1,10 @@ +-- ## MAX with DISTINCT +-- +-- The MAX function is identical in Databricks SQL and T-SQL. As DISTINCT is merely removing duplicates, +-- its presence or otherwise is irrelevant to the MAX function. + +-- tsql sql: +SELECT MAX(DISTINCT col1) FROM t1; + +-- databricks sql: +SELECT MAX(DISTINCT col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_nestlevel_1.sql b/tests/resources/functional/tsql/functions/test_nestlevel_1.sql new file mode 100644 index 0000000000..eb04d0536f --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_nestlevel_1.sql @@ -0,0 +1,9 @@ +-- ##@@NESTLEVEL +-- +-- The @@NESTLEVEL function is unsupported in Databricks SQL +-- +-- tsql sql: +SELECT @@NESTLEVEL; + +-- databricks sql: +SELECT @@NESTLEVEL; diff --git a/tests/resources/functional/tsql/functions/test_percent_rank_1.sql b/tests/resources/functional/tsql/functions/test_percent_rank_1.sql new file mode 100644 index 0000000000..484f2900fa --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_percent_rank_1.sql @@ -0,0 +1,9 @@ +-- ## PERCENT_RANK +-- +-- The PERCENT_RANK is identical in TSql and Databricks. +-- +-- tsql sql: +SELECT PERCENT_RANK() OVER (ORDER BY col2 DESC) AS lead_value FROM tabl; + +-- databricks sql: +SELECT PERCENTILE(col1, 0.5) AS percent50 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_percentile_cont_1.sql b/tests/resources/functional/tsql/functions/test_percentile_cont_1.sql new file mode 100644 index 0000000000..f82acb8752 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_percentile_cont_1.sql @@ -0,0 +1,12 @@ +-- ## PERCENTILE_CONT +-- +-- Note that TSQL uses a continuous distribution model and requires an ORDER BY clause. +-- Databricks uses an approximate distribution algorithm, and does not require an ORDER BY clause. +-- The results may differ slightly due to the difference, but as the result from both is an approximation, +-- the difference is unlikely to be significant. + +-- tsql sql: +SELECT PERCENTILE_CONT(col1) WITHIN GROUP (ORDER BY something) AS approx_percentile_col1 FROM tabl; + +-- databricks sql: +SELECT PERCENTILE(col1) AS approx_percentile_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_percentile_disc_1.sql b/tests/resources/functional/tsql/functions/test_percentile_disc_1.sql new file mode 100644 index 0000000000..c36967aef1 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_percentile_disc_1.sql @@ -0,0 +1,10 @@ +-- ## PERCENTILE_DISC +-- +-- This function has no direct equivalent in Databricks. The closest equivalent is the PERCENTILE function. +-- Approximations are generally faster then exact calculations, so performance may be something to explore. + +-- tsql sql: +SELECT PERCENTILE_DISC(0.5) WITHIN GROUP(ORDER BY col1) AS percent50 FROM tabl; + +-- databricks sql: +SELECT PERCENTILE(col1, 0.5) AS percent50 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_right_shift_1.sql b/tests/resources/functional/tsql/functions/test_right_shift_1.sql new file mode 100644 index 0000000000..0fc1fcf0c5 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_right_shift_1.sql @@ -0,0 +1,9 @@ +-- ## RIGHT_SHIFT +-- +-- The RIGHT_SHIFT is identical in TSql and Databricks. +-- +-- tsql sql: +SELECT RIGHT_SHIFT(42, 7); + +-- databricks sql: +SELECT RIGHTSHIFT(42, 7); diff --git a/tests/resources/functional/tsql/functions/test_set_bit_1.sql b/tests/resources/functional/tsql/functions/test_set_bit_1.sql new file mode 100644 index 0000000000..c832a05056 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_set_bit_1.sql @@ -0,0 +1,9 @@ +-- ## SET_BIT +-- +-- The SET_BIT function does not exist in Databricks SQL, so we must use bit functions +-- +-- tsql sql: +SELECT SET_BIT(42, 7); + +-- databricks sql: +SELECT 42 | SHIFTLEFT(1, 7); diff --git a/tests/resources/functional/tsql/functions/test_set_bit_2.sql b/tests/resources/functional/tsql/functions/test_set_bit_2.sql new file mode 100644 index 0000000000..8213e82fda --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_set_bit_2.sql @@ -0,0 +1,9 @@ +-- ## SET_BIT +-- +-- The SET_BIT function is identical in TSql and Databricks, save for a renaming of the function. +-- +-- tsql sql: +SELECT SET_BIT(42, 7, 0); + +-- databricks sql: +SELECT 42 & -1 ^ SHIFTLEFT(1, 7) | SHIFTRIGHT(0, 7); diff --git a/tests/resources/functional/tsql/functions/test_stdev_1.sql b/tests/resources/functional/tsql/functions/test_stdev_1.sql new file mode 100644 index 0000000000..b1ac41e277 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_stdev_1.sql @@ -0,0 +1,9 @@ +-- ## STDEV +-- +-- The STDEV function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT STDEV(col1) FROM t1; + +-- databricks sql: +SELECT STDEV(col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_stdev_2.sql b/tests/resources/functional/tsql/functions/test_stdev_2.sql new file mode 100644 index 0000000000..392cf2aafd --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_stdev_2.sql @@ -0,0 +1,9 @@ +-- ## STDEV +-- +-- The STDEV function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT STDEV(DISTINCT col1) FROM t1; + +-- databricks sql: +SELECT STDEV(DISTINCT col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_stdevp_1.sql b/tests/resources/functional/tsql/functions/test_stdevp_1.sql new file mode 100644 index 0000000000..96801dcd26 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_stdevp_1.sql @@ -0,0 +1,9 @@ +-- ## STDEVP +-- +-- The STDEVP function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT STDEVP(col1) FROM t1; + +-- databricks sql: +SELECT STDEVP(col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_stdevp_2.sql b/tests/resources/functional/tsql/functions/test_stdevp_2.sql new file mode 100644 index 0000000000..01c41899b9 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_stdevp_2.sql @@ -0,0 +1,9 @@ +-- ## STDEVP +-- +-- The STDEVP function is identical in Databricks SQL and T-SQL. + +-- tsql sql: +SELECT STDEVP(DISTINCT col1) FROM t1; + +-- databricks sql: +SELECT STDEVP(DISTINCT col1) FROM t1; diff --git a/tests/resources/functional/tsql/functions/test_sum_1.sql b/tests/resources/functional/tsql/functions/test_sum_1.sql new file mode 100644 index 0000000000..45be4e146a --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_sum_1.sql @@ -0,0 +1,9 @@ +-- ## SUM +-- +-- The SUM function is identical in TSQL and Databricks + +-- tsql sql: +SELECT sum(col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT SUM(col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_sum_2.sql b/tests/resources/functional/tsql/functions/test_sum_2.sql new file mode 100644 index 0000000000..3d7cb3f6e8 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_sum_2.sql @@ -0,0 +1,9 @@ +-- ## SUM with DISTINCT +-- +-- The SUM function is identical in TSQL and Databricks + +-- tsql sql: +SELECT sum(DISTINCT col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT SUM(DISTINCT col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_var_1.sql b/tests/resources/functional/tsql/functions/test_var_1.sql new file mode 100644 index 0000000000..dfc9de2d64 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_var_1.sql @@ -0,0 +1,9 @@ +-- ## VAR +-- +-- The VAR function is identical in TSQL and Databricks + +-- tsql sql: +SELECT VAR(col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT VAR(col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_var_2.sql b/tests/resources/functional/tsql/functions/test_var_2.sql new file mode 100644 index 0000000000..16880501ca --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_var_2.sql @@ -0,0 +1,10 @@ +-- ## VAR with DISTINCT +-- +-- The VAR funciton is identical in TSQL and Databricks. Using DISTINCT with VAR +-- will not change the results as variance is calculated on unique values already. + +-- tsql sql: +SELECT VAR(DISTINCT col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT VAR(DISTINCT col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_varp_1.sql b/tests/resources/functional/tsql/functions/test_varp_1.sql new file mode 100644 index 0000000000..cbcc7ced32 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_varp_1.sql @@ -0,0 +1,9 @@ +-- ## VARP +-- +-- The VARP function is identical in TSQL and Databricks + +-- tsql sql: +SELECT VARP(col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT VARP(col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/functions/test_varp_2.sql b/tests/resources/functional/tsql/functions/test_varp_2.sql new file mode 100644 index 0000000000..d8e278daa3 --- /dev/null +++ b/tests/resources/functional/tsql/functions/test_varp_2.sql @@ -0,0 +1,10 @@ +-- ## VARP with DISTINCT +-- +-- The VARP funciton is identical in TSQL and Databricks. Using DISTINCT with VARP +-- will not change the results as variance is calculated on unique values already. + +-- tsql sql: +SELECT VARP(DISTINCT col1) AS sum_col1 FROM tabl; + +-- databricks sql: +SELECT VARP(DISTINCT col1) AS sum_col1 FROM tabl; diff --git a/tests/resources/functional/tsql/select/test_cte_1.sql b/tests/resources/functional/tsql/select/test_cte_1.sql new file mode 100644 index 0000000000..37df7bd4b3 --- /dev/null +++ b/tests/resources/functional/tsql/select/test_cte_1.sql @@ -0,0 +1,10 @@ +-- ## WITH cte SELECT +-- +-- The use of CTEs is generally the same in Databricks SQL as TSQL but there are some differences with +-- nesting CTE support. +-- +-- tsql sql: +WITH cte AS (SELECT * FROM t) SELECT * FROM cte + +-- databricks sql: +WITH cte AS (SELECT * FROM t) SELECT * FROM cte; diff --git a/tests/resources/functional/tsql/select/test_cte_2.sql b/tests/resources/functional/tsql/select/test_cte_2.sql new file mode 100644 index 0000000000..a4b0d7cfe3 --- /dev/null +++ b/tests/resources/functional/tsql/select/test_cte_2.sql @@ -0,0 +1,37 @@ +-- ## WITH cte SELECT +-- +-- The use of CTEs is generally the same in Databricks SQL as TSQL but there are some differences with +-- nesting CTE support. +-- +-- tsql sql: + +WITH cteTable1 (col1, col2, col3count) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table1 + ), + cteTable2 (colx, coly, colxcount) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table2 + ) +SELECT col2, col1, col3count, cteTable2.colx, cteTable2.coly, cteTable2.colxcount +FROM cteTable1 + +-- databricks sql: +WITH cteTable1 (col1, col2, col3count) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table1 + ), + cteTable2 (colx, coly, colxcount) + AS + ( + SELECT col1, fred, COUNT(OrderDate) AS counter + FROM Table2 + ) +SELECT col2, col1, col3count, cteTable2.colx, cteTable2.coly, cteTable2.colxcount +FROM cteTable1; diff --git a/tests/resources/functional/tsql/select/test_cte_xml.sql b/tests/resources/functional/tsql/select/test_cte_xml.sql new file mode 100644 index 0000000000..c453879824 --- /dev/null +++ b/tests/resources/functional/tsql/select/test_cte_xml.sql @@ -0,0 +1,15 @@ +-- ## WITH XMLWORKSPACES +-- +-- Databricks SQL does not currently support XML workspaces, so for now, we cover the syntax without recommending +-- a translation. +-- +-- tsql sql: +WITH XMLNAMESPACES ('somereference' as namespace) +SELECT col1 as 'namespace:col1', + col2 as 'namespace:col2' +FROM Table1 +WHERE col2 = 'xyz' +FOR XML RAW ('namespace:namespace'), ELEMENTS; + +-- databricks sql: +WITH XMLNAMESPACES ('somereference' as namespace) SELECT col1 as 'namespace:col1', col2 as 'namespace:col2' FROM Table1 WHERE col2 = 'xyz' FOR XML RAW ('namespace:namespace'), ELEMENTS; diff --git a/tests/resources/functional/tsql/set-operations/except.sql b/tests/resources/functional/tsql/set-operations/except.sql new file mode 100644 index 0000000000..3874a778ab --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/except.sql @@ -0,0 +1,14 @@ +-- ## ... EXCEPT ... +-- +-- Verify simple EXCEPT handling. +-- +-- tsql sql: + +SELECT 1 +EXCEPT +SELECT 2; + +-- databricks sql: +(SELECT 1) +EXCEPT +(SELECT 2); diff --git a/tests/resources/functional/tsql/set-operations/intersect.sql b/tests/resources/functional/tsql/set-operations/intersect.sql new file mode 100644 index 0000000000..bdbf30b2d3 --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/intersect.sql @@ -0,0 +1,14 @@ +-- ## ... INTERSECT ... +-- +-- Verify simple INTERSECT handling. +-- +-- tsql sql: + +SELECT 1 +INTERSECT +SELECT 2; + +-- databricks sql: +(SELECT 1) +INTERSECT +(SELECT 2); diff --git a/tests/resources/functional/tsql/set-operations/precedence.sql b/tests/resources/functional/tsql/set-operations/precedence.sql new file mode 100644 index 0000000000..6b26f89b6b --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/precedence.sql @@ -0,0 +1,78 @@ +-- +-- Verify the precedence rules are being correctly handled. Order of evaluation when chaining is: +-- 1. Brackets. +-- 2. INTERSECT +-- 3. UNION and EXCEPT, evaluated left to right. +-- + +-- tsql sql: + +-- Verifies UNION/EXCEPT as left-to-right, with brackets. +(SELECT 1 + UNION + SELECT 2 + EXCEPT + (SELECT 3 + UNION + SELECT 4)) + +UNION ALL + +-- Verifies UNION/EXCEPT as left-to-right when the order is reversed. +(SELECT 5 + EXCEPT + SELECT 6 + UNION + SELECT 7) + +UNION ALL + +-- Verifies that INTERSECT has precedence over UNION/EXCEPT. +(SELECT 8 + UNION + SELECT 9 + EXCEPT + SELECT 10 + INTERSECT + SELECT 11) + +UNION ALL + +-- Verifies that INTERSECT is left-to-right, although brackets have precedence. +(SELECT 12 + INTERSECT + SELECT 13 + INTERSECT + (SELECT 14 + INTERSECT + SELECT 15)); + +-- databricks sql: + + ( + ( + ( + ((SELECT 1) UNION (SELECT 2)) + EXCEPT + ((SELECT 3) UNION (SELECT 4)) + ) + UNION ALL + ( + ((SELECT 5) EXCEPT (SELECT 6)) + UNION + (SELECT 7) + ) + ) + UNION ALL + ( + ((SELECT 8) UNION (SELECT 9)) + EXCEPT + ((SELECT 10) INTERSECT (SELECT 11)) + ) + ) +UNION ALL + ( + ((SELECT 12) INTERSECT (SELECT 13)) + INTERSECT + ((SELECT 14) INTERSECT (SELECT 15)) + ); diff --git a/tests/resources/functional/tsql/set-operations/union-all.sql b/tests/resources/functional/tsql/set-operations/union-all.sql new file mode 100644 index 0000000000..01b36a4940 --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/union-all.sql @@ -0,0 +1,14 @@ +-- ## ... UNION ALL ... +-- +-- Verify simple UNION ALL handling. +-- +-- tsql sql: + +SELECT 1 +UNION ALL +SELECT 2; + +-- databricks sql: +(SELECT 1) +UNION ALL +(SELECT 2); diff --git a/tests/resources/functional/tsql/set-operations/union.sql b/tests/resources/functional/tsql/set-operations/union.sql new file mode 100644 index 0000000000..872747b1b8 --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/union.sql @@ -0,0 +1,14 @@ +-- ## ... UNION ... +-- +-- Verify simple UNION handling. +-- +-- tsql sql: + +SELECT 1 +UNION +SELECT 2; + +-- databricks sql: +(SELECT 1) +UNION +(SELECT 2); diff --git a/tests/resources/functional/tsql/set-operations/union_all_left_grouped.sql b/tests/resources/functional/tsql/set-operations/union_all_left_grouped.sql new file mode 100644 index 0000000000..89f8210f13 --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/union_all_left_grouped.sql @@ -0,0 +1,10 @@ +-- ## (SELECT …) UNION ALL SELECT … +-- +-- Verify UNION handling when the LHS of the union is explicitly wrapped in parentheses. +-- +-- tsql sql: + +(SELECT a, b from c) UNION ALL SELECT x, y from z; + +-- databricks sql: +(SELECT a, b FROM c) UNION ALL (SELECT x, y FROM z); diff --git a/tests/resources/functional/tsql/set-operations/union_left_grouped.sql b/tests/resources/functional/tsql/set-operations/union_left_grouped.sql new file mode 100644 index 0000000000..5678dd9c20 --- /dev/null +++ b/tests/resources/functional/tsql/set-operations/union_left_grouped.sql @@ -0,0 +1,10 @@ +-- ## (SELECT …) UNION SELECT … +-- +-- Verify UNION handling when the LHS of the union is explicitly wrapped in parentheses. +-- +-- tsql sql: + +(SELECT a, b from c) UNION SELECT x, y from z; + +-- databricks sql: +(SELECT a, b FROM c) UNION (SELECT x, y FROM z); diff --git a/tests/resources/recon_conf_oracle.json b/tests/resources/recon_conf_oracle.json new file mode 100644 index 0000000000..c97fa15308 --- /dev/null +++ b/tests/resources/recon_conf_oracle.json @@ -0,0 +1,89 @@ +{ + "source_catalog": "", + "source_schema": "tpch", + "target_catalog": "tpch", + "target_schema": "1000gb", + "tables": [ + { + "source_name": "supplier", + "target_name": "supplier", + "jdbc_reader_options": { + "number_partitions": 10, + "partition_column": "s_suppkey", + "upper_bound": "10000000", + "lower_bound": "10" + }, + "join_columns": [ + { + "source_name": "s_suppkey" + } + ], + "column_mapping": [ + { + "source_name": "s_address", + "target_name": "s_address" + } + ], + "transformations": [ + { + "column_name": "s_address", + "source": "trim(s_address)", + "target": "trim(s_address)" + }, + { + "column_name": "s_comment", + "source": "trim(s_comment)", + "target": "trim(s_comment)" + }, + { + "column_name": "s_name", + "source": "trim(s_name)", + "target": "trim(s_name)" + }, + { + "column_name": "s_acctbal", + "source": "trim(to_char(s_acctbal, '9999999999.99'))", + "target": "cast(s_acctbal as decimal(38,2))" + } + ] + }, + { + "source_name": "friends", + "target_name": "friends", + "join_columns": [ + { + "source_name": "id", + "target_name": "id_no" + } + ], + "column_mapping": [ + { + "source_name": "name", + "target_name": "char_name" + } + ], + "transformations": [ + { + "column_name": "sal", + "source": "trim(to_char(sal, '9999999999.99'))", + "target": "cast(sal as decimal(38,2))" + }, + { + "column_name": "id", + "source": "cast(id as int)" + } + ], + "thresholds": [ + { + "column_name": "sal", + "lower_bound": "-5%", + "upper_bound": "5%", + "type": "integer" + } + ], + "filters" : { + "target" : "1=1" + } + } + ] +} diff --git a/tests/resources/table_deployment_test_query.sql b/tests/resources/table_deployment_test_query.sql new file mode 100644 index 0000000000..f27da317aa --- /dev/null +++ b/tests/resources/table_deployment_test_query.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS details ( + recon_table_id BIGINT NOT NULL, + recon_type STRING NOT NULL, + status BOOLEAN NOT NULL, + data ARRAY> NOT NULL, + inserted_ts TIMESTAMP NOT NULL +); diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..70a8411e46 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,414 @@ +import re +from pathlib import Path +from collections.abc import Sequence +from unittest.mock import create_autospec + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + ArrayType, + BooleanType, + IntegerType, + LongType, + MapType, + StringType, + StructField, + StructType, + TimestampType, +) +from sqlglot import ErrorLevel, UnsupportedError +from sqlglot.errors import SqlglotError, ParseError +from sqlglot import parse_one as sqlglot_parse_one +from sqlglot import transpile + +from databricks.labs.remorph.config import SQLGLOT_DIALECTS, TranspileConfig +from databricks.labs.remorph.reconcile.recon_config import ( + ColumnMapping, + Filters, + JdbcReaderOptions, + Schema, + Table, + ColumnThresholds, + Transformation, + TableThresholds, +) +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks +from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import Config +from databricks.sdk.service import iam + +from .transpiler.helpers.functional_test_cases import ( + FunctionalTestFile, + FunctionalTestFileWithExpectedException, + expected_exceptions, +) + + +@pytest.fixture(scope="session") +def mock_spark() -> SparkSession: + """ + Method helps to create spark session + :return: returns the spark session + """ + return SparkSession.builder.appName("Remorph Reconcile Test").remote("sc://localhost").getOrCreate() + + +@pytest.fixture(scope="session") +def mock_databricks_config(): + yield create_autospec(Config) + + +@pytest.fixture() +def mock_workspace_client(): + client = create_autospec(WorkspaceClient) + client.current_user.me = lambda: iam.User(user_name="remorph", groups=[iam.ComplexValue(display="admins")]) + yield client + + +@pytest.fixture() +def morph_config(): + yield TranspileConfig( + sdk_config={"cluster_id": "test_cluster"}, + source_dialect="snowflake", + input_source="input_sql", + output_folder="output_folder", + skip_validation=False, + catalog_name="catalog", + schema_name="schema", + mode="current", + ) + + +# TODO Add Standardized Sql Formatter to python functional tests. +def _normalize_string(value: str) -> str: + # Remove extra spaces and ensure consistent spacing around parentheses + value = re.sub(r'\s+', ' ', value) # Replace multiple spaces with a single space + value = re.sub(r'\s*\(\s*', ' ( ', value) # Ensure space around opening parenthesis + value = re.sub(r'\s*\)\s*', ' ) ', value) # Ensure space around closing parenthesis + value = value.strip() # Remove leading and trailing spaces + # Remove indentations, trailing spaces from each line, and convert to lowercase + lines = [line.rstrip() for line in value.splitlines()] + return " ".join(lines).lower().strip() + + +@pytest.fixture +def normalize_string(): + return _normalize_string + + +def get_dialect(input_dialect=None): + return SQLGLOT_DIALECTS.get(input_dialect) + + +def parse_one(sql): + dialect = Databricks + return sqlglot_parse_one(sql, read=dialect) + + +def validate_source_transpile(databricks_sql, *, source=None, pretty=False, experimental=False): + """ + Validate that: + 1. Everything in `source` transpiles to `databricks_sql` + + Args: + databricks_sql (str): Main SQL expression + source (dict): Mapping of dialect -> SQL + pretty (bool): prettify the output + experimental (bool): experimental flag False by default + """ + + for source_dialect, source_sql in (source or {}).items(): + write_dialect = get_dialect("experimental") if experimental else get_dialect("databricks") + + actual_sql = "; ".join( + transpile( + source_sql, + read=get_dialect(source_dialect), + write=write_dialect, + pretty=pretty, + error_level=None, + ) + ) + orig_sql = actual_sql + actual_sql = _normalize_string(actual_sql.rstrip(';')) + expected_sql = _normalize_string(databricks_sql.rstrip(';')) + + error_msg = f"""-> *target_sql* `{expected_sql}` is not matching with\ + \n-> *transpiled_sql* `{actual_sql}`\ + \n-> for *source_dialect* `{source_dialect}\ + ORIG: +{orig_sql} + """ + + assert expected_sql == actual_sql, error_msg + + +def validate_target_transpile(input_sql, *, target=None, pretty=False): + """ + Validate that: + 1. `target_sql` transpiles to `input_sql` using `target` dialect + + Args: + input_sql (str): Main SQL expression + target (dict): Mapping of dialect -> SQL + pretty (bool): prettify the output + """ + expression = parse_one(input_sql) if input_sql else None + for target_dialect, target_sql in (target or {}).items(): + if target_sql is UnsupportedError: + with pytest.raises(UnsupportedError): + if expression: + expression.sql(target_dialect, unsupported_level=ErrorLevel.RAISE) + else: + actual_sql = _normalize_string( + transpile( + target_sql, read=Snowflake, write=get_dialect(target_dialect), pretty=pretty, error_level=None + )[0] + ) + + expected_sql = _normalize_string(input_sql) + + error_msg = f"""-> *target_sql* `{expected_sql}` is not matching with\ + \n-> *transpiled_sql* `{actual_sql}`\ + \n-> for *target_dialect* `{target_dialect}\ + """ + + assert expected_sql == actual_sql, error_msg + + +@pytest.fixture(scope="session") +def dialect_context(): + yield validate_source_transpile, validate_target_transpile + + +_ANTLR_CORE_FOLDER = 'core_engine' + + +def parse_sql_files(input_dir: Path, source: str, target: str, is_expected_exception): + suite: list[FunctionalTestFile | FunctionalTestFileWithExpectedException] = [] + for filenames in input_dir.rglob("*.sql"): + # Skip files in the core directory + if _ANTLR_CORE_FOLDER in filenames.parts: + continue + with open(filenames, 'r', encoding="utf-8") as file_content: + content = file_content.read() + source_pattern = rf'--\s*{source} sql:\n(.*?)(?=\n--\s*{target} sql:|$)' + target_pattern = rf'--\s*{target} sql:\n(.*)' + + # Extract source and target queries + + source_match = re.search(source_pattern, content, re.DOTALL) + target_match = re.search(target_pattern, content, re.DOTALL) + + source_sql = source_match.group(1).strip().rstrip(";") if source_match else "" + target_sql = target_match.group(1).strip() if target_match else "" + + # when multiple sqls are present below target + test_name = filenames.name.replace(".sql", "") + if is_expected_exception: + exception_type = expected_exceptions.get(test_name, SqlglotError) + exception = SqlglotError(test_name) + if exception_type in {ParseError, UnsupportedError}: + exception = exception_type(test_name) + suite.append( + FunctionalTestFileWithExpectedException( + target_sql, + source_sql, + test_name, + exception, + target, + ) + ) + else: + suite.append(FunctionalTestFile(target_sql, source_sql, test_name, target)) + return suite + + +def get_functional_test_files_from_directory( + input_dir: Path, source: str, target: str, is_expected_exception=False +) -> Sequence[FunctionalTestFileWithExpectedException]: + """Get all functional tests in the input_dir.""" + suite = parse_sql_files(input_dir, source, target, is_expected_exception) + return suite + + +@pytest.fixture +def table_conf_mock(): + def _mock_table_conf(**kwargs): + return Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=kwargs.get('jdbc_reader_options', None), + join_columns=kwargs.get('join_columns', None), + select_columns=kwargs.get('select_columns', None), + drop_columns=kwargs.get('drop_columns', None), + column_mapping=kwargs.get('column_mapping', None), + transformations=kwargs.get('transformations', None), + column_thresholds=kwargs.get('thresholds', None), + filters=kwargs.get('filters', None), + ) + + return _mock_table_conf + + +@pytest.fixture +def table_conf_with_opts(column_mapping): + return Table( + source_name="supplier", + target_name="target_supplier", + jdbc_reader_options=JdbcReaderOptions( + number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" + ), + join_columns=["s_suppkey", "s_nationkey"], + select_columns=["s_suppkey", "s_name", "s_address", "s_phone", "s_acctbal", "s_nationkey"], + drop_columns=["s_comment"], + column_mapping=column_mapping, + transformations=[ + Transformation(column_name="s_address", source="trim(s_address)", target="trim(s_address_t)"), + Transformation(column_name="s_phone", source="trim(s_phone)", target="trim(s_phone_t)"), + Transformation(column_name="s_name", source="trim(s_name)", target="trim(s_name)"), + ], + column_thresholds=[ + ColumnThresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int"), + ], + filters=Filters(source="s_name='t' and s_address='a'", target="s_name='t' and s_address_t='a'"), + table_thresholds=[ + TableThresholds(lower_bound="0", upper_bound="100", model="mismatch"), + ], + ) + + +@pytest.fixture +def column_mapping(): + return [ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ColumnMapping(source_name="s_nationkey", target_name="s_nationkey_t"), + ColumnMapping(source_name="s_phone", target_name="s_phone_t"), + ColumnMapping(source_name="s_acctbal", target_name="s_acctbal_t"), + ColumnMapping(source_name="s_comment", target_name="s_comment_t"), + ] + + +@pytest.fixture +def table_schema(): + sch = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + sch_with_alias = [ + Schema("s_suppkey_t", "number"), + Schema("s_name", "varchar"), + Schema("s_address_t", "varchar"), + Schema("s_nationkey_t", "number"), + Schema("s_phone_t", "varchar"), + Schema("s_acctbal_t", "number"), + Schema("s_comment_t", "varchar"), + ] + + return sch, sch_with_alias + + +@pytest.fixture +def expr(): + return parse_one("SELECT col1 FROM DUAL") + + +@pytest.fixture +def report_tables_schema(): + recon_schema = StructType( + [ + StructField("recon_table_id", LongType(), nullable=False), + StructField("recon_id", StringType(), nullable=False), + StructField("source_type", StringType(), nullable=False), + StructField( + "source_table", + StructType( + [ + StructField('catalog', StringType(), nullable=False), + StructField('schema', StringType(), nullable=False), + StructField('table_name', StringType(), nullable=False), + ] + ), + nullable=False, + ), + StructField( + "target_table", + StructType( + [ + StructField('catalog', StringType(), nullable=False), + StructField('schema', StringType(), nullable=False), + StructField('table_name', StringType(), nullable=False), + ] + ), + nullable=False, + ), + StructField("report_type", StringType(), nullable=False), + StructField("operation_name", StringType(), nullable=False), + StructField("start_ts", TimestampType()), + StructField("end_ts", TimestampType()), + ] + ) + + metrics_schema = StructType( + [ + StructField("recon_table_id", LongType(), nullable=False), + StructField( + "recon_metrics", + StructType( + [ + StructField( + "row_comparison", + StructType( + [ + StructField("missing_in_source", IntegerType()), + StructField("missing_in_target", IntegerType()), + ] + ), + ), + StructField( + "column_comparison", + StructType( + [ + StructField("absolute_mismatch", IntegerType()), + StructField("threshold_mismatch", IntegerType()), + StructField("mismatch_columns", StringType()), + ] + ), + ), + StructField("schema_comparison", BooleanType()), + ] + ), + ), + StructField( + "run_metrics", + StructType( + [ + StructField("status", BooleanType(), nullable=False), + StructField("run_by_user", StringType(), nullable=False), + StructField("exception_message", StringType()), + ] + ), + ), + StructField("inserted_ts", TimestampType(), nullable=False), + ] + ) + + details_schema = StructType( + [ + StructField("recon_table_id", LongType(), nullable=False), + StructField("recon_type", StringType(), nullable=False), + StructField("status", BooleanType(), nullable=False), + StructField("data", ArrayType(MapType(StringType(), StringType())), nullable=False), + StructField("inserted_ts", TimestampType(), nullable=False), + ] + ) + + return recon_schema, metrics_schema, details_schema diff --git a/tests/unit/contexts/__init__.py b/tests/unit/contexts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/contexts/test_application.py b/tests/unit/contexts/test_application.py new file mode 100644 index 0000000000..3dbf301f24 --- /dev/null +++ b/tests/unit/contexts/test_application.py @@ -0,0 +1,121 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import iam + +from databricks.labs.remorph.contexts.application import ApplicationContext + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + w.config.return_value = {"warehouse_id", "1234"} + return w + + +def test_workspace_context_attributes_not_none(ws): + ctx = ApplicationContext(ws) + assert hasattr(ctx, "workspace_client") + assert ctx.workspace_client is not None + assert hasattr(ctx, "current_user") + assert ctx.current_user.user_name == "me@example.com" + assert hasattr(ctx, "product_info") + assert ctx.product_info is not None + assert hasattr(ctx, "connect_config") + assert ctx.connect_config is not None + assert hasattr(ctx, "catalog_operations") + assert ctx.catalog_operations is not None + assert hasattr(ctx, "installation") + assert ctx.installation is not None + assert hasattr(ctx, "sql_backend") + assert ctx.sql_backend is not None + assert hasattr(ctx, "prompts") + assert ctx.prompts is not None + + ctx.replace( + installation=MockInstallation( + { + "config.yml": { + "source_dialect": "snowflake", + "catalog_name": "transpiler_test", + "input_sql": "sf_queries", + "output_folder": "out_dir", + "skip_validation": False, + "schema_name": "convertor_test", + "sdk_config": { + "warehouse_id": "abc", + }, + "version": 1, + }, + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "snowflake_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "report_type": "all", + "secret_scope": "remorph_snowflake", + "tables": { + "filter_type": "exclude", + "tables_list": ["ORDERS", "PART"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + "state.json": { + "resources": { + "jobs": {"Remorph_Reconciliation_Job": "12345"}, + "dashboards": {"Remorph-Reconciliation": "abcdef"}, + }, + "version": 1, + }, + } + ) + ) + assert hasattr(ctx, "transpile_config") + assert ctx.transpile_config is not None + assert hasattr(ctx, "recon_config") + assert ctx.recon_config is not None + assert hasattr(ctx, "remorph_config") + assert ctx.remorph_config is not None + assert ctx.remorph_config.transpile is not None + assert ctx.remorph_config.reconcile is not None + assert hasattr(ctx, "install_state") + assert ctx.install_state is not None + + assert hasattr(ctx, "resource_configurator") + assert ctx.resource_configurator is not None + assert hasattr(ctx, "table_deployment") + assert ctx.table_deployment is not None + assert hasattr(ctx, "job_deployment") + assert ctx.job_deployment is not None + assert hasattr(ctx, "dashboard_deployment") + assert ctx.dashboard_deployment is not None + assert hasattr(ctx, "recon_deployment") + assert ctx.recon_deployment is not None + assert hasattr(ctx, "workspace_installation") + assert ctx.workspace_installation is not None + + +def test_workspace_context_missing_configs(ws): + ctx = ApplicationContext(ws) + ctx.replace(installation=MockInstallation({})) + assert hasattr(ctx, "transpile_config") + assert ctx.transpile_config is None + assert hasattr(ctx, "recon_config") + assert ctx.recon_config is None + assert hasattr(ctx, "remorph_config") + assert ctx.remorph_config is not None + assert ctx.remorph_config.transpile is None + assert ctx.remorph_config.reconcile is None diff --git a/tests/unit/coverage/__init__.py b/tests/unit/coverage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/coverage/conftest.py b/tests/unit/coverage/conftest.py new file mode 100644 index 0000000000..f5ab137bad --- /dev/null +++ b/tests/unit/coverage/conftest.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.fixture() +def io_dir_pair(tmp_path): + input_dir = tmp_path / "input" + input_dir.mkdir() + input_file = input_dir / "test.sql" + input_file.write_text("SELECT * FROM test") + + output_dir = tmp_path / "output" + + yield input_dir, output_dir diff --git a/tests/unit/coverage/test_coverage_utils.py b/tests/unit/coverage/test_coverage_utils.py new file mode 100644 index 0000000000..d7b44fad99 --- /dev/null +++ b/tests/unit/coverage/test_coverage_utils.py @@ -0,0 +1,260 @@ +# pylint: disable=all +import json +import os +from datetime import datetime +from unittest import mock +from unittest.mock import patch + +import pytest +import pytz + +from databricks.labs.remorph.coverage.commons import ( + ReportEntry, + collect_transpilation_stats, + get_current_commit_hash, + get_current_time_utc, + get_env_var, + get_supported_sql_files, + write_json_line, +) +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks +from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake + + +def test_get_supported_sql_files(tmp_path): + sub_dir = tmp_path / "test_dir" + sub_dir.mkdir() + + files = [ + tmp_path / "test1.sql", + tmp_path / "test2.sql", + tmp_path / "test3.ddl", + tmp_path / "test4.txt", + sub_dir / "test5.sql", + ] + + for file in files: + file.touch() + + files = list(get_supported_sql_files(tmp_path)) + assert len(files) == 4 + assert all(file.is_file() for file in files) + + +def test_write_json_line(tmp_path): + report_entry = ReportEntry( + project="Remorph", + commit_hash="6b0e403", + version="", + timestamp="2022-01-01T00:00:00", + source_dialect="Snowflake", + target_dialect="Databricks", + file="test_file.sql", + ) + report_file_path = tmp_path / "test_file.json" + with open(report_file_path, "w", encoding="utf8") as report_file: + write_json_line(report_file, report_entry) + + with open(report_file_path) as report_file: + retrieved_report_entry = ReportEntry(**json.loads(report_file.readline())) + assert retrieved_report_entry == report_entry + + +def test_get_env_var(): + os.environ["TEST_VAR"] = "test_value" + assert get_env_var("TEST_VAR") == "test_value" + with pytest.raises(ValueError): + get_env_var("NON_EXISTENT_VAR", required=True) + + +def test_get_current_commit_hash(): + with patch("subprocess.check_output", return_value="6b0e403".encode("ascii")): + assert get_current_commit_hash() == "6b0e403" + + with patch("subprocess.check_output", side_effect=FileNotFoundError()): + assert get_current_commit_hash() is None + + +@mock.patch("databricks.labs.remorph.coverage.commons.datetime") +def test_get_current_time_utc(mock_datetime): + fixed_timestamp = datetime(2022, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) + mock_datetime.now = mock.Mock(return_value=fixed_timestamp) + assert get_current_time_utc() == fixed_timestamp + + +def test_stats_collection_with_invalid_input(tmp_path): + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + + with pytest.raises(NotADirectoryError, match="The input path .* doesn't exist"): + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + + +def test_stats_collection_with_invalid_output_dir(tmp_path): + input_dir = tmp_path / "input" + input_dir.mkdir() + output_dir = tmp_path / "output" + output_dir.touch() + + with pytest.raises(NotADirectoryError, match="The output path .* exists but is not a directory"): + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + + +def test_stats_collection_with_valid_io_dir(tmp_path): + input_dir = tmp_path / "input" + input_dir.mkdir() + output_dir = tmp_path / "output" + output_dir.mkdir() + + try: + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + except Exception as e: + pytest.fail(f"Transpilation stats collection raised an unexpected exception {e!s}") + + +def test_stats_collection_with_parse_error(io_dir_pair): + input_dir, output_dir = io_dir_pair + fixed_timestamp = datetime(2022, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) + + with ( + patch( + "databricks.labs.remorph.coverage.commons.parse_sql", + side_effect=Exception("Some parse error"), + ), + patch( + "databricks.labs.remorph.coverage.commons.get_current_time_utc", + return_value=fixed_timestamp, + ), + ): + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + + report_files = list(output_dir.glob("*.json")) + assert len(report_files) == 1 + + expected_report_entry = ReportEntry( + project="Remorph", + commit_hash="6b0e403", + version="", + timestamp=fixed_timestamp.isoformat(), + source_dialect="Snowflake", + target_dialect="Databricks", + file="input/test.sql", + failures=[{'error_code': "Exception", 'error_message': "Exception('Some parse error')"}], + ) + retrieved_report_entry = ReportEntry(**json.loads(report_files[0].read_text())) + assert retrieved_report_entry == expected_report_entry + + +def test_stats_collection_with_transpile_error(io_dir_pair): + input_dir, output_dir = io_dir_pair + fixed_timestamp = datetime(2022, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) + + with ( + patch( + "databricks.labs.remorph.coverage.commons.generate_sql", + side_effect=Exception("Some transpilation error"), + ), + patch( + "databricks.labs.remorph.coverage.commons.get_current_time_utc", + return_value=fixed_timestamp, + ), + ): + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + + report_files = list(output_dir.glob("*.json")) + assert len(report_files) == 1 + + expected_report_entry = ReportEntry( + project="Remorph", + commit_hash="6b0e403", + version="", + timestamp=fixed_timestamp.isoformat(), + source_dialect="Snowflake", + target_dialect="Databricks", + file="input/test.sql", + parsed=1, + statements=1, + failures=[{'error_code': "Exception", 'error_message': "Exception('Some transpilation error')"}], + ) + retrieved_report_entry = ReportEntry(**json.loads(report_files[0].read_text())) + assert retrieved_report_entry == expected_report_entry + + +def test_stats_collection_no_error(io_dir_pair): + input_dir, output_dir = io_dir_pair + fixed_timestamp = datetime(2022, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) + + with ( + patch( + "databricks.labs.remorph.coverage.commons.get_current_time_utc", + return_value=fixed_timestamp, + ), + ): + collect_transpilation_stats( + project="Remorph", + commit_hash="6b0e403", + version="", + source_dialect=Snowflake, + target_dialect=Databricks, + input_dir=input_dir, + result_dir=output_dir, + ) + + report_files = list(output_dir.glob("*.json")) + assert len(report_files) == 1 + + expected_report_entry = ReportEntry( + project="Remorph", + commit_hash="6b0e403", + version="", + timestamp=fixed_timestamp.isoformat(), + source_dialect="Snowflake", + target_dialect="Databricks", + file="input/test.sql", + parsed=1, + statements=1, + transpiled=1, + transpiled_statements=1, + ) + retrieved_report_entry = ReportEntry(**json.loads(report_files[0].read_text())) + assert retrieved_report_entry == expected_report_entry diff --git a/tests/unit/deployment/__init__.py b/tests/unit/deployment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/deployment/test_configurator.py b/tests/unit/deployment/test_configurator.py new file mode 100644 index 0000000000..da0598d75a --- /dev/null +++ b/tests/unit/deployment/test_configurator.py @@ -0,0 +1,316 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.tui import MockPrompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import iam +from databricks.sdk.service.catalog import ( + CatalogInfo, + SchemaInfo, + VolumeInfo, +) +from databricks.sdk.service.sql import EndpointInfo, EndpointInfoWarehouseType, GetWarehouseResponse, State + +from databricks.labs.remorph.deployment.configurator import ResourceConfigurator +from databricks.labs.remorph.helpers.metastore import CatalogOperations + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + return w + + +def test_prompt_for_catalog_setup_existing_catalog(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_catalog.return_value = CatalogInfo(name="remorph") + catalog_operations.has_catalog_access.return_value = True + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert configurator.prompt_for_catalog_setup() == "remorph" + + +def test_prompt_for_catalog_setup_existing_catalog_no_access_abort(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_catalog.return_value = CatalogInfo(name="remorph") + catalog_operations.get_schema.return_value = SchemaInfo(catalog_name="remorph", name="reconcile") + catalog_operations.has_catalog_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_catalog_setup() + configurator.has_necessary_access("remorph", "reconcile", None) + + +def test_prompt_for_catalog_setup_existing_catalog_no_access_retry_exhaust_attempts(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + r"Catalog .* doesn't exist. Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_catalog.return_value = None + catalog_operations.has_catalog_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_catalog_setup() + + +def test_prompt_for_catalog_setup_new_catalog(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + r"Catalog .* doesn't exist. Create it?": "yes", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_catalog.return_value = None + catalog_operations.create_catalog.return_value = CatalogInfo(name="remorph") + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert configurator.prompt_for_catalog_setup() == "remorph" + + +def test_prompt_for_catalog_setup_new_catalog_abort(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + r"Catalog .* doesn't exist. Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_catalog.return_value = None + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_catalog_setup() + + +def test_prompt_for_schema_setup_existing_schema(ws): + prompts = MockPrompts( + { + r"Enter schema name": "reconcile", + } + ) + catalog_ops = create_autospec(CatalogOperations) + catalog_ops.get_schema.return_value = SchemaInfo( + catalog_name="remorph", + name="reconcile", + full_name="remorph.reconcile", + ) + catalog_ops.has_schema_access.return_value = True + configurator = ResourceConfigurator(ws, prompts, catalog_ops) + assert configurator.prompt_for_schema_setup("remorph", "reconcile") == "reconcile" + + +def test_prompt_for_schema_setup_existing_schema_no_access_abort(ws): + prompts = MockPrompts( + { + r"Enter schema name": "remorph", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_schema.return_value = SchemaInfo( + catalog_name="remorph", name="reconcile", full_name="remorph.reconcile" + ) + catalog_operations.has_catalog_access.return_value = True + catalog_operations.has_schema_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_schema_setup("remorph", "reconcile") + configurator.has_necessary_access("remorph", "reconcile", None) + + +def test_prompt_for_schema_setup_existing_schema_no_access_retry_exhaust_attempts(ws): + prompts = MockPrompts( + { + r"Enter schema name": "remorph", + r"Schema .* doesn't exist .* Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_schema.return_value = None + catalog_operations.has_schema_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_schema_setup("remorph", "reconcile") + + +def test_prompt_for_schema_setup_new_schema(ws): + prompts = MockPrompts( + { + r"Enter schema name": "remorph", + r"Schema .* doesn't exist .* Create it?": "yes", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_schema.return_value = None + catalog_operations.create_schema.return_value = SchemaInfo(catalog_name="remorph", name="reconcile") + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert configurator.prompt_for_schema_setup("remorph", "reconcile") == "reconcile" + + +def test_prompt_for_schema_setup_new_schema_abort(ws): + prompts = MockPrompts( + { + r"Enter schema name": "remorph", + r"Schema .* doesn't exist .* Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_schema.return_value = None + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_schema_setup("remorph", "reconcile") + + +def test_prompt_for_volume_setup_existing_volume(ws): + prompts = MockPrompts( + { + r"Enter volume name": "recon_volume", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_volume.return_value = VolumeInfo( + catalog_name="remorph", + schema_name="reconcile", + name="recon_volume", + ) + catalog_operations.has_volume_access.return_value = True + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert ( + configurator.prompt_for_volume_setup( + "remorph", + "reconcile", + "recon_volume", + ) + == "recon_volume" + ) + + +def test_prompt_for_volume_setup_existing_volume_no_access_abort(ws): + prompts = MockPrompts( + { + r"Enter volume name": "recon_volume", + r"Do you want to use another volume?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_volume.return_value = VolumeInfo( + catalog_name="remorph", + schema_name="reconcile", + name="recon_volume", + full_name="remorph.reconcile.recon_volume", + ) + catalog_operations.has_volume_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_volume_setup( + "remorph", + "reconcile", + "recon_volume", + ) + configurator.has_necessary_access("remorph", "reconcile", "recon_volume") + + +def test_prompt_for_volume_setup_existing_volume_no_access_retry_exhaust_attempts(ws): + prompts = MockPrompts( + { + r"Enter volume name": "recon_volume", + r"Volume .* doesn't exist .* Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_volume.return_value = None + catalog_operations.has_volume_access.return_value = False + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_volume_setup( + "remorph", + "reconcile", + "recon_volume", + ) + + +def test_prompt_for_volume_setup_new_volume(ws): + prompts = MockPrompts( + { + r"Enter volume name": "recon_volume", + r"Volume .* doesn't exist .* Create it?": "yes", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_volume.return_value = None + catalog_operations.create_volume.return_value = VolumeInfo( + catalog_name="remorph", + schema_name="reconcile", + name="recon_volume", + ) + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert ( + configurator.prompt_for_volume_setup( + "remorph", + "reconcile", + "recon_volume", + ) + == "recon_volume" + ) + + +def test_prompt_for_volume_setup_new_volume_abort(ws): + prompts = MockPrompts( + { + r"Enter volume name": "recon_volume", + r"Volume .* doesn't exist .* Create it?": "no", + } + ) + catalog_operations = create_autospec(CatalogOperations) + catalog_operations.get_volume.return_value = None + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + with pytest.raises(SystemExit): + configurator.prompt_for_volume_setup( + "remorph", + "reconcile", + "recon_volume", + ) + + +def test_prompt_for_warehouse_setup_from_existing_warehouses(ws): + ws.warehouses.list.return_value = [ + EndpointInfo( + name="Test Warehouse", + id="w_id", + warehouse_type=EndpointInfoWarehouseType.PRO, + state=State.RUNNING, + ) + ] + prompts = MockPrompts({r"Select PRO or SERVERLESS SQL warehouse": "1"}) + catalog_operations = create_autospec(CatalogOperations) + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert configurator.prompt_for_warehouse_setup("Test") == "w_id" + + +def test_prompt_for_warehouse_setup_new(ws): + ws.warehouses.list.return_value = [ + EndpointInfo( + name="Test Warehouse", + id="w_id", + warehouse_type=EndpointInfoWarehouseType.PRO, + state=State.RUNNING, + ) + ] + ws.warehouses.create.return_value = GetWarehouseResponse(id="new_w_id") + prompts = MockPrompts({r"Select PRO or SERVERLESS SQL warehouse": "0"}) + catalog_operations = create_autospec(CatalogOperations) + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + assert configurator.prompt_for_warehouse_setup("Test") == "new_w_id" diff --git a/tests/unit/deployment/test_dashboard.py b/tests/unit/deployment/test_dashboard.py new file mode 100644 index 0000000000..46c027d490 --- /dev/null +++ b/tests/unit/deployment/test_dashboard.py @@ -0,0 +1,138 @@ +import json +from pathlib import Path +from unittest.mock import create_autospec +import logging +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.installer import InstallState +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue, NotFound +from databricks.sdk.service.dashboards import Dashboard +from databricks.sdk.service.dashboards import LifecycleState + +from databricks.labs.remorph.config import ReconcileMetadataConfig, ReconcileConfig, DatabaseConfig +from databricks.labs.remorph.deployment.dashboard import DashboardDeployment + + +def _get_dashboard_query(kwargs): + serialized_dashboard = json.loads(kwargs['serialized_dashboard']) + return serialized_dashboard['datasets'][0]['query'] + + +def test_deploy_dashboard(): + ws = create_autospec(WorkspaceClient) + expected_query = """SELECT + main.recon_id, + main.source_type, + main.report_type, + main.source_table.`catalog` AS source_catalog, + main.source_table.`schema` AS source_schema, + main.source_table.table_name AS source_table_name\nFROM remorph.reconcile.main AS main""".strip() + + dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") + dashboard = Dashboard( + dashboard_id="9c1fbf4ad3449be67d6cb64c8acc730b", + display_name="Remorph-Reconciliation", + ) + ws.lakeview.create.return_value = dashboard + installation = MockInstallation(is_global=False) + install_state = InstallState.from_installation(installation) + dashboard_publisher = DashboardDeployment(ws, installation, install_state) + reconcile_config = ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle69", + database_config=DatabaseConfig( + source_schema="tpch_sf100069", + target_catalog="tpch69", + target_schema="1000gb69", + ), + metadata_config=ReconcileMetadataConfig(), + ) + dashboard_publisher.deploy(dashboard_folder, reconcile_config) + _, kwargs = ws.lakeview.create.call_args + query = _get_dashboard_query(kwargs) + assert query == expected_query + assert install_state.dashboards["queries"] == dashboard.dashboard_id + + +@pytest.mark.parametrize("exception", [InvalidParameterValue, NotFound]) +def test_recovery_invalid_dashboard(caplog, exception): + dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") + + ws = create_autospec(WorkspaceClient) + dashboard_id = "9c1fbf4ad3449be67d6cb64c8acc730b" + dashboard = Dashboard( + dashboard_id=dashboard_id, + display_name="Remorph-Reconciliation", + ) + ws.lakeview.create.return_value = dashboard + ws.lakeview.get.side_effect = exception + # name = "Remorph-Reconciliation" + installation = MockInstallation( + { + "state.json": { + "resources": {"dashboards": {"queries": "8c1fbf4ad3449be67d6cb64c8acc730b"}}, + "version": 1, + }, + } + ) + install_state = InstallState.from_installation(installation) + dashboard_publisher = DashboardDeployment(ws, installation, install_state) + reconcile_config = ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle66", + database_config=DatabaseConfig( + source_schema="tpch_sf100066", + target_catalog="tpch66", + target_schema="1000gb66", + ), + metadata_config=ReconcileMetadataConfig(), + ) + with caplog.at_level(logging.DEBUG, logger="databricks.labs.remorph.deployment.dashboard"): + dashboard_publisher.deploy(dashboard_folder, reconcile_config) + assert "Recovering invalid dashboard" in caplog.text + assert "Deleted dangling dashboard" in caplog.text + ws.workspace.delete.assert_called() + ws.lakeview.create.assert_called() + ws.lakeview.update.assert_not_called() + + +def test_recovery_trashed_dashboard(caplog): + dashboard_folder = Path(__file__).parent / Path("../../resources/dashboards") + + ws = create_autospec(WorkspaceClient) + dashboard_id = "9c1fbf4ad3449be67d6cb64c8acc730b" + dashboard = Dashboard( + dashboard_id=dashboard_id, + display_name="Remorph-Reconciliation", + ) + ws.lakeview.create.return_value = dashboard + ws.lakeview.get.return_value = Dashboard(lifecycle_state=LifecycleState.TRASHED) + installation = MockInstallation( + { + "state.json": { + "resources": {"dashboards": {"queries": "8c1fbf4ad3449be67d6cb64c8acc730b"}}, + "version": 1, + }, + } + ) + install_state = InstallState.from_installation(installation) + dashboard_publisher = DashboardDeployment(ws, installation, install_state) + reconcile_config = ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle77", + database_config=DatabaseConfig( + source_schema="tpch_sf100077", + target_catalog="tpch77", + target_schema="1000gb77", + ), + metadata_config=ReconcileMetadataConfig(), + ) + with caplog.at_level(logging.DEBUG, logger="databricks.labs.remorph.deployment.dashboard"): + dashboard_publisher.deploy(dashboard_folder, reconcile_config) + assert "Recreating trashed dashboard" in caplog.text + ws.lakeview.create.assert_called() + ws.lakeview.update.assert_not_called() diff --git a/tests/unit/deployment/test_installation.py b/tests/unit/deployment/test_installation.py new file mode 100644 index 0000000000..9b550497cc --- /dev/null +++ b/tests/unit/deployment/test_installation.py @@ -0,0 +1,220 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.installation import MockInstallation, Installation +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.blueprint.wheels import WheelsV2, ProductInfo +from databricks.labs.blueprint.upgrades import Upgrades + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.sdk.service import iam + +from databricks.labs.remorph.config import ( + TranspileConfig, + RemorphConfigs, + ReconcileConfig, + DatabaseConfig, + ReconcileMetadataConfig, +) +from databricks.labs.remorph.deployment.installation import WorkspaceInstallation +from databricks.labs.remorph.deployment.recon import ReconDeployment + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + return w + + +def test_install_all(ws): + prompts = MockPrompts( + { + r"Enter catalog name": "remorph", + } + ) + recon_deployment = create_autospec(ReconDeployment) + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + transpile_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow6", + output_folder="/tmp/queries/databricks6", + skip_validation=True, + catalog_name="remorph6", + schema_name="transpiler6", + mode="current", + ) + reconcile_config = ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle6", + database_config=DatabaseConfig( + source_schema="tpch_sf10006", + target_catalog="tpch6", + target_schema="1000gb6", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph6", + schema="reconcile6", + volume="reconcile_volume6", + ), + ) + config = RemorphConfigs(transpile=transpile_config, reconcile=reconcile_config) + installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation.install(config) + + +def test_no_recon_component_installation(ws): + prompts = MockPrompts({}) + recon_deployment = create_autospec(ReconDeployment) + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + transpile_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow7", + output_folder="/tmp/queries/databricks7", + skip_validation=True, + catalog_name="remorph7", + schema_name="transpiler7", + mode="current", + ) + config = RemorphConfigs(transpile=transpile_config) + installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation.install(config) + recon_deployment.install.assert_not_called() + + +def test_recon_component_installation(ws): + recon_deployment = create_autospec(ReconDeployment) + installation = create_autospec(Installation) + prompts = MockPrompts({}) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + reconcile_config = ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle8", + database_config=DatabaseConfig( + source_schema="tpch_sf10008", + target_catalog="tpch8", + target_schema="1000gb8", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph8", + schema="reconcile8", + volume="reconcile_volume8", + ), + ) + config = RemorphConfigs(reconcile=reconcile_config) + installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation.install(config) + recon_deployment.install.assert_called() + + +def test_negative_uninstall_confirmation(ws): + prompts = MockPrompts( + { + r"Do you want to uninstall Remorph .*": "no", + } + ) + installation = create_autospec(Installation) + recon_deployment = create_autospec(ReconDeployment) + wheels = create_autospec(WheelsV2) + upgrades = create_autospec(Upgrades) + + ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + config = RemorphConfigs() + ws_installation.uninstall(config) + installation.remove.assert_not_called() + + +def test_missing_installation(ws): + prompts = MockPrompts( + { + r"Do you want to uninstall Remorph .*": "yes", + } + ) + installation = create_autospec(Installation) + installation.files.side_effect = NotFound("Installation not found") + installation.install_folder.return_value = "~/mock" + recon_deployment = create_autospec(ReconDeployment) + wheels = create_autospec(WheelsV2) + upgrades = create_autospec(Upgrades) + + ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + config = RemorphConfigs() + ws_installation.uninstall(config) + installation.remove.assert_not_called() + + +def test_uninstall_configs_exist(ws): + prompts = MockPrompts( + { + r"Do you want to uninstall Remorph .*": "yes", + } + ) + + transpile_config = TranspileConfig( + source_dialect="snowflake", + input_source="sf_queries1", + output_folder="out_dir1", + skip_validation=True, + catalog_name="transpiler_test1", + schema_name="convertor_test1", + mode="current", + sdk_config={"warehouse_id": "abc"}, + ) + + reconcile_config = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake1", + database_config=DatabaseConfig( + source_catalog="snowflake_sample_data1", + source_schema="tpch_sf10001", + target_catalog="tpch1", + target_schema="1000gb1", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph1", + schema="reconcile1", + volume="reconcile_volume1", + ), + ) + config = RemorphConfigs(transpile=transpile_config, reconcile=reconcile_config) + installation = MockInstallation({}) + recon_deployment = create_autospec(ReconDeployment) + wheels = create_autospec(WheelsV2) + upgrades = create_autospec(Upgrades) + + ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation.uninstall(config) + recon_deployment.uninstall.assert_called() + installation.assert_removed() + + +def test_uninstall_configs_missing(ws): + prompts = MockPrompts( + { + r"Do you want to uninstall Remorph .*": "yes", + } + ) + installation = MockInstallation() + recon_deployment = create_autospec(ReconDeployment) + wheels = create_autospec(WheelsV2) + upgrades = create_autospec(Upgrades) + + ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + config = RemorphConfigs() + ws_installation.uninstall(config) + recon_deployment.uninstall.assert_not_called() + installation.assert_removed() diff --git a/tests/unit/deployment/test_job.py b/tests/unit/deployment/test_job.py new file mode 100644 index 0000000000..3872c14942 --- /dev/null +++ b/tests/unit/deployment/test_job.py @@ -0,0 +1,96 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue +from databricks.sdk.service.jobs import Job + +from databricks.labs.remorph.config import RemorphConfigs, ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.remorph.deployment.job import JobDeployment + + +@pytest.fixture +def oracle_recon_config() -> ReconcileConfig: + return ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle9", + database_config=DatabaseConfig( + source_schema="tpch_sf10009", + target_catalog="tpch9", + target_schema="1000gb9", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph9", + schema="reconcile9", + volume="reconcile_volume9", + ), + ) + + +@pytest.fixture +def snowflake_recon_config() -> ReconcileConfig: + return ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake9", + database_config=DatabaseConfig( + source_schema="tpch_sf10009", + target_catalog="tpch9", + target_schema="1000gb9", + source_catalog="snowflake_sample_data9", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph9", + schema="reconcile9", + volume="reconcile_volume9", + ), + ) + + +def test_deploy_new_job(oracle_recon_config): + workspace_client = create_autospec(WorkspaceClient) + job = Job(job_id=1234) + workspace_client.jobs.create.return_value = job + installation = MockInstallation(is_global=False) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.from_class(RemorphConfigs) + name = "Recon Job" + job_deployer = JobDeployment(workspace_client, installation, install_state, product_info) + job_deployer.deploy_recon_job(name, oracle_recon_config, "remorph-x.y.z-py3-none-any.whl") + workspace_client.jobs.create.assert_called_once() + assert install_state.jobs[name] == str(job.job_id) + + +def test_deploy_existing_job(snowflake_recon_config): + workspace_client = create_autospec(WorkspaceClient) + workspace_client.config.is_gcp = True + job_id = 1234 + job = Job(job_id=job_id) + name = "Recon Job" + installation = MockInstallation({"state.json": {"resources": {"jobs": {name: job_id}}, "version": 1}}) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + job_deployer = JobDeployment(workspace_client, installation, install_state, product_info) + job_deployer.deploy_recon_job(name, snowflake_recon_config, "remorph-x.y.z-py3-none-any.whl") + workspace_client.jobs.reset.assert_called_once() + assert install_state.jobs[name] == str(job.job_id) + + +def test_deploy_missing_job(snowflake_recon_config): + workspace_client = create_autospec(WorkspaceClient) + job_id = 1234 + job = Job(job_id=job_id) + workspace_client.jobs.create.return_value = job + workspace_client.jobs.reset.side_effect = InvalidParameterValue("Job not found") + name = "Recon Job" + installation = MockInstallation({"state.json": {"resources": {"jobs": {name: 5678}}, "version": 1}}) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + job_deployer = JobDeployment(workspace_client, installation, install_state, product_info) + job_deployer.deploy_recon_job(name, snowflake_recon_config, "remorph-x.y.z-py3-none-any.whl") + workspace_client.jobs.create.assert_called_once() + assert install_state.jobs[name] == str(job.job_id) diff --git a/tests/unit/deployment/test_recon.py b/tests/unit/deployment/test_recon.py new file mode 100644 index 0000000000..4f8a457eb6 --- /dev/null +++ b/tests/unit/deployment/test_recon.py @@ -0,0 +1,211 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue +from databricks.sdk.service import iam + +from databricks.labs.remorph.config import RemorphConfigs, ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.remorph.deployment.dashboard import DashboardDeployment +from databricks.labs.remorph.deployment.job import JobDeployment +from databricks.labs.remorph.deployment.recon import ReconDeployment +from databricks.labs.remorph.deployment.table import TableDeployment + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + return w + + +def test_install_missing_config(ws): + table_deployer = create_autospec(TableDeployment) + job_deployer = create_autospec(JobDeployment) + dashboard_deployer = create_autospec(DashboardDeployment) + installation = MockInstallation(is_global=False) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + recon_deployer = ReconDeployment( + ws, + installation, + install_state, + product_info, + table_deployer, + job_deployer, + dashboard_deployer, + ) + remorph_config = None + recon_deployer.install(remorph_config, ["remorph-x.y.z-py3-none-any.whl"]) + table_deployer.deploy_table_from_ddl_file.assert_not_called() + job_deployer.deploy_recon_job.assert_not_called() + dashboard_deployer.deploy.assert_not_called() + + +def test_install(ws): + reconcile_config = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake4", + database_config=DatabaseConfig( + source_catalog="snowflake_sample_data4", + source_schema="tpch_sf10004", + target_catalog="tpch4", + target_schema="1000gb4", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph4", + schema="reconcile4", + volume="reconcile_volume4", + ), + ) + installation = MockInstallation( + { + "state.json": { + "resources": { + "jobs": { + "Reconciliation Deprecated Job 1": "1", + "Reconciliation Deprecated Job 2": "2", + "Some other Job": "3", + }, + "dashboards": { + "Reconciliation Deprecated Dashboard 1": "d_id1", + "Reconciliation Deprecated Dashboard 2": "d_id2", + "Some other Dashboard": "d_id3", + }, + }, + "version": 1, + }, + } + ) + table_deployer = create_autospec(TableDeployment) + job_deployer = create_autospec(JobDeployment) + dashboard_deployer = create_autospec(DashboardDeployment) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + recon_deployer = ReconDeployment( + ws, + installation, + install_state, + product_info, + table_deployer, + job_deployer, + dashboard_deployer, + ) + + def raise_invalid_parameter_err_for_dashboard(rid: str): + if rid == "d_id2": + raise InvalidParameterValue + + def raise_invalid_parameter_err_for_job(jid: str): + if jid == 2: + raise InvalidParameterValue + + ws.lakeview.trash.side_effect = raise_invalid_parameter_err_for_dashboard + ws.jobs.delete.side_effect = raise_invalid_parameter_err_for_job + recon_deployer.install(reconcile_config, ["remorph-x.y.z-py3-none-any.whl"]) + table_deployer.deploy_table_from_ddl_file.assert_called() + job_deployer.deploy_recon_job.assert_called() + dashboard_deployer.deploy.assert_called() + + assert "Reconciliation Deprecated Job 1" not in install_state.jobs + assert "Reconciliation Deprecated Job 2" not in install_state.jobs + assert "Some other Job" in install_state.jobs + + +def test_uninstall_missing_config(ws): + table_deployer = create_autospec(TableDeployment) + job_deployer = create_autospec(JobDeployment) + dashboard_deployer = create_autospec(DashboardDeployment) + installation = MockInstallation(is_global=False) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + recon_deployer = ReconDeployment( + ws, + installation, + install_state, + product_info, + table_deployer, + job_deployer, + dashboard_deployer, + ) + remorph_config = None + recon_deployer.uninstall(remorph_config) + ws.lakeview.trash.assert_not_called() + ws.jobs.delete.assert_not_called() + + +def test_uninstall(ws): + recon_config = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake5", + database_config=DatabaseConfig( + source_catalog="snowflake_sample_data5", + source_schema="tpch_sf10005", + target_catalog="tpch5", + target_schema="1000gb5", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph5", + schema="reconcile5", + volume="reconcile_volume5", + ), + ) + installation = MockInstallation( + { + "state.json": { + "resources": { + "jobs": { + "Reconciliation Runner": "15", + "Reconciliation Another Job": "25", + "Some other Job": "35", + }, + "dashboards": { + "Reconciliation Metrics": "d_id15", + "Reconciliation Another Dashboard": "d_id25", + "Some other Dashboard": "d_id35", + }, + }, + "version": 1, + }, + } + ) + table_deployer = create_autospec(TableDeployment) + job_deployer = create_autospec(JobDeployment) + dashboard_deployer = create_autospec(DashboardDeployment) + install_state = InstallState.from_installation(installation) + product_info = ProductInfo.for_testing(RemorphConfigs) + recon_deployer = ReconDeployment( + ws, + installation, + install_state, + product_info, + table_deployer, + job_deployer, + dashboard_deployer, + ) + + def raise_invalid_parameter_err_for_dashboard(rid: str): + if rid == "d_id25": + raise InvalidParameterValue + + def raise_invalid_parameter_err_for_job(jid: str): + if jid == 25: + raise InvalidParameterValue + + ws.lakeview.trash.side_effect = raise_invalid_parameter_err_for_dashboard + ws.jobs.delete.side_effect = raise_invalid_parameter_err_for_job + + recon_deployer.uninstall(recon_config) + ws.lakeview.trash.assert_called() + ws.jobs.delete.assert_called() + + assert "Reconciliation Runner" not in install_state.jobs + assert "Some other Job" in install_state.jobs + assert len(install_state.dashboards.keys()) == 0 diff --git a/tests/unit/deployment/test_table.py b/tests/unit/deployment/test_table.py new file mode 100644 index 0000000000..60dde9dc4d --- /dev/null +++ b/tests/unit/deployment/test_table.py @@ -0,0 +1,14 @@ +from pathlib import Path + +from databricks.labs.lsql.backends import MockBackend + +from databricks.labs.remorph.deployment.table import TableDeployment + + +def test_deploy_table_from_ddl_file(): + sql_backend = MockBackend() + table_deployer = TableDeployment(sql_backend) + ddl_file = Path(__file__).parent / Path("../../resources/table_deployment_test_query.sql") + table_deployer.deploy_table_from_ddl_file("catalog", "schema", "table", ddl_file) + assert len(sql_backend.queries) == 1 + assert sql_backend.queries[0] == ddl_file.read_text() diff --git a/tests/unit/deployment/test_upgrade_common.py b/tests/unit/deployment/test_upgrade_common.py new file mode 100644 index 0000000000..0b979098f1 --- /dev/null +++ b/tests/unit/deployment/test_upgrade_common.py @@ -0,0 +1,208 @@ +from unittest.mock import patch +from databricks.labs.lsql.backends import MockBackend +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.remorph.contexts.application import ApplicationContext + +from databricks.labs.remorph.deployment.upgrade_common import ( + replace_patterns, + extract_columns_with_datatype, + extract_column_name, + table_original_query, + current_table_columns, + installed_table_columns, + check_table_mismatch, + recreate_table_sql, +) + + +def test_replace_patterns_removes_struct_and_map(): + sql_text = "CREATE TABLE test (id INT, data STRUCT, map_data MAP)" + result = replace_patterns(sql_text) + assert result == "CREATE TABLE test (id INT, data , map_data )" + + +def test_extract_columns_with_datatype_parses_columns(): + sql_text = "CREATE TABLE test (id INT, name STRING NOT NULL, age INT)" + result = extract_columns_with_datatype(sql_text) + assert result == ["id INT", " name STRING NOT NULL", " age INT"] + + +def test_extract_column_name_parses_column_name(): + column_with_datatype = "id INT" + result = extract_column_name(column_with_datatype) + assert result == "id" + + +def test_table_original_query(): + table_name = "main" + full_table_name = "main_table" + result = table_original_query(table_name, full_table_name) + assert "CREATE OR REPLACE TABLE main_table" in result + + +def test_current_table_columns(): + table_name = "main" + full_table_name = "main_table" + result = current_table_columns(table_name, full_table_name) + assert result == [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + + +def test_installed_table_columns(mock_workspace_client): + table_identifier = "main_table" + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ): + result = installed_table_columns(mock_workspace_client, table_identifier) + assert result == [] + + +def test_check_table_mismatch(): + main_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + installed_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + result = check_table_mismatch(main_columns, installed_columns) + assert result is False + + main_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + installed_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + ] + result = check_table_mismatch(main_columns, installed_columns) + assert result + + +def test_recreate_table_sql(mock_workspace_client): + ## Test 1 + main_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + ] + installed_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + table_identifier = "main" + prompts = MockPrompts( + { + rf"The `{table_identifier}` table columns are not as expected. Do you want to recreate the `{table_identifier}` table?": "yes" + } + ) + installation = MockInstallation() + ctx = ApplicationContext(mock_workspace_client) + ctx.replace( + prompts=prompts, + installation=installation, + ) + result = recreate_table_sql(table_identifier, main_columns, installed_columns, ctx.prompts) + assert "CREATE OR REPLACE TABLE main" in result + + ## Test 2 + prompts = MockPrompts( + { + rf"The `{table_identifier}` table columns are not as expected. Do you want to recreate the `{table_identifier}` table?": "no" + } + ) + installation = MockInstallation() + ctx = ApplicationContext(mock_workspace_client) + ctx.replace( + prompts=prompts, + installation=installation, + ) + result = recreate_table_sql(table_identifier, main_columns, installed_columns, ctx.prompts) + assert result is None + + ## Test 3 + main_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + installed_columns = [ + "recon_table_id", + "recon_id", + "source_type", + "source_table", + "target_table", + "report_type", + "operation_name", + "start_ts", + "end_ts", + ] + table_identifier = "main" + prompts = MockPrompts( + { + rf"The `{table_identifier}` table columns are not as expected. Do you want to recreate the `{table_identifier}` table?": "yes" + } + ) + result = recreate_table_sql(table_identifier, main_columns, installed_columns, prompts) + assert ( + result + == "CREATE OR REPLACE TABLE main AS SELECT recon_table_id,recon_id,source_type,source_table,target_table,report_type,operation_name,start_ts,end_ts FROM main" + ) diff --git a/tests/unit/helpers/__init__.py b/tests/unit/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/helpers/test_db_sql.py b/tests/unit/helpers/test_db_sql.py new file mode 100644 index 0000000000..9e98e8e49e --- /dev/null +++ b/tests/unit/helpers/test_db_sql.py @@ -0,0 +1,42 @@ +from unittest.mock import patch, create_autospec + +import pytest +from databricks.labs.remorph.helpers.db_sql import get_sql_backend +from databricks.sdk import WorkspaceClient +from databricks.labs.lsql.backends import StatementExecutionBackend + + +def test_get_sql_backend_with_warehouse_id_in_config(): + workspace_client = create_autospec(WorkspaceClient) + workspace_client.config.warehouse_id = "test_warehouse_id" + sql_backend = get_sql_backend(workspace_client) + assert isinstance(sql_backend, StatementExecutionBackend) + + +def test_get_sql_backend_with_warehouse_id_in_arg(): + workspace_client = create_autospec(WorkspaceClient) + sql_backend = get_sql_backend(workspace_client, warehouse_id="test_warehouse_id") + assert isinstance(sql_backend, StatementExecutionBackend) + + +@patch('databricks.labs.remorph.helpers.db_sql.DatabricksConnectBackend') +def test_get_sql_backend_without_warehouse_id(databricks_connect_backend): + workspace_client = create_autospec(WorkspaceClient) + workspace_client.config.warehouse_id = None + sql_backend = get_sql_backend(workspace_client) + databricks_connect_backend.assert_called_once_with(workspace_client) + assert isinstance(sql_backend, databricks_connect_backend.return_value.__class__) + + +@pytest.mark.usefixtures("monkeypatch") +@patch('databricks.labs.remorph.helpers.db_sql.RuntimeBackend') +def test_get_sql_backend_without_warehouse_id_in_notebook( + runtime_backend, + monkeypatch, +): + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + workspace_client = create_autospec(WorkspaceClient) + workspace_client.config.warehouse_id = None + sql_backend = get_sql_backend(workspace_client) + runtime_backend.assert_called_once() + assert isinstance(sql_backend, runtime_backend.return_value.__class__) diff --git a/tests/unit/helpers/test_file_utils.py b/tests/unit/helpers/test_file_utils.py new file mode 100644 index 0000000000..dec31f6e0b --- /dev/null +++ b/tests/unit/helpers/test_file_utils.py @@ -0,0 +1,149 @@ +import codecs +import os +import tempfile +from pathlib import Path + +import pytest + +from databricks.labs.remorph.helpers.file_utils import ( + dir_walk, + is_sql_file, + make_dir, + refactor_hexadecimal_chars, + remove_bom, +) + + +@pytest.fixture(scope="module") +def setup_module(tmp_path_factory): + test_dir = tmp_path_factory.mktemp("test_dir") + sql_file = test_dir / "test.sql" + sql_file.write_text("SELECT * FROM test;") + non_sql_file = test_dir / "test.txt" + non_sql_file.write_text("This is a test.") + return test_dir, sql_file, non_sql_file + + +def test_remove_bom(): + test_string = "test_string" + + # Test no BOM + assert remove_bom(test_string) == test_string + + # Test UTF-16 BOM + # "******** UTF-16 ********" + input_string = codecs.BOM_UTF16.decode("utf-16") + test_string + assert remove_bom(input_string) == test_string + + # "******** UTF-16-BE ********" + input_string = codecs.BOM_UTF16_BE.decode("utf-16-be") + test_string + assert remove_bom(input_string) == test_string + + # "******** UTF-16-LE ********" + input_string = codecs.BOM_UTF16_LE.decode("utf-16-le") + test_string + assert remove_bom(input_string) == test_string + + # Test UTF-32 BOM + # "******** UTF-32 ********" + input_string = codecs.BOM_UTF32.decode("utf-32") + test_string + assert remove_bom(input_string) == test_string + + # "******** UTF-32-BE ********" + input_string = codecs.BOM_UTF32_BE.decode("utf-32-be") + test_string + assert remove_bom(input_string) == test_string + + # "******** UTF-32-LE ********" + input_string = codecs.BOM_UTF32_LE.decode("utf-32-le") + test_string + assert remove_bom(input_string) == test_string + + # Test UTF8 BOM + # "******** UTF-8 ********" + input_string = codecs.BOM_UTF8.decode("utf-8") + test_string + assert remove_bom(input_string) == test_string + + +def test_is_sql_file(): + assert is_sql_file("test.sql") is True + assert is_sql_file("test.ddl") is True + assert is_sql_file("test.txt") is False + assert is_sql_file("test") is False + + +def test_make_dir(): + with tempfile.TemporaryDirectory() as temp_dir: + new_dir_path = temp_dir.join("new_dir") + + # Ensure the directory does not exist + assert os.path.exists(new_dir_path) is False + + # Call the function to create the directory + make_dir(new_dir_path) + + # Check if the directory now exists + assert os.path.exists(new_dir_path) is True + + +def safe_remove_file(file_path: Path): + if file_path.exists(): + file_path.unlink() + + +def safe_remove_dir(dir_path: Path): + if dir_path.exists(): + dir_path.rmdir() + + +def test_dir_walk_single_file(): + path = Path("test_dir") + path.mkdir() + (path / "test_file.txt").touch() + result = list(dir_walk(path)) + assert len(result) == 1 + assert result[0][0] == path + assert len(result[0][1]) == 0 + assert len(result[0][2]) == 1 + safe_remove_file(path / "test_file.txt") + safe_remove_dir(path) + + +def test_dir_walk_nested_files(): + path = Path("test_dir") + path.mkdir() + (path / "test_file.txt").touch() + (path / "nested_dir").mkdir() + (path / "nested_dir" / "nested_file.txt").touch() + result = list(dir_walk(path)) + + assert len(result) == 2 + assert result[0][0] == path + assert len(result[0][1]) == 1 + assert len(result[0][2]) == 1 + assert result[1][0] == path / "nested_dir" + assert len(result[1][1]) == 0 + assert len(result[1][2]) == 1 + safe_remove_file(path / "test_file.txt") + safe_remove_file(path / "nested_dir" / "nested_file.txt") + safe_remove_dir(path / "nested_dir") + safe_remove_dir(path) + + +def test_dir_walk_empty_dir(): + path = Path("empty_dir") + path.mkdir() + result = list(dir_walk(path)) + + assert len(result) == 1 + assert result[0][0] == path + assert len(result[0][1]) == 0 + assert len(result[0][2]) == 0 + safe_remove_dir(path) + + +def test_refactor_hexadecimal_chars(): + input_string = "SELECT * FROM test \x1b[4mWHERE\x1b[0m" + output_string = "SELECT * FROM test --> WHERE <--" + assert refactor_hexadecimal_chars(input_string) == output_string + + input_string2 = "SELECT \x1b[4marray_agg(\x1b[0mafter_state order by timestamp asc) FROM dual" + output_string2 = "SELECT --> array_agg( <--after_state order by timestamp asc) FROM dual" + assert refactor_hexadecimal_chars(input_string2) == output_string2 diff --git a/tests/unit/helpers/test_metastore.py b/tests/unit/helpers/test_metastore.py new file mode 100644 index 0000000000..38b7f3d784 --- /dev/null +++ b/tests/unit/helpers/test_metastore.py @@ -0,0 +1,323 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.catalog import ( + CatalogInfo, + EffectivePermissionsList, + Privilege, + SchemaInfo, + VolumeInfo, + SecurableType, +) +from databricks.sdk.errors import NotFound +from databricks.labs.remorph.helpers.metastore import CatalogOperations + + +@pytest.fixture +def ws(): + return create_autospec(WorkspaceClient) + + +def test_get_existing_catalog(ws): + ws.catalogs.get.return_value = CatalogInfo(name="test") + catalog_operations = CatalogOperations(ws) + assert isinstance(catalog_operations.get_catalog("test"), CatalogInfo) + + +def test_get_non_existing_catalog(ws): + ws.catalogs.get.side_effect = NotFound() + catalog_operations = CatalogOperations(ws) + assert catalog_operations.get_catalog("test") is None + + +def test_get_existing_schema(ws): + ws.schemas.get.return_value = SchemaInfo(catalog_name="test_catalog", name="test_schema") + catalog_operations = CatalogOperations(ws) + assert isinstance(catalog_operations.get_schema("test_catalog", "test_schema"), SchemaInfo) + + +def test_get_non_existing_schema(ws): + ws.schemas.get.side_effect = NotFound() + catalog_operations = CatalogOperations(ws) + assert catalog_operations.get_schema("test_catalog", "test_schema") is None + + +def test_get_existing_volume(ws): + ws.volumes.read.return_value = VolumeInfo( + catalog_name="test_catalog", schema_name="test_schema", name="test_volume" + ) + catalog_operations = CatalogOperations(ws) + assert isinstance( + catalog_operations.get_volume("test_catalog", "test_schema", "test_volume"), + VolumeInfo, + ) + + +def test_get_non_existing_volume(ws): + ws.volumes.read.side_effect = NotFound() + catalog_operations = CatalogOperations(ws) + assert catalog_operations.get_volume("test_catalog", "test_schema", "test_volume") is None + + +def test_create_catalog(ws): + catalog = CatalogInfo(name="test") + ws.catalogs.create.return_value = catalog + catalog_operations = CatalogOperations(ws) + assert catalog == catalog_operations.create_catalog("test") + + +def test_create_schema(ws): + schema = SchemaInfo(catalog_name="test_catalog", name="test_schema") + ws.schemas.create.return_value = schema + catalog_operations = CatalogOperations(ws) + assert schema == catalog_operations.create_schema("test_schema", "test_catalog") + + +def test_create_volume(ws): + volume = VolumeInfo(catalog_name="test_catalog", schema_name="test_schema", name="test_volume") + ws.volumes.create.return_value = volume + catalog_operations = CatalogOperations(ws) + assert volume == catalog_operations.create_volume("test_catalog", "test_schema", "test_volume") + + +def test_has_all_privileges(ws): + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [ + {"privilege": "USE_CATALOG"}, + {"privilege": "CREATE_SCHEMA"}, + ], + } + ] + } + ) + + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_privileges( + user='test_user', + securable_type=SecurableType.CATALOG, + full_name='test_catalog', + privileges={Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA}, + ) + + +def test_has_no_privileges(ws): + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [], + } + ] + } + ) + + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_privileges( + user='test_user', + securable_type=SecurableType.CATALOG, + full_name='test_catalog', + privileges={Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA}, + ) + + +def test_has_none_permission_list(ws): + ws.grants.get_effective.return_value = None + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_privileges( + user='test_user', + securable_type=SecurableType.CATALOG, + full_name='test_catalog', + privileges={Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA}, + ) + + +def test_has_none_privilege_assignments(ws): + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict({"privilege_assignments": None}) + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_privileges( + user='test_user', + securable_type=SecurableType.CATALOG, + full_name='test_catalog', + privileges={Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA}, + ) + + +def test_has_some_privileges(ws): + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [{"privilege": "USE_CATALOG"}], + } + ] + } + ) + + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_privileges( + user='test_user', + securable_type=SecurableType.CATALOG, + full_name='test_catalog', + privileges={Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA}, + ) + + +def test_has_catalog_access_owner(ws): + catalog = CatalogInfo(name="test_catalog", owner="test_user@me.com") + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_catalog_access(catalog, "test_user@me.com", ({Privilege.ALL_PRIVILEGES},)) + + +def test_has_catalog_access_has_all_privileges(ws): + catalog = CatalogInfo(name="test_catalog") + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [ + {"privilege": "USE_CATALOG"}, + {"privilege": "CREATE_SCHEMA"}, + ], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_catalog_access( + catalog, "test_user@me.com", ({Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA},) + ) + + +def test_has_catalog_access_has_no_privileges(ws): + catalog = CatalogInfo(name="test_catalog") + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_catalog_access( + catalog, "test_user@me.com", ({Privilege.USE_CATALOG, Privilege.CREATE_SCHEMA},) + ) + + +def test_has_schema_access_owner(ws): + schema = SchemaInfo(catalog_name="test_catalog", name="test_schema", owner="test_user@me.com") + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_schema_access( + schema, + "test_user@me.com", + ({Privilege.ALL_PRIVILEGES},), + ) + + +def test_has_schema_access_has_all_privileges(ws): + schema = SchemaInfo(catalog_name="test_catalog", name="test_schema", full_name="test_catalog.test_schema") + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [ + {"privilege": "USE_SCHEMA"}, + {"privilege": "CREATE_TABLE"}, + ], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_schema_access( + schema, + "test_user@me.com", + ({Privilege.USE_SCHEMA, Privilege.CREATE_TABLE},), + ) + + +def test_schema_access_has_no_privileges(ws): + schema = SchemaInfo(catalog_name="test_catalog", name="test_schema", full_name="test_catalog.test_schema") + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_schema_access( + schema, + "test_user@me.com", + ({Privilege.USE_SCHEMA, Privilege.CREATE_TABLE},), + ) + + +def test_has_volume_access_owner(ws): + volume = VolumeInfo( + catalog_name="test_catalog", schema_name="test_schema", name="test_volume", owner="test_user@me.com" + ) + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_volume_access( + volume, + "test_user@me.com", + ({Privilege.ALL_PRIVILEGES},), + ) + + +def test_has_volume_access_has_all_privileges(ws): + volume = VolumeInfo( + catalog_name="test_catalog", + schema_name="test_schema", + name="test_volume", + full_name="test_catalog.test_schema.test_volume", + ) + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [ + {"privilege": "READ_VOLUME"}, + {"privilege": "WRITE_VOLUME"}, + ], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert catalog_ops.has_volume_access( + volume, + "test_user@me.com", + ({Privilege.READ_VOLUME, Privilege.WRITE_VOLUME},), + ) + + +def test_volume_access_has_no_privileges(ws): + volume = VolumeInfo( + catalog_name="test_catalog", + schema_name="test_schema", + name="test_volume", + full_name="test_catalog.test_schema.test_volume", + ) + ws.grants.get_effective.return_value = EffectivePermissionsList.from_dict( + { + "privilege_assignments": [ + { + "privileges": [], + } + ] + } + ) + catalog_ops = CatalogOperations(ws) + assert not catalog_ops.has_volume_access( + volume, + "test_user@me.com", + ({Privilege.READ_VOLUME, Privilege.WRITE_VOLUME},), + ) diff --git a/tests/unit/helpers/test_recon_config_utils.py b/tests/unit/helpers/test_recon_config_utils.py new file mode 100644 index 0000000000..61ccf508c0 --- /dev/null +++ b/tests/unit/helpers/test_recon_config_utils.py @@ -0,0 +1,145 @@ +from unittest.mock import patch + +import pytest + +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.remorph.helpers.recon_config_utils import ReconConfigPrompts +from databricks.sdk.errors.platform import ResourceDoesNotExist +from databricks.sdk.service.workspace import SecretScope + +SOURCE_DICT = {"databricks": "0", "oracle": "1", "snowflake": "2"} +SCOPE_NAME = "dummy_scope" + + +def test_configure_secrets_snowflake_overwrite(mock_workspace_client): + prompts = MockPrompts( + { + r"Select the source": SOURCE_DICT["snowflake"], + r"Enter Secret Scope name": SCOPE_NAME, + r"Enter Snowflake URL": "dummy", + r"Enter Account Name": "dummy", + r"Enter User": "dummy", + r"Enter Password": "dummy", + r"Enter Database": "dummy", + r"Enter Schema": "dummy", + r"Enter Snowflake Warehouse": "dummy", + r"Enter Role": "dummy", + r"Do you want to overwrite.*": "yes", + } + ) + mock_workspace_client.secrets.list_scopes.side_effect = [[SecretScope(name=SCOPE_NAME)]] + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.prompt_source() + + recon_conf.prompt_and_save_connection_details() + + +def test_configure_secrets_oracle_insert(mock_workspace_client): + # mock prompts for Oracle + prompts = MockPrompts( + { + r"Select the source": SOURCE_DICT["oracle"], + r"Enter Secret Scope name": SCOPE_NAME, + r"Do you want to create a new one?": "yes", + r"Enter User": "dummy", + r"Enter Password": "dummy", + r"Enter host": "dummy", + r"Enter port": "dummy", + r"Enter database/SID": "dummy", + } + ) + + mock_workspace_client.secrets.list_scopes.side_effect = [[SecretScope(name="scope_name")]] + + with patch( + "databricks.labs.remorph.helpers.recon_config_utils.ReconConfigPrompts._secret_key_exists", + return_value=False, + ): + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.prompt_source() + + recon_conf.prompt_and_save_connection_details() + + +def test_configure_secrets_invalid_source(mock_workspace_client): + prompts = MockPrompts( + { + r"Select the source": "3", + r"Enter Secret Scope name": SCOPE_NAME, + } + ) + + with patch( + "databricks.labs.remorph.helpers.recon_config_utils.ReconConfigPrompts._scope_exists", + return_value=True, + ): + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + with pytest.raises(ValueError, match="cannot get answer within 10 attempt"): + recon_conf.prompt_source() + + +def test_store_connection_secrets_exception(mock_workspace_client): + prompts = MockPrompts( + { + r"Do you want to overwrite `source_key`?": "no", + } + ) + + mock_workspace_client.secrets.get_secret.side_effect = ResourceDoesNotExist("Not Found") + mock_workspace_client.secrets.put_secret.side_effect = Exception("Timed out") + + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + + with pytest.raises(Exception, match="Timed out"): + recon_conf.store_connection_secrets("scope_name", ("source", {"key": "value"})) + + +def test_configure_secrets_no_scope(mock_workspace_client): + prompts = MockPrompts( + { + r"Select the source": SOURCE_DICT["snowflake"], + r"Enter Secret Scope name": SCOPE_NAME, + r"Do you want to create a new one?": "no", + } + ) + + mock_workspace_client.secrets.list_scopes.side_effect = [[SecretScope(name="scope_name")]] + + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.prompt_source() + + with pytest.raises(SystemExit, match="Scope is needed to store Secrets in Databricks Workspace"): + recon_conf.prompt_and_save_connection_details() + + +def test_configure_secrets_create_scope_exception(mock_workspace_client): + prompts = MockPrompts( + { + r"Select the source": SOURCE_DICT["snowflake"], + r"Enter Secret Scope name": SCOPE_NAME, + r"Do you want to create a new one?": "yes", + } + ) + + mock_workspace_client.secrets.list_scopes.side_effect = [[SecretScope(name="scope_name")]] + mock_workspace_client.secrets.create_scope.side_effect = Exception("Network Error") + + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.prompt_source() + + with pytest.raises(Exception, match="Network Error"): + recon_conf.prompt_and_save_connection_details() + + +def test_store_connection_secrets_overwrite(mock_workspace_client): + prompts = MockPrompts( + { + r"Do you want to overwrite `key`?": "no", + } + ) + + with patch( + "databricks.labs.remorph.helpers.recon_config_utils.ReconConfigPrompts._secret_key_exists", return_value=True + ): + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.store_connection_secrets("scope_name", ("source", {"key": "value"})) diff --git a/tests/unit/helpers/test_validation.py b/tests/unit/helpers/test_validation.py new file mode 100644 index 0000000000..5016bd290a --- /dev/null +++ b/tests/unit/helpers/test_validation.py @@ -0,0 +1,96 @@ +from databricks.labs.lsql.backends import MockBackend +from databricks.labs.lsql.core import Row +from databricks.labs.remorph.helpers.validation import Validator + + +def test_valid_query(morph_config): + query = "SELECT * FROM a_table" + sql_backend = MockBackend( + rows={ + "EXPLAIN SELECT": [Row(plan="== Physical Plan ==")], + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + assert query in validation_result.validated_sql + assert validation_result.exception_msg is None + + +def test_query_with_syntax_error(morph_config): + query = "SELECT * a_table" + sql_backend = MockBackend( + fails_on_first={ + f"EXPLAIN {query}": "[PARSE_SYNTAX_ERROR] Syntax error at", + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + assert "Exception Start" in validation_result.validated_sql + assert "Syntax error" in validation_result.exception_msg + + +def test_query_with_analysis_error(morph_config): + error_types = [ + ("[TABLE_OR_VIEW_NOT_FOUND]", True), + ("[TABLE_OR_VIEW_ALREADY_EXISTS]", True), + ("[UNRESOLVED_ROUTINE]", False), + ("Hive support is required to CREATE Hive TABLE (AS SELECT).;", True), + ("Some other analysis error", False), + ] + + for err, should_succeed in error_types: + query = "SELECT * FROM a_table" + sql_backend = MockBackend( + fails_on_first={ + f"EXPLAIN {query}": err, + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + if should_succeed: + assert query in validation_result.validated_sql + assert "[WARNING]:" in validation_result.exception_msg + else: + assert err in validation_result.exception_msg + + +def test_validate_format_result_with_valid_query(morph_config): + query = "SELECT current_timestamp()" + sql_backend = MockBackend( + rows={ + "EXPLAIN SELECT": [Row(plan="== Physical Plan ==")], + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + assert query in validation_result.validated_sql + assert validation_result.exception_msg is None + + +def test_validate_format_result_with_invalid_query(morph_config): + query = "SELECT fn() FROM tab" + sql_backend = MockBackend( + rows={ + "EXPLAIN SELECT": [ + Row(plan="Error occurred during query planning:"), + Row(plan="[UNRESOLVED_ROUTINE] Cannot resolve function"), + ], + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + assert "Exception Start" in validation_result.validated_sql + assert "[UNRESOLVED_ROUTINE]" in validation_result.exception_msg + + +def test_validate_with_no_rows_returned(morph_config): + query = "SELECT * FROM a_table" + sql_backend = MockBackend( + rows={ + "EXPLAIN SELECT": [], + } + ) + validator = Validator(sql_backend) + validation_result = validator.validate_format_result(morph_config, query) + assert "Exception Start" in validation_result.validated_sql + assert "No results returned" in validation_result.exception_msg diff --git a/tests/unit/intermediate/test_dag.py b/tests/unit/intermediate/test_dag.py new file mode 100644 index 0000000000..b5e4cbcfb3 --- /dev/null +++ b/tests/unit/intermediate/test_dag.py @@ -0,0 +1,37 @@ +import pytest + +from databricks.labs.remorph.intermediate.dag import DAG + + +@pytest.fixture(scope="module") +def dag(): + d = DAG() + d.add_edge("parent_node", "child_node") + return d + + +def test_add_node(dag): + dag.add_node("test_node") + assert "test_node" in dag.nodes + + +def test_add_edge(dag): + dag.add_edge("edge_node", "node") + assert "edge_node" in dag.nodes + assert "node" in dag.nodes + assert dag.nodes["node"].name in dag.nodes["edge_node"].children + + +def test_identify_immediate_parents(dag): + parents = dag.identify_immediate_parents("child_node") + assert parents == ["parent_node"] + + +def test_identify_immediate_children(dag): + children = dag.identify_immediate_children("parent_node") + assert children == ["child_node"] + + +def test_identify_root_tables(dag): + root_tables = dag.identify_root_tables(0) + assert "parent_node" in root_tables diff --git a/tests/unit/intermediate/test_root_tables.py b/tests/unit/intermediate/test_root_tables.py new file mode 100644 index 0000000000..d64c09a0f4 --- /dev/null +++ b/tests/unit/intermediate/test_root_tables.py @@ -0,0 +1,48 @@ +import pytest + +from databricks.labs.remorph.intermediate.root_tables import RootTableIdentifier + + +@pytest.fixture(autouse=True) +def setup_file(tmpdir): + file = tmpdir.join("test.sql") + file.write( + """create table table1 select * from table2 inner join + table3 on table2.id = table3.id where table2.id in (select id from table4); + create table table2 select * from table4; + create table table5 select * from table3 join table4 on table3.id = table4.id ; + """ + ) + return file + + +def test_generate_lineage(tmpdir): + root_table_identifier = RootTableIdentifier("snowflake", str(tmpdir)) + dag = root_table_identifier.generate_lineage() + roots = ["table2", "table3", "table4"] + + assert len(dag.nodes["table4"].parents) == 0 + assert len(dag.identify_immediate_children("table3")) == 2 + assert dag.identify_immediate_parents("table1") == roots + assert dag.identify_root_tables(0) == {"table3", "table4"} + assert dag.identify_root_tables(2) == {"table1"} + assert dag.identify_immediate_parents("none") == [] + + +def test_generate_lineage_sql_file(setup_file): + root_table_identifier = RootTableIdentifier("snowflake", str(setup_file)) + dag = root_table_identifier.generate_lineage(engine="sqlglot") + roots = ["table2", "table3", "table4"] + + assert len(dag.nodes["table4"].parents) == 0 + assert len(dag.identify_immediate_children("table3")) == 2 + assert dag.identify_immediate_parents("table1") == roots + assert dag.identify_root_tables(0) == {"table3", "table4"} + assert dag.identify_root_tables(2) == {"table1"} + assert dag.identify_immediate_children("none") == [] + + +def test_non_sqlglot_engine_raises_error(tmpdir): + root_table_identifier = RootTableIdentifier("snowflake", str(tmpdir)) + with pytest.raises(ValueError): + root_table_identifier.generate_lineage(engine="antlr") diff --git a/tests/unit/no_cheat.py b/tests/unit/no_cheat.py new file mode 100644 index 0000000000..e0979c19f1 --- /dev/null +++ b/tests/unit/no_cheat.py @@ -0,0 +1,37 @@ +import sys +from pathlib import Path + +DISABLE_TAG = '# pylint: disable=' + + +def no_cheat(diff_text: str) -> str: + lines = diff_text.split('\n') + removed: dict[str, int] = {} + added: dict[str, int] = {} + for line in lines: + if not (line.startswith("-") or line.startswith("+")): + continue + idx = line.find(DISABLE_TAG) + if idx < 0: + continue + codes = line[idx + len(DISABLE_TAG) :].split(',') + for code in codes: + code = code.strip().strip('\n').strip('"').strip("'") + if line.startswith("-"): + removed[code] = removed.get(code, 0) + 1 + continue + added[code] = added.get(code, 0) + 1 + results: list[str] = [] + for code, count in added.items(): + count -= removed.get(code, 0) + if count > 0: + results.append(f"Do not cheat the linter: found {count} additional {DISABLE_TAG}{code}") + return '\n'.join(results) + + +if __name__ == "__main__": + diff_data = sys.argv[1] + path = Path(diff_data) + if path.exists(): + diff_data = path.read_text("utf-8") + print(no_cheat(diff_data)) diff --git a/tests/unit/reconcile/__init__.py b/tests/unit/reconcile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/reconcile/connectors/test_databricks.py b/tests/unit/reconcile/connectors/test_databricks.py new file mode 100644 index 0000000000..30094d1479 --- /dev/null +++ b/tests/unit/reconcile/connectors/test_databricks.py @@ -0,0 +1,116 @@ +import re +from unittest.mock import MagicMock, create_autospec + +import pytest + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource +from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException +from databricks.sdk import WorkspaceClient + + +def initial_setup(): + pyspark_sql_session = MagicMock() + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + # Define the source, workspace, and scope + engine = get_dialect("databricks") + ws = create_autospec(WorkspaceClient) + scope = "scope" + return engine, spark, ws, scope + + +def test_get_schema(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # catalog as catalog + dd = DatabricksDataSource(engine, spark, ws, scope) + dd.get_schema("catalog", "schema", "supplier") + spark.sql.assert_called_with( + re.sub( + r'\s+', + ' ', + """select lower(column_name) as col_name, full_data_type as data_type from + catalog.information_schema.columns where lower(table_catalog)='catalog' + and lower(table_schema)='schema' and lower(table_name) ='supplier' order by + col_name""", + ) + ) + spark.sql().where.assert_called_with("col_name not like '#%'") + + # hive_metastore as catalog + dd.get_schema("hive_metastore", "schema", "supplier") + spark.sql.assert_called_with(re.sub(r'\s+', ' ', """describe table hive_metastore.schema.supplier""")) + spark.sql().where.assert_called_with("col_name not like '#%'") + + # global_temp as schema with hive_metastore + dd.get_schema("hive_metastore", "global_temp", "supplier") + spark.sql.assert_called_with(re.sub(r'\s+', ' ', """describe table global_temp.supplier""")) + spark.sql().where.assert_called_with("col_name not like '#%'") + + +def test_read_data_from_uc(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for DatabricksDataSource + dd = DatabricksDataSource(engine, spark, ws, scope) + + # Test with query + dd.read_data("org", "data", "employee", "select id as id, name as name from :tbl", None) + spark.sql.assert_called_with("select id as id, name as name from org.data.employee") + + # global_temp as schema with UC catalog + dd.read_data("org", "global_temp", "employee", "select id as id, name as name from :tbl", None) + spark.sql.assert_called_with("select id as id, name as name from global_temp.employee") + + +def test_read_data_from_hive(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for DatabricksDataSource + dd = DatabricksDataSource(engine, spark, ws, scope) + + # Test with query + dd.read_data("hive_metastore", "data", "employee", "select id as id, name as name from :tbl", None) + spark.sql.assert_called_with("select id as id, name as name from hive_metastore.data.employee") + + # global_temp as schema with hive_metastore + dd.read_data("hive_metastore", "global_temp", "employee", "select id as id, name as name from :tbl", None) + spark.sql.assert_called_with("select id as id, name as name from global_temp.employee") + + +def test_read_data_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for DatabricksDataSource + dd = DatabricksDataSource(engine, spark, ws, scope) + spark.sql.side_effect = RuntimeError("Test Exception") + + with pytest.raises( + DataSourceRuntimeException, + match="Runtime exception occurred while fetching data using select id as id, ename as name from " + "org.data.employee : Test Exception", + ): + dd.read_data("org", "data", "employee", "select id as id, ename as name from :tbl", None) + + +def test_get_schema_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for DatabricksDataSource + dd = DatabricksDataSource(engine, spark, ws, scope) + spark.sql.side_effect = RuntimeError("Test Exception") + with pytest.raises(DataSourceRuntimeException) as exception: + dd.get_schema("org", "data", "employee") + + assert str(exception.value) == ( + "Runtime exception occurred while fetching schema using select lower(column_name) " + "as col_name, full_data_type as data_type from org.information_schema.columns " + "where lower(table_catalog)='org' and lower(table_schema)='data' and lower(" + "table_name) ='employee' order by col_name : Test Exception" + ) diff --git a/tests/unit/reconcile/connectors/test_mock_data_source.py b/tests/unit/reconcile/connectors/test_mock_data_source.py new file mode 100644 index 0000000000..0ca2888168 --- /dev/null +++ b/tests/unit/reconcile/connectors/test_mock_data_source.py @@ -0,0 +1,105 @@ +import pytest +from pyspark import Row +from pyspark.testing import assertDataFrameEqual + +from databricks.labs.remorph.reconcile.connectors.data_source import MockDataSource +from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException +from databricks.labs.remorph.reconcile.recon_config import Schema + +catalog = "org" +schema = "data" +table = "employee" + + +def test_mock_data_source_happy(mock_spark): + dataframe_repository = { + ( + "org", + "data", + "select * from employee", + ): mock_spark.createDataFrame( + [ + Row(emp_id="1", emp_name="name-1", sal=100), + Row(emp_id="2", emp_name="name-2", sal=200), + Row(emp_id="3", emp_name="name-3", sal=300), + ] + ) + } + schema_repository = { + (catalog, schema, table): [ + Schema(column_name="emp_id", data_type="int"), + Schema(column_name="emp_name", data_type="str"), + Schema(column_name="sal", data_type="int"), + ] + } + + data_source = MockDataSource(dataframe_repository, schema_repository) + + actual_data = data_source.read_data(catalog, schema, table, "select * from employee", None) + expected_data = mock_spark.createDataFrame( + [ + Row(emp_id="1", emp_name="name-1", sal=100), + Row(emp_id="2", emp_name="name-2", sal=200), + Row(emp_id="3", emp_name="name-3", sal=300), + ] + ) + + actual_schema = data_source.get_schema(catalog, schema, table) + assertDataFrameEqual(actual_data, expected_data) + assert actual_schema == [ + Schema(column_name="emp_id", data_type="int"), + Schema(column_name="emp_name", data_type="str"), + Schema(column_name="sal", data_type="int"), + ] + + +def test_mock_data_source_fail(mock_spark): + data_source = MockDataSource({}, {}, Exception("TABLE NOT FOUND")) + with pytest.raises( + DataSourceRuntimeException, + match="Runtime exception occurred while fetching data using \\(org, data, select \\* from test\\) : TABLE" + " NOT FOUND", + ): + data_source.read_data(catalog, schema, table, "select * from test", None) + + with pytest.raises( + DataSourceRuntimeException, + match="Runtime exception occurred while fetching schema using \\(org, data, unknown\\) : TABLE NOT FOUND", + ): + data_source.get_schema(catalog, schema, "unknown") + + +def test_mock_data_source_no_catalog(mock_spark): + dataframe_repository = { + ( + "", + "data", + "select * from employee", + ): mock_spark.createDataFrame( + [ + Row(emp_id="1", emp_name="name-1", sal=100), + Row(emp_id="2", emp_name="name-2", sal=200), + Row(emp_id="3", emp_name="name-3", sal=300), + ] + ) + } + schema_repository = { + (catalog, schema, table): [ + Schema(column_name="emp_id", data_type="int"), + Schema(column_name="emp_name", data_type="str"), + Schema(column_name="sal", data_type="int"), + ] + } + + data_source = MockDataSource(dataframe_repository, schema_repository) + + actual_data = data_source.read_data(None, schema, table, "select * from employee", None) + expected_data = mock_spark.createDataFrame( + [ + Row(emp_id="1", emp_name="name-1", sal=100), + Row(emp_id="2", emp_name="name-2", sal=200), + Row(emp_id="3", emp_name="name-3", sal=300), + ] + ) + + assertDataFrameEqual(actual_data, expected_data) diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py new file mode 100644 index 0000000000..2714eb3b2f --- /dev/null +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -0,0 +1,177 @@ +import base64 +import re +from unittest.mock import MagicMock, create_autospec + +import pytest + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource +from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Table +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import GetSecretResponse + + +def mock_secret(scope, key): + secret_mock = { + "scope": { + 'user': GetSecretResponse(key='user', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8')), + 'password': GetSecretResponse( + key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') + ), + 'host': GetSecretResponse(key='host', value=base64.b64encode(bytes('my_host', 'utf-8')).decode('utf-8')), + 'port': GetSecretResponse(key='port', value=base64.b64encode(bytes('777', 'utf-8')).decode('utf-8')), + 'database': GetSecretResponse( + key='database', value=base64.b64encode(bytes('my_database', 'utf-8')).decode('utf-8') + ), + } + } + + return secret_mock[scope][key] + + +def initial_setup(): + pyspark_sql_session = MagicMock() + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + # Define the source, workspace, and scope + engine = get_dialect("oracle") + ws = create_autospec(WorkspaceClient) + scope = "scope" + ws.secrets.get_secret.side_effect = mock_secret + return engine, spark, ws, scope + + +def test_read_data_with_options(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for SnowflakeDataSource + ds = OracleDataSource(engine, spark, ws, scope) + # Create a Tables configuration object with JDBC reader options + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=JdbcReaderOptions( + number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" + ), + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + column_thresholds=None, + filters=None, + ) + + # Call the read_data method with the Tables configuration + ds.read_data(None, "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + # spark assertions + spark.read.format.assert_called_with("jdbc") + spark.read.format().option.assert_called_with( + "url", + "jdbc:oracle:thin:my_user/my_password@//my_host:777/my_database", + ) + spark.read.format().option().option.assert_called_with("driver", "oracle.jdbc.driver.OracleDriver") + spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from data.employee) tmp") + actual_args = spark.read.format().option().option().option().options.call_args.kwargs + expected_args = { + "numPartitions": 100, + "partitionColumn": "s_nationkey", + "lowerBound": '0', + "upperBound": "100", + "fetchsize": 100, + "oracle.jdbc.mapDateToTimestamp": "False", + "sessionInitStatement": r"BEGIN dbms_session.set_nls('nls_date_format', " + r"'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD " + r"HH24:MI:SS''');END;", + } + assert actual_args == expected_args + spark.read.format().option().option().option().options().load.assert_called_once() + + +def test_get_schema(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for SnowflakeDataSource + ds = OracleDataSource(engine, spark, ws, scope) + # call test method + ds.get_schema(None, "data", "employee") + # spark assertions + spark.read.format.assert_called_with("jdbc") + spark.read.format().option().option().option.assert_called_with( + "dbtable", + re.sub( + r'\s+', + ' ', + r"""(select column_name, case when (data_precision is not null + and data_scale <> 0) + then data_type || '(' || data_precision || ',' || data_scale || ')' + when (data_precision is not null and data_scale = 0) + then data_type || '(' || data_precision || ')' + when data_precision is null and (lower(data_type) in ('date') or + lower(data_type) like 'timestamp%') then data_type + when CHAR_LENGTH == 0 then data_type + else data_type || '(' || CHAR_LENGTH || ')' + end data_type + FROM ALL_TAB_COLUMNS + WHERE lower(TABLE_NAME) = 'employee' and lower(owner) = 'data') tmp""", + ), + ) + + +def test_read_data_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + ds = OracleDataSource(engine, spark, ws, scope) + # Create a Tables configuration object + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + column_thresholds=None, + filters=None, + ) + + spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception") + + # Call the read_data method with the Tables configuration and assert that a PySparkException is raised + with pytest.raises( + DataSourceRuntimeException, + match="Runtime exception occurred while fetching data using select 1 from data.employee : Test Exception", + ): + ds.read_data(None, "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + +def test_get_schema_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + ds = OracleDataSource(engine, spark, ws, scope) + + spark.read.format().option().option().option().load.side_effect = RuntimeError("Test Exception") + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=r"""select column_name, case when (data_precision is not null + and data_scale <> 0) + then data_type || '(' || data_precision || ',' || data_scale || ')' + when (data_precision is not null and data_scale = 0) + then data_type || '(' || data_precision || ')' + when data_precision is null and (lower(data_type) in ('date') or + lower(data_type) like 'timestamp%') then data_type + when CHAR_LENGTH == 0 then data_type + else data_type || '(' || CHAR_LENGTH || ')' + end data_type + FROM ALL_TAB_COLUMNS + WHERE lower(TABLE_NAME) = 'employee' and lower(owner) = 'data' """, + ): + ds.get_schema(None, "data", "employee") diff --git a/tests/unit/reconcile/connectors/test_secrets.py b/tests/unit/reconcile/connectors/test_secrets.py new file mode 100644 index 0000000000..7cb082ad8f --- /dev/null +++ b/tests/unit/reconcile/connectors/test_secrets.py @@ -0,0 +1,49 @@ +import base64 +from unittest.mock import create_autospec + +import pytest + +from databricks.labs.remorph.reconcile.connectors.secrets import SecretsMixin +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.sdk.service.workspace import GetSecretResponse + + +class Test(SecretsMixin): + def __init__(self, ws: WorkspaceClient, secret_scope: str): + self._ws = ws + self._secret_scope = secret_scope + + +def mock_secret(scope, key): + secret_mock = { + "scope": { + 'user_name': GetSecretResponse( + key='user_name', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') + ), + 'password': GetSecretResponse( + key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') + ), + } + } + + return secret_mock.get(scope).get(key) + + +def test_get_secrets_happy(): + ws = create_autospec(WorkspaceClient) + ws.secrets.get_secret.side_effect = mock_secret + + mock = Test(ws, "scope") + + assert mock._get_secret("user_name") == "my_user" + assert mock._get_secret("password") == "my_password" + + +def test_get_secrets_not_found_exception(): + ws = create_autospec(WorkspaceClient) + ws.secrets.get_secret.side_effect = NotFound("Test Exception") + mock = Test(ws, "scope") + + with pytest.raises(NotFound, match="Secret does not exist with scope: scope and key: unknown : Test Exception"): + mock._get_secret("unknown") diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py new file mode 100644 index 0000000000..8b286e259d --- /dev/null +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -0,0 +1,320 @@ +import base64 +import re +from unittest.mock import MagicMock, create_autospec + +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives import serialization +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource +from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException, InvalidSnowflakePemPrivateKey +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Table +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import GetSecretResponse +from databricks.sdk.errors import NotFound + + +def mock_secret(scope, key): + secret_mock = { + "scope": { + 'sfAccount': GetSecretResponse( + key='sfAccount', value=base64.b64encode(bytes('my_account', 'utf-8')).decode('utf-8') + ), + 'sfUser': GetSecretResponse( + key='sfUser', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') + ), + 'sfPassword': GetSecretResponse( + key='sfPassword', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') + ), + 'sfDatabase': GetSecretResponse( + key='sfDatabase', value=base64.b64encode(bytes('my_database', 'utf-8')).decode('utf-8') + ), + 'sfSchema': GetSecretResponse( + key='sfSchema', value=base64.b64encode(bytes('my_schema', 'utf-8')).decode('utf-8') + ), + 'sfWarehouse': GetSecretResponse( + key='sfWarehouse', value=base64.b64encode(bytes('my_warehouse', 'utf-8')).decode('utf-8') + ), + 'sfRole': GetSecretResponse( + key='sfRole', value=base64.b64encode(bytes('my_role', 'utf-8')).decode('utf-8') + ), + 'sfUrl': GetSecretResponse(key='sfUrl', value=base64.b64encode(bytes('my_url', 'utf-8')).decode('utf-8')), + } + } + + return secret_mock[scope][key] + + +def generate_pkcs8_pem_key(malformed=False): + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pem_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode('utf-8') + return pem_key[:50] + "MALFORMED" + pem_key[60:] if malformed else pem_key + + +def mock_private_key_secret(scope, key): + if key == 'pem_private_key': + return GetSecretResponse(key=key, value=base64.b64encode(generate_pkcs8_pem_key().encode()).decode()) + return mock_secret(scope, key) + + +def mock_malformed_private_key_secret(scope, key): + if key == 'pem_private_key': + return GetSecretResponse(key=key, value=base64.b64encode(generate_pkcs8_pem_key(True).encode()).decode()) + return mock_secret(scope, key) + + +def mock_no_auth_key_secret(scope, key): + if key == 'pem_private_key' or key == 'sfPassword': + raise NotFound("Secret not found") + return mock_secret(scope, key) + + +def initial_setup(): + pyspark_sql_session = MagicMock() + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + # Define the source, workspace, and scope + engine = get_dialect("snowflake") + ws = create_autospec(WorkspaceClient) + scope = "scope" + ws.secrets.get_secret.side_effect = mock_secret + return engine, spark, ws, scope + + +def test_get_jdbc_url_happy(): + # initial setup + engine, spark, ws, scope = initial_setup() + # create object for SnowflakeDataSource + ds = SnowflakeDataSource(engine, spark, ws, scope) + url = ds.get_jdbc_url + # Assert that the URL is generated correctly + assert url == ( + "jdbc:snowflake://my_account.snowflakecomputing.com" + "/?user=my_user&password=my_password" + "&db=my_database&schema=my_schema" + "&warehouse=my_warehouse&role=my_role" + ) + + +def test_get_jdbc_url_fail(): + # initial setup + engine, spark, ws, scope = initial_setup() + ws.secrets.get_secret.side_effect = mock_secret + # create object for SnowflakeDataSource + ds = SnowflakeDataSource(engine, spark, ws, scope) + url = ds.get_jdbc_url + # Assert that the URL is generated correctly + assert url == ( + "jdbc:snowflake://my_account.snowflakecomputing.com" + "/?user=my_user&password=my_password" + "&db=my_database&schema=my_schema" + "&warehouse=my_warehouse&role=my_role" + ) + + +def test_read_data_with_out_options(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for SnowflakeDataSource + ds = SnowflakeDataSource(engine, spark, ws, scope) + # Create a Tables configuration object with no JDBC reader options + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + column_thresholds=None, + filters=None, + ) + + # Call the read_data method with the Tables configuration + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + # spark assertions + spark.read.format.assert_called_with("snowflake") + spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") + spark.read.format().option().options.assert_called_with( + sfUrl="my_url", + sfUser="my_user", + sfPassword="my_password", + sfDatabase="my_database", + sfSchema="my_schema", + sfWarehouse="my_warehouse", + sfRole="my_role", + ) + spark.read.format().option().options().load.assert_called_once() + + +def test_read_data_with_options(): + # initial setup + engine, spark, ws, scope = initial_setup() + + # create object for SnowflakeDataSource + ds = SnowflakeDataSource(engine, spark, ws, scope) + # Create a Tables configuration object with JDBC reader options + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=JdbcReaderOptions( + number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" + ), + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + column_thresholds=None, + filters=None, + ) + + # Call the read_data method with the Tables configuration + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + # spark assertions + spark.read.format.assert_called_with("jdbc") + spark.read.format().option.assert_called_with( + "url", + "jdbc:snowflake://my_account.snowflakecomputing.com/?user=my_user&password=" + "my_password&db=my_database&schema=my_schema&warehouse=my_warehouse&role=my_role", + ) + spark.read.format().option().option.assert_called_with("driver", "net.snowflake.client.jdbc.SnowflakeDriver") + spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from org.data.employee) tmp") + spark.read.format().option().option().option().options.assert_called_with( + numPartitions=100, partitionColumn='s_nationkey', lowerBound='0', upperBound='100', fetchsize=100 + ) + spark.read.format().option().option().option().options().load.assert_called_once() + + +def test_get_schema(): + # initial setup + engine, spark, ws, scope = initial_setup() + # Mocking get secret method to return the required values + # create object for SnowflakeDataSource + ds = SnowflakeDataSource(engine, spark, ws, scope) + # call test method + ds.get_schema("catalog", "schema", "supplier") + # spark assertions + spark.read.format.assert_called_with("snowflake") + spark.read.format().option.assert_called_with( + "dbtable", + re.sub( + r'\s+', + ' ', + """(select column_name, case when numeric_precision is not null and numeric_scale is not null then + concat(data_type, '(', numeric_precision, ',' , numeric_scale, ')') when lower(data_type) = 'text' then + concat('varchar', '(', CHARACTER_MAXIMUM_LENGTH, ')') else data_type end as data_type from + catalog.INFORMATION_SCHEMA.COLUMNS where lower(table_name)='supplier' and table_schema = 'SCHEMA' + order by ordinal_position) as tmp""", + ), + ) + spark.read.format().option().options.assert_called_with( + sfUrl="my_url", + sfUser="my_user", + sfPassword="my_password", + sfDatabase="my_database", + sfSchema="my_schema", + sfWarehouse="my_warehouse", + sfRole="my_role", + ) + spark.read.format().option().options().load.assert_called_once() + + +def test_read_data_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + ds = SnowflakeDataSource(engine, spark, ws, scope) + # Create a Tables configuration object + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + column_thresholds=None, + filters=None, + ) + + spark.read.format().option().options().load.side_effect = RuntimeError("Test Exception") + + # Call the read_data method with the Tables configuration and assert that a PySparkException is raised + with pytest.raises( + DataSourceRuntimeException, + match="Runtime exception occurred while fetching data using select 1 from org.data.employee : Test Exception", + ): + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + +def test_get_schema_exception_handling(): + # initial setup + engine, spark, ws, scope = initial_setup() + + ds = SnowflakeDataSource(engine, spark, ws, scope) + + spark.read.format().option().options().load.side_effect = RuntimeError("Test Exception") + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=r"Runtime exception occurred while fetching schema using select column_name, case when numeric_precision " + "is not null and numeric_scale is not null then concat\\(data_type, '\\(', numeric_precision, ',' , " + "numeric_scale, '\\)'\\) when lower\\(data_type\\) = 'text' then concat\\('varchar', '\\(', " + "CHARACTER_MAXIMUM_LENGTH, '\\)'\\) else data_type end as data_type from catalog.INFORMATION_SCHEMA.COLUMNS " + "where lower\\(table_name\\)='supplier' and table_schema = 'SCHEMA' order by ordinal_position : Test " + "Exception", + ): + ds.get_schema("catalog", "schema", "supplier") + + +def test_read_data_with_out_options_private_key(): + engine, spark, ws, scope = initial_setup() + ws.secrets.get_secret.side_effect = mock_private_key_secret + ds = SnowflakeDataSource(engine, spark, ws, scope) + table_conf = Table(source_name="supplier", target_name="supplier") + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + spark.read.format.assert_called_with("snowflake") + spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") + expected_options = { + "sfUrl": "my_url", + "sfUser": "my_user", + "sfDatabase": "my_database", + "sfSchema": "my_schema", + "sfWarehouse": "my_warehouse", + "sfRole": "my_role", + } + actual_options = spark.read.format().option().options.call_args[1] + actual_options.pop("pem_private_key", None) + assert actual_options == expected_options + spark.read.format().option().options().load.assert_called_once() + + +def test_read_data_with_out_options_malformed_private_key(): + engine, spark, ws, scope = initial_setup() + ws.secrets.get_secret.side_effect = mock_malformed_private_key_secret + ds = SnowflakeDataSource(engine, spark, ws, scope) + table_conf = Table(source_name="supplier", target_name="supplier") + with pytest.raises(InvalidSnowflakePemPrivateKey, match="Failed to load or process the provided PEM private key."): + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + + +def test_read_data_with_out_any_auth(): + engine, spark, ws, scope = initial_setup() + ws.secrets.get_secret.side_effect = mock_no_auth_key_secret + ds = SnowflakeDataSource(engine, spark, ws, scope) + table_conf = Table(source_name="supplier", target_name="supplier") + with pytest.raises( + NotFound, match='sfPassword and pem_private_key not found. Either one is required for snowflake auth.' + ): + ds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) diff --git a/tests/unit/reconcile/query_builder/__init__.py b/tests/unit/reconcile/query_builder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/reconcile/query_builder/test_count_query.py b/tests/unit/reconcile/query_builder/test_count_query.py new file mode 100644 index 0000000000..1b03b01fee --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_count_query.py @@ -0,0 +1,13 @@ +from databricks.labs.remorph.reconcile.query_builder.count_query import CountQueryBuilder +from databricks.labs.remorph.config import get_dialect + + +def test_count_query(table_conf_with_opts): + source_query = CountQueryBuilder( + table_conf=table_conf_with_opts, layer="source", engine=get_dialect("oracle") + ).build_query() + target_query = CountQueryBuilder( + table_conf=table_conf_with_opts, layer="target", engine=get_dialect("databricks") + ).build_query() + assert source_query == "SELECT COUNT(1) AS count FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + assert target_query == "SELECT COUNT(1) AS count FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" diff --git a/tests/unit/reconcile/query_builder/test_expression_generator.py b/tests/unit/reconcile/query_builder/test_expression_generator.py new file mode 100644 index 0000000000..d068fe6935 --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_expression_generator.py @@ -0,0 +1,229 @@ +import pytest +from sqlglot import expressions as exp +from sqlglot import parse_one +from sqlglot.expressions import Column + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.query_builder.expression_generator import ( + array_sort, + array_to_string, + build_between, + build_column, + build_from_clause, + build_if, + build_join_clause, + build_literal, + build_sub, + build_where_clause, + coalesce, + concat, + get_hash_transform, + json_format, + lower, + sha2, + sort_array, + to_char, + trim, +) + + +def test_coalesce(expr): + assert coalesce(expr, "NA", True).sql() == "SELECT COALESCE(col1, 'NA') FROM DUAL" + assert coalesce(expr, "0", False).sql() == "SELECT COALESCE(col1, 0) FROM DUAL" + assert coalesce(expr).sql() == "SELECT COALESCE(col1, 0) FROM DUAL" + + +def test_trim(expr): + assert trim(expr).sql() == "SELECT TRIM(col1) FROM DUAL" + + nested_expr = parse_one("select coalesce(col1,' ') FROM DUAL") + assert trim(nested_expr).sql() == "SELECT COALESCE(TRIM(col1), ' ') FROM DUAL" + + +def test_json_format(): + expr = parse_one("SELECT col1 FROM DUAL") + + assert json_format(expr).sql() == "SELECT JSON_FORMAT(col1) FROM DUAL" + assert json_format(expr).sql(dialect="databricks") == "SELECT TO_JSON(col1) FROM DUAL" + assert json_format(expr).sql(dialect="snowflake") == "SELECT JSON_FORMAT(col1) FROM DUAL" + + +def test_sort_array(expr): + assert sort_array(expr).sql() == "SELECT SORT_ARRAY(col1, TRUE) FROM DUAL" + assert sort_array(expr, asc=False).sql() == "SELECT SORT_ARRAY(col1, FALSE) FROM DUAL" + + +def test_to_char(expr): + assert to_char(expr).sql(dialect="oracle") == "SELECT TO_CHAR(col1) FROM DUAL" + assert to_char(expr, to_format='YYYY-MM-DD').sql(dialect="oracle") == "SELECT TO_CHAR(col1, 'YYYY-MM-DD') FROM DUAL" + + +def test_array_to_string(expr): + assert array_to_string(expr).sql() == "SELECT ARRAY_TO_STRING(col1, ',') FROM DUAL" + assert array_to_string(expr, null_replacement='NA').sql() == "SELECT ARRAY_TO_STRING(col1, ',', 'NA') FROM DUAL" + + +def test_array_sort(expr): + assert array_sort(expr).sql() == "SELECT ARRAY_SORT(col1, TRUE) FROM DUAL" + assert array_sort(expr, asc=False).sql() == "SELECT ARRAY_SORT(col1, FALSE) FROM DUAL" + + +def test_build_column(): + # test build_column without alias and column as str expr + assert build_column(this="col1") == exp.Column(this=exp.Identifier(this="col1", quoted=False), table="") + + # test build_column with alias and column as str expr + assert build_column(this="col1", alias="col1_aliased") == exp.Alias( + this=exp.Column(this="col1", table=""), alias=exp.Identifier(this="col1_aliased", quoted=False) + ) + + # test build_column with alias and column as exp.Column expr + assert build_column( + this=exp.Column(this=exp.Identifier(this="col1", quoted=False), table=""), alias="col1_aliased" + ) == exp.Alias( + this=exp.Column(this=exp.Identifier(this="col1", quoted=False), table=""), + alias=exp.Identifier(this="col1_aliased", quoted=False), + ) + + # with table name + result = build_column(this="test_column", alias="test_alias", table_name="test_table") + assert str(result) == "test_table.test_column AS test_alias" + + +def test_build_literal(): + actual = build_literal(this="abc") + expected = exp.Literal(this="abc", is_string=True) + + assert actual == expected + + +def test_sha2(expr): + assert sha2(expr, num_bits="256").sql() == "SELECT SHA2(col1, 256) FROM DUAL" + assert ( + sha2(Column(this="CONCAT(col1,col2,col3)"), num_bits="256", is_expr=True).sql() + == "SHA2(CONCAT(col1,col2,col3), 256)" + ) + + +def test_concat(): + exprs = [exp.Expression(this="col1"), exp.Expression(this="col2")] + assert concat(exprs) == exp.Concat( + expressions=[exp.Expression(this="col1"), exp.Expression(this="col2")], safe=True + ) + + +def test_lower(expr): + assert lower(expr).sql() == "SELECT LOWER(col1) FROM DUAL" + assert lower(Column(this="CONCAT(col1,col2,col3)"), is_expr=True).sql() == "LOWER(CONCAT(col1,col2,col3))" + + +def test_get_hash_transform(): + assert isinstance(get_hash_transform(get_dialect("snowflake"), "source"), list) is True + + with pytest.raises(ValueError): + get_hash_transform(get_dialect("trino"), "source") + + with pytest.raises(ValueError): + get_hash_transform(get_dialect("snowflake"), "sourc") + + +def test_build_from_clause(): + # with table alias + result = build_from_clause("test_table", "test_alias") + assert str(result) == "FROM test_table AS test_alias" + assert isinstance(result, exp.From) + assert result.this.this.this == "test_table" + assert result.this.alias == "test_alias" + + # without table alias + result = build_from_clause("test_table") + assert str(result) == "FROM test_table" + + +def test_build_join_clause(): + # with table alias + result = build_join_clause( + table_name="test_table", + join_columns=["test_column"], + source_table_alias="source", + target_table_alias="test_alias", + ) + assert str(result) == ( + "INNER JOIN test_table AS test_alias ON source.test_column IS NOT DISTINCT FROM test_alias.test_column" + ) + assert isinstance(result, exp.Join) + assert result.this.this.this == "test_table" + assert result.this.alias == "test_alias" + + # without table alias + result = build_join_clause("test_table", ["test_column"]) + assert str(result) == "INNER JOIN test_table ON test_column IS NOT DISTINCT FROM test_column" + + +def test_build_sub(): + # with table name + result = build_sub("left_column", "right_column", "left_table", "right_table") + assert str(result) == "left_table.left_column - right_table.right_column" + assert isinstance(result, exp.Sub) + assert result.this.this.this == "left_column" + assert result.this.table == "left_table" + assert result.expression.this.this == "right_column" + assert result.expression.table == "right_table" + + # without table name + result = build_sub("left_column", "right_column") + assert str(result) == "left_column - right_column" + + +def test_build_where_clause(): + # or condition + where_clause = [ + exp.EQ( + this=exp.Column(this="test_column", table="test_table"), expression=exp.Literal(this='1', is_string=False) + ) + ] + result = build_where_clause(where_clause) + assert str(result) == "(1 = 1 OR 1 = 1) OR test_table.test_column = 1" + assert isinstance(result, exp.Or) + + # and condition + where_clause = [ + exp.EQ( + this=exp.Column(this="test_column", table="test_table"), expression=exp.Literal(this='1', is_string=False) + ) + ] + result = build_where_clause(where_clause, "and") + assert str(result) == "(1 = 1 AND 1 = 1) AND test_table.test_column = 1" + assert isinstance(result, exp.And) + + +def test_build_if(): + # with true and false + result = build_if( + this=exp.EQ( + this=exp.Column(this="test_column", table="test_table"), expression=exp.Literal(this='1', is_string=False) + ), + true=exp.Literal(this='1', is_string=False), + false=exp.Literal(this='0', is_string=False), + ) + assert str(result) == "CASE WHEN test_table.test_column = 1 THEN 1 ELSE 0 END" + assert isinstance(result, exp.If) + + # without false + result = build_if( + this=exp.EQ( + this=exp.Column(this="test_column", table="test_table"), expression=exp.Literal(this='1', is_string=False) + ), + true=exp.Literal(this='1', is_string=False), + ) + assert str(result) == "CASE WHEN test_table.test_column = 1 THEN 1 END" + + +def test_build_between(): + result = build_between( + this=exp.Column(this="test_column", table="test_table"), + low=exp.Literal(this='1', is_string=False), + high=exp.Literal(this='2', is_string=False), + ) + assert str(result) == "test_table.test_column BETWEEN 1 AND 2" + assert isinstance(result, exp.Between) diff --git a/tests/unit/reconcile/query_builder/test_hash_query.py b/tests/unit/reconcile/query_builder/test_hash_query.py new file mode 100644 index 0000000000..dca12acaa5 --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_hash_query.py @@ -0,0 +1,225 @@ +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.query_builder.hash_query import HashQueryBuilder +from databricks.labs.remorph.reconcile.recon_config import Filters, ColumnMapping, Transformation + + +def test_hash_query_builder_for_snowflake_src(table_conf_with_opts, table_schema): + sch, sch_with_alias = table_schema + src_actual = HashQueryBuilder(table_conf_with_opts, sch, "source", get_dialect("snowflake")).build_query( + report_type="data" + ) + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(TRIM(s_address), TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), " + "TRIM(s_phone), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_nationkey AS " + "s_nationkey, " + "s_suppkey AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + ) + + tgt_actual = HashQueryBuilder( + table_conf_with_opts, sch_with_alias, "target", get_dialect("databricks") + ).build_query(report_type="data") + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(TRIM(s_address_t), TRIM(s_name), COALESCE(TRIM(s_nationkey_t), '_null_recon_'), " + "TRIM(s_phone_t), COALESCE(TRIM(s_suppkey_t), '_null_recon_')), 256)) AS hash_value_recon, s_nationkey_t AS " + "s_nationkey, " + "s_suppkey_t AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_hash_query_builder_for_oracle_src(table_conf_mock, table_schema, column_mapping): + schema, _ = table_schema + table_conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + filters=Filters(source="s_nationkey=1"), + column_mapping=[ColumnMapping(source_name="s_nationkey", target_name="s_nationkey")], + ) + src_actual = HashQueryBuilder(table_conf, schema, "source", get_dialect("oracle")).build_query(report_type="all") + src_expected = ( + "SELECT LOWER(RAWTOHEX(STANDARD_HASH(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), COALESCE(TRIM(" + "s_address), '_null_recon_'), " + "COALESCE(TRIM(s_comment), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(" + "s_nationkey), '_null_recon_'), COALESCE(TRIM(s_phone), '_null_recon_'), COALESCE(TRIM(s_suppkey), " + "'_null_recon_')), 'SHA256'))) AS hash_value_recon, s_nationkey AS s_nationkey, " + "s_suppkey AS s_suppkey FROM :tbl WHERE s_nationkey = 1" + ) + + tgt_actual = HashQueryBuilder(table_conf, schema, "target", get_dialect("databricks")).build_query( + report_type="all" + ) + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), COALESCE(TRIM(s_address), " + "'_null_recon_'), COALESCE(TRIM(" + "s_comment), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(s_nationkey), " + "'_null_recon_'), COALESCE(TRIM(s_phone), " + "'_null_recon_'), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_nationkey AS " + "s_nationkey, s_suppkey " + "AS s_suppkey FROM :tbl" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_hash_query_builder_for_databricks_src(table_conf_mock, table_schema, column_mapping): + table_conf = table_conf_mock( + join_columns=["s_suppkey"], + column_mapping=column_mapping, + filters=Filters(target="s_nationkey_t=1"), + ) + sch, sch_with_alias = table_schema + src_actual = HashQueryBuilder(table_conf, sch, "source", get_dialect("databricks")).build_query(report_type="data") + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), COALESCE(TRIM(s_address), '_null_recon_'), " + "COALESCE(TRIM(s_comment), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(" + "s_nationkey), '_null_recon_'), COALESCE(TRIM(" + "s_phone), '_null_recon_'), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_suppkey " + "AS s_suppkey FROM :tbl" + ) + + tgt_actual = HashQueryBuilder(table_conf, sch_with_alias, "target", get_dialect("databricks")).build_query( + report_type="data" + ) + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal_t), '_null_recon_'), COALESCE(TRIM(s_address_t), " + "'_null_recon_'), COALESCE(TRIM(" + "s_comment_t), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(s_nationkey_t), " + "'_null_recon_'), COALESCE(TRIM(s_phone_t), " + "'_null_recon_'), COALESCE(TRIM(s_suppkey_t), '_null_recon_')), 256)) AS hash_value_recon, s_suppkey_t AS " + "s_suppkey FROM :tbl WHERE " + "s_nationkey_t = 1" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_hash_query_builder_without_column_mapping(table_conf_mock, table_schema): + table_conf = table_conf_mock( + join_columns=["s_suppkey"], + filters=Filters(target="s_nationkey=1"), + ) + sch, _ = table_schema + src_actual = HashQueryBuilder(table_conf, sch, "source", get_dialect("databricks")).build_query(report_type="data") + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), COALESCE(TRIM(s_address), '_null_recon_')," + " COALESCE(TRIM(s_comment), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(" + "s_nationkey), '_null_recon_'), COALESCE(TRIM(" + "s_phone), '_null_recon_'), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_suppkey " + "AS s_suppkey FROM :tbl" + ) + + tgt_actual = HashQueryBuilder(table_conf, sch, "target", get_dialect("databricks")).build_query(report_type="data") + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), COALESCE(TRIM(s_address), " + "'_null_recon_'), COALESCE(TRIM(" + "s_comment), '_null_recon_'), COALESCE(TRIM(s_name), '_null_recon_'), COALESCE(TRIM(s_nationkey), " + "'_null_recon_'), COALESCE(TRIM(s_phone), " + "'_null_recon_'), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_suppkey AS " + "s_suppkey FROM :tbl WHERE " + "s_nationkey = 1" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_hash_query_builder_without_transformation(table_conf_mock, table_schema, column_mapping): + table_conf = table_conf_mock( + join_columns=["s_suppkey"], + transformations=[ + Transformation(column_name="s_address", source=None, target="trim(s_address_t)"), + Transformation(column_name="s_name", source="trim(s_name)", target=None), + Transformation(column_name="s_suppkey", source="trim(s_suppkey)", target=None), + ], + column_mapping=column_mapping, + filters=Filters(target="s_nationkey_t=1"), + ) + sch, tgt_sch = table_schema + src_actual = HashQueryBuilder(table_conf, sch, "source", get_dialect("databricks")).build_query(report_type="data") + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), s_address, " + "COALESCE(TRIM(s_comment), '_null_recon_'), TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), " + "COALESCE(TRIM(" + "s_phone), '_null_recon_'), TRIM(s_suppkey)), 256)) AS hash_value_recon, TRIM(s_suppkey) AS s_suppkey FROM :tbl" + ) + + tgt_actual = HashQueryBuilder(table_conf, tgt_sch, "target", get_dialect("databricks")).build_query( + report_type="data" + ) + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal_t), '_null_recon_'), TRIM(s_address_t), COALESCE(TRIM(" + "s_comment_t), '_null_recon_'), s_name, COALESCE(TRIM(s_nationkey_t), '_null_recon_'), COALESCE(TRIM(" + "s_phone_t), " + "'_null_recon_'), s_suppkey_t), 256)) AS hash_value_recon, s_suppkey_t AS s_suppkey FROM :tbl WHERE " + "s_nationkey_t = 1" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_hash_query_builder_for_report_type_is_row(table_conf_with_opts, table_schema, column_mapping): + sch, sch_with_alias = table_schema + src_actual = HashQueryBuilder(table_conf_with_opts, sch, "source", get_dialect("databricks")).build_query( + report_type="row" + ) + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(TRIM(s_address), TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), " + "TRIM(s_phone), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, TRIM(s_address) AS " + "s_address, TRIM(s_name) AS s_name, s_nationkey AS s_nationkey, TRIM(s_phone) " + "AS s_phone, s_suppkey AS s_suppkey FROM :tbl WHERE s_name = 't' AND " + "s_address = 'a'" + ) + + tgt_actual = HashQueryBuilder( + table_conf_with_opts, sch_with_alias, "target", get_dialect("databricks") + ).build_query(report_type="row") + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(TRIM(s_address_t), TRIM(s_name), COALESCE(TRIM(s_nationkey_t), '_null_recon_'), " + "TRIM(s_phone_t), COALESCE(TRIM(s_suppkey_t), '_null_recon_')), 256)) AS hash_value_recon, TRIM(s_address_t) " + "AS s_address, TRIM(s_name) AS s_name, s_nationkey_t AS s_nationkey, " + "TRIM(s_phone_t) AS s_phone, s_suppkey_t AS s_suppkey FROM :tbl WHERE s_name " + "= 't' AND s_address_t = 'a'" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected + + +def test_config_case_sensitivity(table_conf_mock, table_schema, column_mapping): + table_conf = table_conf_mock( + select_columns=["S_SUPPKEY", "S_name", "S_ADDRESS", "S_NATIOnKEY", "S_PhONE", "S_acctbal"], + drop_columns=["s_Comment"], + join_columns=["S_SUPPKEY"], + transformations=[ + Transformation(column_name="S_ADDRESS", source=None, target="trim(s_address_t)"), + Transformation(column_name="S_NAME", source="trim(s_name)", target=None), + Transformation(column_name="s_suppKey", source="trim(s_suppkey)", target=None), + ], + column_mapping=column_mapping, + filters=Filters(target="s_nationkey_t=1"), + ) + sch, tgt_sch = table_schema + src_actual = HashQueryBuilder(table_conf, sch, "source", get_dialect("databricks")).build_query(report_type="data") + src_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal), '_null_recon_'), s_address, " + "TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), COALESCE(TRIM(" + "s_phone), '_null_recon_'), TRIM(s_suppkey)), 256)) AS hash_value_recon, TRIM(s_suppkey) AS s_suppkey FROM :tbl" + ) + + tgt_actual = HashQueryBuilder(table_conf, tgt_sch, "target", get_dialect("databricks")).build_query( + report_type="data" + ) + tgt_expected = ( + "SELECT LOWER(SHA2(CONCAT(COALESCE(TRIM(s_acctbal_t), '_null_recon_'), TRIM(s_address_t), s_name, " + "COALESCE(TRIM(" + "s_nationkey_t), '_null_recon_'), COALESCE(TRIM(s_phone_t), '_null_recon_'), s_suppkey_t), " + "256)) AS hash_value_recon, s_suppkey_t AS " + "s_suppkey FROM :tbl WHERE s_nationkey_t = 1" + ) + + assert src_actual == src_expected + assert tgt_actual == tgt_expected diff --git a/tests/unit/reconcile/query_builder/test_sampling_query.py b/tests/unit/reconcile/query_builder/test_sampling_query.py new file mode 100644 index 0000000000..b8964454f0 --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_sampling_query.py @@ -0,0 +1,326 @@ +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.query_builder.sampling_query import ( + SamplingQueryBuilder, +) +from databricks.labs.remorph.reconcile.recon_config import ( + ColumnMapping, + Filters, + Schema, + Transformation, +) + + +def test_build_query_for_snowflake_src(mock_spark, table_conf_mock, table_schema): + spark = mock_spark + sch, sch_with_alias = table_schema + df_schema = StructType( + [ + StructField('s_suppkey', IntegerType()), + StructField('s_name', StringType()), + StructField('s_address', StringType()), + StructField('s_nationkey', IntegerType()), + StructField('s_phone', StringType()), + StructField('s_acctbal', StringType()), + StructField('s_comment', StringType()), + ] + ) + df = spark.createDataFrame( + [ + (1, 'name-1', 'add-1', 11, '1-1', 100, 'c-1'), + (2, 'name-2', 'add-2', 22, '2-2', 200, 'c-2'), + ], + schema=df_schema, + ) + + conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_nationkey", target_name='s_nationkey_t'), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ColumnMapping(source_name="s_phone", target_name="s_phone_t"), + ColumnMapping(source_name="s_acctbal", target_name="s_acctbal_t"), + ColumnMapping(source_name="s_comment", target_name="s_comment_t"), + ], + filters=Filters(source="s_nationkey=1"), + transformations=[Transformation(column_name="s_address", source="trim(s_address)", target="trim(s_address_t)")], + ) + + src_actual = SamplingQueryBuilder(conf, sch, "source", get_dialect("snowflake")).build_query(df) + src_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey UNION SELECT 22 AS ' + "s_nationkey, 2 AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_acctbal), '_null_recon_') " + "AS s_acctbal, TRIM(s_address) AS s_address, COALESCE(TRIM(s_comment), '_null_recon_') AS " + "s_comment, COALESCE(TRIM(s_name), '_null_recon_') AS s_name, COALESCE(TRIM(s_nationkey), " + "'_null_recon_') AS s_nationkey, COALESCE(TRIM(s_phone), '_null_recon_') AS s_phone, " + "COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_nationkey = 1) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon AS recon ON " + "COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') " + "AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + tgt_actual = SamplingQueryBuilder(conf, sch_with_alias, "target", get_dialect("databricks")).build_query(df) + tgt_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey UNION SELECT 22 AS ' + "s_nationkey, 2 AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_acctbal_t), '_null_recon_') " + 'AS s_acctbal, TRIM(s_address_t) AS s_address, COALESCE(TRIM(s_comment_t), ' + "'_null_recon_') AS s_comment, COALESCE(TRIM(s_name), '_null_recon_') AS s_name, " + "COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, COALESCE(TRIM(s_phone_t), " + "'_null_recon_') AS s_phone, COALESCE(TRIM(s_suppkey_t), '_null_recon_') AS s_suppkey FROM :tbl) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + assert src_expected == src_actual + assert tgt_expected == tgt_actual + + +def test_build_query_for_oracle_src(mock_spark, table_conf_mock, table_schema, column_mapping): + spark = mock_spark + _, sch_with_alias = table_schema + df_schema = StructType( + [ + StructField('s_suppkey', IntegerType()), + StructField('s_name', StringType()), + StructField('s_address', StringType()), + StructField('s_nationkey', IntegerType()), + StructField('s_phone', StringType()), + StructField('s_acctbal', StringType()), + StructField('s_comment', StringType()), + ] + ) + df = spark.createDataFrame( + [ + (1, 'name-1', 'add-1', 11, '1-1', 100, 'c-1'), + (2, 'name-2', 'add-2', 22, '2-2', 200, 'c-2'), + (3, 'name-3', 'add-3', 33, '3-3', 300, 'c-3'), + ], + schema=df_schema, + ) + + conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + column_mapping=column_mapping, + filters=Filters(source="s_nationkey=1"), + ) + + sch = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "nvarchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "nchar"), + ] + + src_actual = SamplingQueryBuilder(conf, sch, "source", get_dialect("oracle")).build_query(df) + src_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey FROM dual UNION SELECT 22 AS ' + 's_nationkey, 2 AS s_suppkey FROM dual UNION SELECT 33 AS s_nationkey, 3 AS s_suppkey FROM dual), ' + "src AS (SELECT COALESCE(TRIM(s_acctbal), '_null_recon_') AS s_acctbal, " + "COALESCE(TRIM(s_address), '_null_recon_') AS s_address, " + "NVL(TRIM(TO_CHAR(s_comment)),'_null_recon_') AS s_comment, " + "COALESCE(TRIM(s_name), '_null_recon_') AS s_name, COALESCE(TRIM(s_nationkey), '_null_recon_') AS " + "s_nationkey, NVL(TRIM(TO_CHAR(s_phone)),'_null_recon_') AS s_phone, " + "COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_nationkey = 1) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + tgt_actual = SamplingQueryBuilder(conf, sch_with_alias, "target", get_dialect("databricks")).build_query(df) + tgt_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey UNION SELECT 22 AS ' + 's_nationkey, 2 AS s_suppkey UNION SELECT 33 AS s_nationkey, 3 AS s_suppkey), ' + "src AS (SELECT COALESCE(TRIM(s_acctbal_t), '_null_recon_') AS s_acctbal, " + "COALESCE(TRIM(s_address_t), '_null_recon_') AS s_address, COALESCE(TRIM(s_comment_t), " + "'_null_recon_') AS s_comment, COALESCE(TRIM(s_name), '_null_recon_') AS s_name, " + "COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, COALESCE(TRIM(s_phone_t), " + "'_null_recon_') AS s_phone, COALESCE(TRIM(s_suppkey_t), '_null_recon_') AS s_suppkey FROM :tbl) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + assert src_expected == src_actual + assert tgt_expected == tgt_actual + + +def test_build_query_for_databricks_src(mock_spark, table_conf_mock): + spark = mock_spark + df_schema = StructType( + [ + StructField('s_suppkey', IntegerType()), + StructField('s_name', StringType()), + StructField('s_address', StringType()), + StructField('s_nationkey', IntegerType()), + StructField('s_phone', StringType()), + StructField('s_acctbal', StringType()), + StructField('s_comment', StringType()), + ] + ) + df = spark.createDataFrame([(1, 'name-1', 'add-1', 11, '1-1', 100, 'c-1')], schema=df_schema) + + schema = [ + Schema("s_suppkey", "bigint"), + Schema("s_name", "string"), + Schema("s_address", "string"), + Schema("s_nationkey", "bigint"), + Schema("s_phone", "string"), + Schema("s_acctbal", "bigint"), + Schema("s_comment", "string"), + ] + + conf = table_conf_mock(join_columns=["s_suppkey", "s_nationkey"]) + + src_actual = SamplingQueryBuilder(conf, schema, "source", get_dialect("databricks")).build_query(df) + src_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey), src AS (SELECT ' + "COALESCE(TRIM(s_acctbal), '_null_recon_') AS s_acctbal, COALESCE(TRIM(s_address), '_null_recon_') AS " + "s_address, COALESCE(TRIM(s_comment), '_null_recon_') AS s_comment, " + "COALESCE(TRIM(s_name), '_null_recon_') AS s_name, COALESCE(TRIM(s_nationkey), '_null_recon_') AS " + "s_nationkey, COALESCE(TRIM(s_phone), '_null_recon_') AS s_phone, " + "COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM :tbl) SELECT src.s_acctbal, " + 'src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, src.s_suppkey FROM src INNER ' + "JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), " + "'_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), " + "'_null_recon_')" + ) + assert src_expected == src_actual + + +def test_build_query_for_snowflake_without_transformations(mock_spark, table_conf_mock, table_schema): + spark = mock_spark + sch, sch_with_alias = table_schema + df_schema = StructType( + [ + StructField('s_suppkey', IntegerType()), + StructField('s_name', StringType()), + StructField('s_address', StringType()), + StructField('s_nationkey', IntegerType()), + StructField('s_phone', StringType()), + StructField('s_acctbal', StringType()), + StructField('s_comment', StringType()), + ] + ) + df = spark.createDataFrame( + [ + (1, 'name-1', 'add-1', 11, '1-1', 100, 'c-1'), + (2, 'name-2', 'add-2', 22, '2-2', 200, 'c-2'), + ], + schema=df_schema, + ) + + conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_nationkey", target_name='s_nationkey_t'), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ColumnMapping(source_name="s_phone", target_name="s_phone_t"), + ColumnMapping(source_name="s_acctbal", target_name="s_acctbal_t"), + ColumnMapping(source_name="s_comment", target_name="s_comment_t"), + ], + filters=Filters(source="s_nationkey=1"), + transformations=[ + Transformation(column_name="s_address", source=None, target="trim(s_address_t)"), + Transformation(column_name="s_name", source="trim(s_name)", target=None), + Transformation(column_name="s_suppkey", source="trim(s_suppkey)", target=None), + ], + ) + + src_actual = SamplingQueryBuilder(conf, sch, "source", get_dialect("snowflake")).build_query(df) + src_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey UNION SELECT 22 AS ' + "s_nationkey, 2 AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_acctbal), '_null_recon_') " + "AS s_acctbal, s_address AS s_address, COALESCE(TRIM(s_comment), '_null_recon_') AS " + "s_comment, TRIM(s_name) AS s_name, COALESCE(TRIM(s_nationkey), " + "'_null_recon_') AS s_nationkey, COALESCE(TRIM(s_phone), '_null_recon_') AS s_phone, " + "TRIM(s_suppkey) AS s_suppkey FROM :tbl WHERE s_nationkey = 1) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + tgt_actual = SamplingQueryBuilder(conf, sch_with_alias, "target", get_dialect("databricks")).build_query(df) + tgt_expected = ( + 'WITH recon AS (SELECT 11 AS s_nationkey, 1 AS s_suppkey UNION SELECT 22 AS ' + "s_nationkey, 2 AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_acctbal_t), '_null_recon_') " + 'AS s_acctbal, TRIM(s_address_t) AS s_address, COALESCE(TRIM(s_comment_t), ' + "'_null_recon_') AS s_comment, s_name AS s_name, " + "COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, COALESCE(TRIM(s_phone_t), " + "'_null_recon_') AS s_phone, s_suppkey_t AS s_suppkey FROM :tbl) " + 'SELECT src.s_acctbal, src.s_address, src.s_comment, src.s_name, src.s_nationkey, src.s_phone, ' + "src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + assert src_expected == src_actual + assert tgt_expected == tgt_actual + + +def test_build_query_for_snowflake_src_for_non_integer_primary_keys(mock_spark, table_conf_mock): + spark = mock_spark + sch = [Schema("s_suppkey", "varchar"), Schema("s_name", "varchar"), Schema("s_nationkey", "number")] + + sch_with_alias = [Schema("s_suppkey_t", "varchar"), Schema("s_name", "varchar"), Schema("s_nationkey_t", "number")] + df_schema = StructType( + [ + StructField('s_suppkey', StringType()), + StructField('s_name', StringType()), + StructField('s_nationkey', IntegerType()), + ] + ) + df = spark.createDataFrame( + [ + ('a', 'name-1', 11), + ('b', 'name-2', 22), + ], + schema=df_schema, + ) + + conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_nationkey", target_name='s_nationkey_t'), + ], + transformations=[Transformation(column_name="s_address", source="trim(s_address)", target="trim(s_address_t)")], + ) + + src_actual = SamplingQueryBuilder(conf, sch, "source", get_dialect("snowflake")).build_query(df) + src_expected = ( + "WITH recon AS (SELECT 11 AS s_nationkey, 'a' AS s_suppkey UNION SELECT 22 AS " + "s_nationkey, 'b' AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_name), '_null_recon_') AS s_name, " + "COALESCE(TRIM(" + "s_nationkey), '_null_recon_') AS s_nationkey, COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM " + ":tbl) " + "SELECT src.s_name, src.s_nationkey, src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(" + "src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(" + "src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + tgt_actual = SamplingQueryBuilder(conf, sch_with_alias, "target", get_dialect("databricks")).build_query(df) + tgt_expected = ( + "WITH recon AS (SELECT 11 AS s_nationkey, 'a' AS s_suppkey UNION SELECT 22 AS " + "s_nationkey, 'b' AS s_suppkey), src AS (SELECT COALESCE(TRIM(s_name), '_null_recon_') AS s_name, " + "COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, COALESCE(TRIM(s_suppkey_t), '_null_recon_') AS " + "s_suppkey FROM :tbl) " + "SELECT src.s_name, src.s_nationkey, " + "src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = " + "COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + ) + + assert src_expected == src_actual + assert tgt_expected == tgt_actual diff --git a/tests/unit/reconcile/query_builder/test_threshold_query.py b/tests/unit/reconcile/query_builder/test_threshold_query.py new file mode 100644 index 0000000000..f34cdb52ed --- /dev/null +++ b/tests/unit/reconcile/query_builder/test_threshold_query.py @@ -0,0 +1,148 @@ +import re + +import pytest + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.exception import InvalidInputException +from databricks.labs.remorph.reconcile.query_builder.threshold_query import ( + ThresholdQueryBuilder, +) +from databricks.labs.remorph.reconcile.recon_config import ( + JdbcReaderOptions, + Schema, + ColumnThresholds, + Transformation, +) + + +def test_threshold_comparison_query_with_one_threshold(table_conf_with_opts, table_schema): + # table conf + table_conf = table_conf_with_opts + # schema + table_schema, _ = table_schema + table_schema.append(Schema("s_suppdate", "timestamp")) + comparison_query = ThresholdQueryBuilder( + table_conf, table_schema, "source", get_dialect("oracle") + ).build_comparison_query() + assert re.sub(r'\s+', ' ', comparison_query.strip().lower()) == re.sub( + r'\s+', + ' ', + """select coalesce(source.s_acctbal, 0) as s_acctbal_source, coalesce(databricks.s_acctbal, + 0) as s_acctbal_databricks, case when (coalesce(source.s_acctbal, 0) - coalesce(databricks.s_acctbal, + 0)) = 0 then 'match' when (coalesce(source.s_acctbal, 0) - coalesce(databricks.s_acctbal, 0)) between 0 and + 100 then 'warning' else 'failed' end as s_acctbal_match, source.s_nationkey as s_nationkey_source, + source.s_suppkey as s_suppkey_source from source_supplier_df_threshold_vw as source inner join + target_target_supplier_df_threshold_vw as databricks on source.s_nationkey <=> databricks.s_nationkey and + source.s_suppkey <=> databricks.s_suppkey where (1 = 1 or 1 = 1) or + (coalesce(source.s_acctbal, 0) - coalesce(databricks.s_acctbal, 0)) <> 0""".strip().lower(), + ) + + +def test_threshold_comparison_query_with_dual_threshold(table_conf_with_opts, table_schema): + # table conf + table_conf = table_conf_with_opts + table_conf.join_columns = ["s_suppkey", "s_suppdate"] + table_conf.column_thresholds = [ + ColumnThresholds(column_name="s_acctbal", lower_bound="5%", upper_bound="-5%", type="float"), + ColumnThresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp"), + ] + + # schema + table_schema, _ = table_schema + table_schema.append(Schema("s_suppdate", "timestamp")) + + comparison_query = ThresholdQueryBuilder( + table_conf, table_schema, "target", get_dialect("databricks") + ).build_comparison_query() + assert re.sub(r'\s+', ' ', comparison_query.strip().lower()) == re.sub( + r'\s+', + ' ', + """select coalesce(source.s_acctbal, 0) as s_acctbal_source, coalesce(databricks.s_acctbal, + 0) as s_acctbal_databricks, case when (coalesce(source.s_acctbal, 0) - coalesce(databricks.s_acctbal, + 0)) = 0 then 'match' when ((coalesce(source.s_acctbal, 0) - coalesce(databricks.s_acctbal, + 0)) / if(databricks.s_acctbal = 0 or databricks.s_acctbal is null, 1, databricks.s_acctbal)) * 100 between 5 + and -5 then 'warning' else 'failed' end as s_acctbal_match, coalesce(unix_timestamp(source.s_suppdate), + 0) as s_suppdate_source, coalesce(unix_timestamp(databricks.s_suppdate), 0) as s_suppdate_databricks, + case when (coalesce(unix_timestamp(source.s_suppdate), 0) - coalesce(unix_timestamp(databricks.s_suppdate), + 0)) = 0 then 'match' when (coalesce(unix_timestamp(source.s_suppdate), 0) - + coalesce(unix_timestamp(databricks.s_suppdate), 0)) between -86400 and 86400 then + 'warning' else 'failed' end as s_suppdate_match, source.s_suppdate as s_suppdate_source, + source.s_suppkey as s_suppkey_source from source_supplier_df_threshold_vw as + source inner join target_target_supplier_df_threshold_vw as databricks on source.s_suppdate <=> databricks.s_suppdate and + source.s_suppkey <=> databricks.s_suppkey where (1 = 1 or 1 = 1) or (coalesce(source.s_acctbal, 0) - + coalesce(databricks.s_acctbal, 0)) <> 0 or (coalesce(unix_timestamp(source.s_suppdate), 0) - + coalesce(unix_timestamp(databricks.s_suppdate), 0)) <> 0""".strip().lower(), + ) + + +def test_build_threshold_query_with_single_threshold(table_conf_with_opts, table_schema): + table_conf = table_conf_with_opts + table_conf.jdbc_reader_options = None + table_conf.transformations = [ + Transformation(column_name="s_acctbal", source="cast(s_acctbal as number)", target="cast(s_acctbal_t as int)") + ] + src_schema, tgt_schema = table_schema + src_query = ThresholdQueryBuilder(table_conf, src_schema, "source", get_dialect("oracle")).build_threshold_query() + target_query = ThresholdQueryBuilder( + table_conf, tgt_schema, "target", get_dialect("databricks") + ).build_threshold_query() + assert src_query == ( + "SELECT s_nationkey AS s_nationkey, s_suppkey AS s_suppkey, " + "CAST(s_acctbal AS NUMBER) AS s_acctbal FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + ) + assert target_query == ( + "SELECT s_nationkey_t AS s_nationkey, s_suppkey_t AS s_suppkey, " + "CAST(s_acctbal_t AS INT) AS s_acctbal FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" + ) + + +def test_build_threshold_query_with_multiple_threshold(table_conf_with_opts, table_schema): + table_conf = table_conf_with_opts + table_conf.jdbc_reader_options = JdbcReaderOptions( + number_partitions=100, partition_column="s_phone", lower_bound="0", upper_bound="100" + ) + table_conf.column_thresholds = [ + ColumnThresholds(column_name="s_acctbal", lower_bound="5%", upper_bound="-5%", type="float"), + ColumnThresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp"), + ] + table_conf.filters = None + src_schema, tgt_schema = table_schema + src_schema.append(Schema("s_suppdate", "timestamp")) + tgt_schema.append(Schema("s_suppdate", "timestamp")) + src_query = ThresholdQueryBuilder(table_conf, src_schema, "source", get_dialect("oracle")).build_threshold_query() + target_query = ThresholdQueryBuilder( + table_conf, tgt_schema, "target", get_dialect("databricks") + ).build_threshold_query() + assert src_query == ( + "SELECT s_nationkey AS s_nationkey, TRIM(s_phone) AS s_phone, s_suppkey " + "AS s_suppkey, s_acctbal AS s_acctbal, s_suppdate AS s_suppdate FROM :tbl" + ) + assert target_query == ( + "SELECT s_nationkey_t AS s_nationkey, s_suppkey_t AS s_suppkey, " + "s_acctbal_t AS s_acctbal, s_suppdate AS s_suppdate FROM :tbl" + ) + + +def test_build_expression_type_raises_value_error(table_conf_with_opts, table_schema): + table_conf = table_conf_with_opts + table_conf.column_thresholds = [ + ColumnThresholds(column_name="s_acctbal", lower_bound="5%", upper_bound="-5%", type="unknown"), + ] + table_conf.filters = None + src_schema, tgt_schema = table_schema + src_schema.append(Schema("s_suppdate", "timestamp")) + tgt_schema.append(Schema("s_suppdate", "timestamp")) + + with pytest.raises(ValueError): + ThresholdQueryBuilder(table_conf, src_schema, "source", get_dialect("oracle")).build_comparison_query() + + +def test_test_no_join_columns_raise_exception(table_conf_with_opts, table_schema): + table_conf = table_conf_with_opts + table_conf.join_columns = None + src_schema, tgt_schema = table_schema + src_schema.append(Schema("s_suppdate", "timestamp")) + tgt_schema.append(Schema("s_suppdate", "timestamp")) + + with pytest.raises(InvalidInputException): + ThresholdQueryBuilder(table_conf, src_schema, "source", get_dialect("oracle")).build_comparison_query() diff --git a/tests/unit/reconcile/test_aggregates_recon_capture.py b/tests/unit/reconcile/test_aggregates_recon_capture.py new file mode 100644 index 0000000000..884afdcbc9 --- /dev/null +++ b/tests/unit/reconcile/test_aggregates_recon_capture.py @@ -0,0 +1,138 @@ +import datetime +from pathlib import Path + +from pyspark.sql import Row, SparkSession + +from databricks.labs.remorph.config import DatabaseConfig, get_dialect, ReconcileMetadataConfig +from databricks.labs.remorph.reconcile.recon_capture import ( + ReconCapture, +) +from databricks.labs.remorph.reconcile.recon_config import ( + ReconcileProcessDuration, + Table, + AggregateQueryOutput, +) +from databricks.labs.remorph.reconcile.recon_capture import generate_final_reconcile_aggregate_output +from .test_aggregates_reconcile import expected_reconcile_output_dict, expected_rule_output + + +def remove_directory_recursively(directory_path): + path = Path(directory_path) + if path.is_dir(): + for item in path.iterdir(): + if item.is_dir(): + remove_directory_recursively(item) + else: + item.unlink() + path.rmdir() + + +def agg_data_prep(spark: SparkSession): + table_conf = Table(source_name="supplier", target_name="target_supplier") + reconcile_process_duration = ReconcileProcessDuration( + start_ts=str(datetime.datetime.now()), end_ts=str(datetime.datetime.now()) + ) + + # Prepare output dataclasses + agg_reconcile_output = [ + AggregateQueryOutput( + rule=expected_rule_output()["count"], reconcile_output=expected_reconcile_output_dict(spark)["count"] + ), + AggregateQueryOutput( + reconcile_output=expected_reconcile_output_dict(spark)["sum"], rule=expected_rule_output()["sum"] + ), + ] + + # Drop old data + spark.sql("DROP TABLE IF EXISTS DEFAULT.main") + spark.sql("DROP TABLE IF EXISTS DEFAULT.aggregate_rules") + spark.sql("DROP TABLE IF EXISTS DEFAULT.aggregate_metrics") + spark.sql("DROP TABLE IF EXISTS DEFAULT.aggregate_details") + + # Get the warehouse location + warehouse_location = spark.conf.get("spark.sql.warehouse.dir") + + if warehouse_location and Path(warehouse_location.lstrip('file:')).exists(): + tables = ["main", "aggregate_rules", "aggregate_metrics", "aggregate_details"] + for table in tables: + remove_directory_recursively(f"{warehouse_location.lstrip('file:')}/{table}") + + return agg_reconcile_output, table_conf, reconcile_process_duration + + +def test_aggregates_reconcile_store_aggregate_metrics(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + + source_type = get_dialect("snowflake") + spark = mock_spark + agg_reconcile_output, table_conf, reconcile_process_duration = agg_data_prep(mock_spark) + + recon_id = "999fygdrs-dbb7-489f-bad1-6a7e8f4821b1" + + recon_capture = ReconCapture( + database_config, + recon_id, + "", + source_type, + mock_workspace_client, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.store_aggregates_metrics(table_conf, reconcile_process_duration, agg_reconcile_output) + + # Check if the tables are created + + # assert main table data + remorph_reconcile_df = spark.sql("select * from DEFAULT.main") + + assert remorph_reconcile_df.count() == 1 + if remorph_reconcile_df.first(): + main = remorph_reconcile_df.first().asDict() + assert main.get("recon_id") == recon_id + assert main.get("source_type") == "Snowflake" + assert not main.get("report_type") + assert main.get("operation_name") == "aggregates-reconcile" + + # assert rules data + agg_reconcile_rules_df = spark.sql("select * from DEFAULT.aggregate_rules") + + assert agg_reconcile_rules_df.count() == 2 + assert agg_reconcile_rules_df.select("rule_type").distinct().count() == 1 + if agg_reconcile_rules_df.first(): + rule = agg_reconcile_rules_df.first().asDict() + assert rule.get("rule_type") == "AGGREGATE" + assert isinstance(rule.get("rule_info"), dict) + assert rule["rule_info"].keys() == {"agg_type", "agg_column", "group_by_columns"} + + # assert metrics + agg_reconcile_metrics_df = spark.sql("select * from DEFAULT.aggregate_metrics") + + assert agg_reconcile_metrics_df.count() == 2 + if agg_reconcile_metrics_df.first(): + metric = agg_reconcile_metrics_df.first().asDict() + assert isinstance(metric.get("recon_metrics"), Row) + assert metric.get("recon_metrics").asDict().keys() == {"mismatch", "missing_in_source", "missing_in_target"} + + # assert details + agg_reconcile_details_df = spark.sql("select * from DEFAULT.aggregate_details") + + assert agg_reconcile_details_df.count() == 6 + assert agg_reconcile_details_df.select("recon_type").distinct().count() == 3 + recon_type_values = { + row["recon_type"] for row in agg_reconcile_details_df.select("recon_type").distinct().collect() + } + + assert recon_type_values == {"mismatch", "missing_in_source", "missing_in_target"} + + reconcile_output = generate_final_reconcile_aggregate_output( + recon_id=recon_id, + spark=mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + assert len(reconcile_output.results) == 1 + assert not reconcile_output.results[0].exception_message + assert reconcile_output.results[0].status.aggregate is False diff --git a/tests/unit/reconcile/test_aggregates_reconcile.py b/tests/unit/reconcile/test_aggregates_reconcile.py new file mode 100644 index 0000000000..8fd8cd45ab --- /dev/null +++ b/tests/unit/reconcile/test_aggregates_reconcile.py @@ -0,0 +1,399 @@ +import sys + +from dataclasses import dataclass +from pathlib import Path + +from unittest.mock import patch + +import pytest +from pyspark.testing import assertDataFrameEqual +from pyspark.sql import Row + +from databricks.labs.remorph.config import DatabaseConfig, ReconcileMetadataConfig, get_dialect +from databricks.labs.remorph.reconcile.connectors.data_source import MockDataSource +from databricks.labs.remorph.reconcile.execute import Reconciliation, main +from databricks.labs.remorph.reconcile.recon_config import ( + Aggregate, + AggregateQueryOutput, + DataReconcileOutput, + MismatchOutput, + AggregateRule, +) +from databricks.labs.remorph.reconcile.schema_compare import SchemaCompare + +CATALOG = "org" +SCHEMA = "data" +SRC_TABLE = "supplier" +TGT_TABLE = "target_supplier" + + +@dataclass +class AggregateQueries: + source_agg_query: str + target_agg_query: str + source_group_agg_query: str + target_group_agg_query: str + + +@dataclass +class AggregateQueryStore: + agg_queries: AggregateQueries + + +@pytest.fixture +def query_store(mock_spark): + agg_queries = AggregateQueries( + source_agg_query="SELECT min(s_acctbal) AS source_min_s_acctbal FROM :tbl WHERE s_name = 't' AND s_address = 'a'", + target_agg_query="SELECT min(s_acctbal_t) AS target_min_s_acctbal FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'", + source_group_agg_query="SELECT sum(s_acctbal) AS source_sum_s_acctbal, count(TRIM(s_name)) AS source_count_s_name, s_nationkey AS source_group_by_s_nationkey FROM :tbl WHERE s_name = 't' AND s_address = 'a' GROUP BY s_nationkey", + target_group_agg_query="SELECT sum(s_acctbal_t) AS target_sum_s_acctbal, count(TRIM(s_name)) AS target_count_s_name, s_nationkey_t AS target_group_by_s_nationkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a' GROUP BY s_nationkey_t", + ) + + return AggregateQueryStore( + agg_queries=agg_queries, + ) + + +def test_reconcile_aggregate_data_missing_records( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_conf_with_opts.aggregates = [Aggregate(type="MIN", agg_columns=["s_acctbal"])] + + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.agg_queries.source_agg_query, + ): mock_spark.createDataFrame( + [ + Row(source_min_s_acctbal=11), + ] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.agg_queries.target_agg_query, + ): mock_spark.createDataFrame( + [ + Row(target_min_s_acctbal=10), + ] + ) + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual: list[AggregateQueryOutput] = Reconciliation( + source, + target, + database_config, + "", + SchemaCompare(mock_spark), + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_aggregates(table_conf_with_opts, src_schema, tgt_schema) + + assert len(actual) == 1 + + assert actual[0].rule, "Rule must be generated" + + assert actual[0].rule.agg_type == "min" + assert actual[0].rule.agg_column == "s_acctbal" + assert actual[0].rule.group_by_columns is None + assert actual[0].rule.group_by_columns_as_str == "NA" + assert actual[0].rule.group_by_columns_as_table_column == "NULL" + assert actual[0].rule.column_from_rule == "min_s_acctbal_NA" + assert actual[0].rule.rule_type == "AGGREGATE" + + assert actual[0].reconcile_output.mismatch.mismatch_df, "Mismatch dataframe must be present" + assert not actual[0].reconcile_output.mismatch.mismatch_df.isEmpty() + + expected = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=0, + missing_in_tgt_count=0, + mismatch=MismatchOutput( + mismatch_columns=None, + mismatch_df=mock_spark.createDataFrame( + [ + Row( + source_min_s_acctbal=11, + target_min_s_acctbal=10, + match_min_s_acctbal=False, + agg_data_match=False, + ) + ] + ), + ), + ) + + assert actual[0].reconcile_output.mismatch_count == expected.mismatch_count + assert actual[0].reconcile_output.missing_in_src_count == expected.missing_in_src_count + assert actual[0].reconcile_output.missing_in_tgt_count == expected.missing_in_tgt_count + assertDataFrameEqual(actual[0].reconcile_output.mismatch.mismatch_df, expected.mismatch.mismatch_df) + + +def expected_rule_output(): + count_rule_output = AggregateRule( + agg_type="count", + agg_column="s_name", + group_by_columns=["s_nationkey"], + group_by_columns_as_str="s_nationkey", + ) + + sum_rule_output = AggregateRule( + agg_type="sum", + agg_column="s_acctbal", + group_by_columns=["s_nationkey"], + group_by_columns_as_str="s_nationkey", + ) + + return {"count": count_rule_output, "sum": sum_rule_output} + + +def expected_reconcile_output_dict(spark): + count_reconcile_output = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=1, + missing_in_tgt_count=1, + mismatch=MismatchOutput( + mismatch_columns=None, + mismatch_df=spark.createDataFrame( + [ + Row( + source_count_s_name=11, + target_count_s_name=9, + source_group_by_s_nationkey=12, + target_group_by_s_nationkey=12, + match_count_s_name=False, + match_group_by_s_nationkey=True, + agg_data_match=False, + ) + ] + ), + ), + missing_in_src=spark.createDataFrame([Row(target_count_s_name=76, target_group_by_s_nationkey=14)]), + missing_in_tgt=spark.createDataFrame([Row(source_count_s_name=21, source_group_by_s_nationkey=13)]), + ) + + sum_reconcile_output = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=1, + missing_in_tgt_count=1, + mismatch=MismatchOutput( + mismatch_columns=None, + mismatch_df=spark.createDataFrame( + [ + Row( + source_sum_s_acctbal=23, + target_sum_s_acctbal=43, + source_group_by_s_nationkey=12, + target_group_by_s_nationkey=12, + match_sum_s_acctbal=False, + match_group_by_s_nationkey=True, + agg_data_match=False, + ) + ] + ), + ), + missing_in_src=spark.createDataFrame([Row(target_sum_s_acctbal=348, target_group_by_s_nationkey=14)]), + missing_in_tgt=spark.createDataFrame([Row(source_sum_s_acctbal=112, source_group_by_s_nationkey=13)]), + ) + + return {"count": count_reconcile_output, "sum": sum_reconcile_output} + + +def _compare_reconcile_output(actual_reconcile_output: DataReconcileOutput, expected_reconcile: DataReconcileOutput): + # Reconcile Output validations + if actual_reconcile_output and expected_reconcile: + assert actual_reconcile_output.mismatch.mismatch_df, "Mismatch dataframe must be present" + assert actual_reconcile_output.missing_in_src, "Missing in source one record must be present" + assert actual_reconcile_output.missing_in_tgt, "Missing in target one record must be present" + + assert actual_reconcile_output.mismatch_count == expected_reconcile.mismatch_count + assert actual_reconcile_output.missing_in_src_count == expected_reconcile.missing_in_src_count + assert actual_reconcile_output.missing_in_tgt_count == expected_reconcile.missing_in_tgt_count + + if actual_reconcile_output.mismatch.mismatch_df and expected_reconcile.mismatch.mismatch_df: + mismatch_df_columns = actual_reconcile_output.mismatch.mismatch_df.columns + assertDataFrameEqual( + actual_reconcile_output.mismatch.mismatch_df.select(*mismatch_df_columns), + expected_reconcile.mismatch.mismatch_df.select(*mismatch_df_columns), + ) + + if actual_reconcile_output.missing_in_src and expected_reconcile.missing_in_src: + missing_in_src_columns = actual_reconcile_output.missing_in_src.columns + assertDataFrameEqual( + actual_reconcile_output.missing_in_src.select(*missing_in_src_columns), + expected_reconcile.missing_in_src.select(*missing_in_src_columns), + ) + + if actual_reconcile_output.missing_in_tgt and expected_reconcile.missing_in_tgt: + missing_in_tgt_columns = actual_reconcile_output.missing_in_tgt.columns + assert ( + actual_reconcile_output.missing_in_tgt.select(*missing_in_tgt_columns).first() + == expected_reconcile.missing_in_tgt.select(*missing_in_tgt_columns).first() + ) + + +def test_reconcile_aggregate_data_mismatch_and_missing_records( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_conf_with_opts.aggregates = [ + Aggregate(type="SUM", agg_columns=["s_acctbal"], group_by_columns=["s_nationkey"]), + Aggregate(type="COUNT", agg_columns=["s_name"], group_by_columns=["s_nationkey"]), + ] + + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.agg_queries.source_group_agg_query, + ): mock_spark.createDataFrame( + [ + Row(source_sum_s_acctbal=101, source_count_s_name=13, source_group_by_s_nationkey=11), + Row(source_sum_s_acctbal=23, source_count_s_name=11, source_group_by_s_nationkey=12), + Row(source_sum_s_acctbal=112, source_count_s_name=21, source_group_by_s_nationkey=13), + ] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.agg_queries.target_group_agg_query, + ): mock_spark.createDataFrame( + [ + Row(target_sum_s_acctbal=101, target_count_s_name=13, target_group_by_s_nationkey=11), + Row(target_sum_s_acctbal=43, target_count_s_name=9, target_group_by_s_nationkey=12), + Row(target_sum_s_acctbal=348, target_count_s_name=76, target_group_by_s_nationkey=14), + ] + ) + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + db_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual_list: list[AggregateQueryOutput] = Reconciliation( + source, + MockDataSource(target_dataframe_repository, target_schema_repository), + db_config, + "", + SchemaCompare(mock_spark), + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_aggregates(table_conf_with_opts, src_schema, tgt_schema) + + assert len(actual_list) == 2 + + for actual in actual_list: + assert actual.rule, "Rule must be generated" + expected_rule = expected_rule_output().get(actual.rule.agg_type) + assert expected_rule, "Rule must be defined in expected" + + # Rule validations + assert actual.rule.agg_type == expected_rule.agg_type + assert actual.rule.agg_column == expected_rule.agg_column + assert actual.rule.group_by_columns == expected_rule.group_by_columns + assert actual.rule.group_by_columns_as_str == expected_rule.group_by_columns_as_str + assert actual.rule.group_by_columns_as_table_column == expected_rule.group_by_columns_as_table_column + assert ( + actual.rule.column_from_rule + == f"{expected_rule.agg_type}_{expected_rule.agg_column}_{expected_rule.group_by_columns_as_str}" + ) + assert actual.rule.rule_type == "AGGREGATE" + + # Reconcile Output validations + _compare_reconcile_output( + actual.reconcile_output, expected_reconcile_output_dict(mock_spark).get(actual.rule.agg_type) + ) + + +def test_run_with_invalid_operation_name(monkeypatch): + test_args = ["databricks_labs_remorph", "invalid-operation"] + monkeypatch.setattr(sys, 'argv', test_args) + with pytest.raises(AssertionError, match="Invalid option:"): + main() + + +def test_aggregates_reconcile_invalid_aggregates(): + invalid_agg_type_message = "Invalid aggregate type: std, only .* are supported." + with pytest.raises(AssertionError, match=invalid_agg_type_message): + Aggregate(agg_columns=["discount"], group_by_columns=["p_id"], type="STD") + + +def test_aggregates_reconcile_aggregate_columns(): + agg = Aggregate(agg_columns=["discount", "price"], group_by_columns=["p_dept_id", "p_sub_dept"], type="STDDEV") + + assert agg.get_agg_type() == "stddev" + assert agg.group_by_columns_as_str == "p_dept_id+__+p_sub_dept" + assert agg.agg_columns_as_str == "discount+__+price" + + agg1 = Aggregate(agg_columns=["discount"], type="MAX") + assert agg1.get_agg_type() == "max" + assert agg1.group_by_columns_as_str == "NA" + assert agg1.agg_columns_as_str == "discount" + + +def test_aggregates_reconcile_aggregate_rule(): + agg_rule = AggregateRule( + agg_column="discount", + group_by_columns=["p_dept_id", "p_sub_dept"], + group_by_columns_as_str="p_dept_id+__+p_sub_dept", + agg_type="stddev", + ) + + assert agg_rule.column_from_rule == "stddev_discount_p_dept_id+__+p_sub_dept" + assert agg_rule.group_by_columns_as_table_column == "\"p_dept_id, p_sub_dept\"" + expected_rule_query = """ SELECT 1234 as rule_id, 'AGGREGATE' as rule_type, map( 'agg_type', 'stddev', + 'agg_column', 'discount', + 'group_by_columns', "p_dept_id, p_sub_dept" + ) + as rule_info """ + assert agg_rule.get_rule_query(1234) == expected_rule_query + + +agg_rule1 = AggregateRule(agg_column="discount", group_by_columns=None, group_by_columns_as_str="NA", agg_type="max") +assert agg_rule1.column_from_rule == "max_discount_NA" +assert agg_rule1.group_by_columns_as_table_column == "NULL" +EXPECTED_RULE1_QUERY = """ SELECT 1234 as rule_id, 'AGGREGATE' as rule_type, map( 'agg_type', 'max', + 'agg_column', 'discount', + 'group_by_columns', NULL + ) + as rule_info """ +assert agg_rule1.get_rule_query(1234) == EXPECTED_RULE1_QUERY diff --git a/tests/unit/reconcile/test_compare.py b/tests/unit/reconcile/test_compare.py new file mode 100644 index 0000000000..369c234861 --- /dev/null +++ b/tests/unit/reconcile/test_compare.py @@ -0,0 +1,208 @@ +from pathlib import Path +import pytest +from pyspark import Row +from pyspark.testing import assertDataFrameEqual + +from databricks.labs.remorph.reconcile.compare import ( + alias_column_str, + capture_mismatch_data_and_columns, + reconcile_data, +) +from databricks.labs.remorph.reconcile.exception import ColumnMismatchException +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + MismatchOutput, +) + + +def test_compare_data_for_report_all( + mock_spark, + tmp_path: Path, +): + source = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, hash_value_recon='1a1'), + Row(s_suppkey=2, s_nationkey=22, hash_value_recon='2b2'), + Row(s_suppkey=3, s_nationkey=33, hash_value_recon='3c3'), + Row(s_suppkey=5, s_nationkey=55, hash_value_recon='5e5'), + ] + ) + target = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, hash_value_recon='1a1'), + Row(s_suppkey=2, s_nationkey=22, hash_value_recon='2b4'), + Row(s_suppkey=4, s_nationkey=44, hash_value_recon='4d4'), + Row(s_suppkey=5, s_nationkey=56, hash_value_recon='5e6'), + ] + ) + + mismatch = MismatchOutput(mismatch_df=mock_spark.createDataFrame([Row(s_suppkey=2, s_nationkey=22)])) + missing_in_src = mock_spark.createDataFrame([Row(s_suppkey=4, s_nationkey=44), Row(s_suppkey=5, s_nationkey=56)]) + missing_in_tgt = mock_spark.createDataFrame([Row(s_suppkey=3, s_nationkey=33), Row(s_suppkey=5, s_nationkey=55)]) + + actual = reconcile_data( + source=source, + target=target, + key_columns=["s_suppkey", "s_nationkey"], + report_type="all", + spark=mock_spark, + path=str(tmp_path), + ) + expected = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=1, + missing_in_tgt_count=1, + missing_in_src=missing_in_src, + missing_in_tgt=missing_in_tgt, + mismatch=mismatch, + ) + + assertDataFrameEqual(actual.mismatch.mismatch_df, expected.mismatch.mismatch_df) + assertDataFrameEqual(actual.missing_in_src, expected.missing_in_src) + assertDataFrameEqual(actual.missing_in_tgt, expected.missing_in_tgt) + + +def test_compare_data_for_report_hash(mock_spark, tmp_path: Path): + source = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, hash_value_recon='1a1'), + Row(s_suppkey=2, s_nationkey=22, hash_value_recon='2b2'), + Row(s_suppkey=3, s_nationkey=33, hash_value_recon='3c3'), + Row(s_suppkey=5, s_nationkey=55, hash_value_recon='5e5'), + ] + ) + target = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, hash_value_recon='1a1'), + Row(s_suppkey=2, s_nationkey=22, hash_value_recon='2b4'), + Row(s_suppkey=4, s_nationkey=44, hash_value_recon='4d4'), + Row(s_suppkey=5, s_nationkey=56, hash_value_recon='5e6'), + ] + ) + + missing_in_src = mock_spark.createDataFrame( + [Row(s_suppkey=2, s_nationkey=22), Row(s_suppkey=4, s_nationkey=44), Row(s_suppkey=5, s_nationkey=56)] + ) + missing_in_tgt = mock_spark.createDataFrame( + [Row(s_suppkey=2, s_nationkey=22), Row(s_suppkey=3, s_nationkey=33), Row(s_suppkey=5, s_nationkey=55)] + ) + + actual = reconcile_data( + source=source, + target=target, + key_columns=["s_suppkey", "s_nationkey"], + report_type="hash", + spark=mock_spark, + path=str(tmp_path), + ) + expected = DataReconcileOutput( + missing_in_src=missing_in_src, + missing_in_tgt=missing_in_tgt, + mismatch=MismatchOutput(), + mismatch_count=0, + missing_in_src_count=1, + missing_in_tgt_count=1, + ) + + assert actual.mismatch.mismatch_df is None + assert not actual.mismatch.mismatch_columns + assertDataFrameEqual(actual.missing_in_src, expected.missing_in_src) + assertDataFrameEqual(actual.missing_in_tgt, expected.missing_in_tgt) + + +def test_capture_mismatch_data_and_cols(mock_spark): + source = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, s_name='supp-1', s_address='a-1', s_phone='ph-1', s_acctbal=100), + Row(s_suppkey=2, s_nationkey=22, s_name='supp-22', s_address='a-2', s_phone='ph-2', s_acctbal=200), + Row(s_suppkey=3, s_nationkey=33, s_name='supp-3', s_address='a-3', s_phone='ph-3', s_acctbal=300), + Row(s_suppkey=5, s_nationkey=55, s_name='supp-5', s_address='a-5', s_phone='ph-5', s_acctbal=400), + ] + ) + target = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, s_name='supp-1', s_address='a-1', s_phone='ph-1', s_acctbal=100), + Row(s_suppkey=2, s_nationkey=22, s_name='supp-2', s_address='a-2', s_phone='ph-2', s_acctbal=2000), + Row(s_suppkey=3, s_nationkey=33, s_name='supp-33', s_address='a-3', s_phone='ph-3', s_acctbal=300), + Row(s_suppkey=4, s_nationkey=44, s_name='supp-4', s_address='a-4', s_phone='ph-4', s_acctbal=400), + ] + ) + + actual = capture_mismatch_data_and_columns(source=source, target=target, key_columns=["s_suppkey", "s_nationkey"]) + + expected_df = mock_spark.createDataFrame( + [ + Row( + s_suppkey=2, + s_nationkey=22, + s_acctbal_base=200, + s_acctbal_compare=2000, + s_acctbal_match=False, + s_address_base='a-2', + s_address_compare='a-2', + s_address_match=True, + s_name_base='supp-22', + s_name_compare='supp-2', + s_name_match=False, + s_phone_base='ph-2', + s_phone_compare='ph-2', + s_phone_match=True, + ), + Row( + s_suppkey=3, + s_nationkey=33, + s_acctbal_base=300, + s_acctbal_compare=300, + s_acctbal_match=True, + s_address_base='a-3', + s_address_compare='a-3', + s_address_match=True, + s_name_base='supp-3', + s_name_compare='supp-33', + s_name_match=False, + s_phone_base='ph-3', + s_phone_compare='ph-3', + s_phone_match=True, + ), + ] + ) + + assertDataFrameEqual(actual.mismatch_df, expected_df) + assert sorted(actual.mismatch_columns) == ['s_acctbal', 's_name'] + + +def test_capture_mismatch_data_and_cols_fail(mock_spark): + source = mock_spark.createDataFrame( + [ + Row(s_suppkey=1, s_nationkey=11, s_name='supp-1', s_address='a-1', s_phone='ph-1', s_acctbal=100), + Row(s_suppkey=2, s_nationkey=22, s_name='supp-22', s_address='a-2', s_phone='ph-2', s_acctbal=200), + Row(s_suppkey=3, s_nationkey=33, s_name='supp-3', s_address='a-3', s_phone='ph-3', s_acctbal=300), + Row(s_suppkey=5, s_nationkey=55, s_name='supp-5', s_address='a-5', s_phone='ph-5', s_acctbal=400), + ] + ) + target = mock_spark.createDataFrame( + [ + Row(s_suppkey=1), + Row(s_suppkey=2), + Row(s_suppkey=3), + Row(s_suppkey=4), + ] + ) + + with pytest.raises(ColumnMismatchException) as exception: + capture_mismatch_data_and_columns(source=source, target=target, key_columns=["s_suppkey"]) + + assert str(exception.value) == ( + "source and target should have same columns for capturing the mismatch data\n" + "columns missing in source: None\n" + "columns missing in target: s_nationkey,s_name,s_address,s_phone,s_acctbal\n" + ) + + +def test_alias_column_str(): + column_list = ['col1', 'col2', 'col3'] + alias = 'source' + actual = alias_column_str(alias=alias, columns=column_list) + expected = ['source.col1', 'source.col2', 'source.col3'] + + assert actual == expected diff --git a/tests/unit/reconcile/test_execute.py b/tests/unit/reconcile/test_execute.py new file mode 100644 index 0000000000..2a328e7727 --- /dev/null +++ b/tests/unit/reconcile/test_execute.py @@ -0,0 +1,1976 @@ +from pathlib import Path +from dataclasses import dataclass +from datetime import datetime +from unittest.mock import patch, MagicMock + +import pytest +from pyspark import Row +from pyspark.errors import PySparkException +from pyspark.testing import assertDataFrameEqual + +from databricks.labs.remorph.config import ( + DatabaseConfig, + TableRecon, + get_dialect, + ReconcileMetadataConfig, + ReconcileConfig, +) +from databricks.labs.remorph.reconcile.connectors.data_source import MockDataSource +from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource +from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource +from databricks.labs.remorph.reconcile.exception import ( + DataSourceRuntimeException, + InvalidInputException, + ReconciliationException, +) +from databricks.labs.remorph.reconcile.execute import ( + Reconciliation, + initialise_data_source, + recon, + generate_volume_path, +) +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + MismatchOutput, + ThresholdOutput, + ReconcileOutput, + ReconcileTableOutput, + StatusOutput, +) +from databricks.labs.remorph.reconcile.schema_compare import SchemaCompare + +CATALOG = "org" +SCHEMA = "data" +SRC_TABLE = "supplier" +TGT_TABLE = "target_supplier" + + +@dataclass +class HashQueries: + source_hash_query: str + target_hash_query: str + + +@dataclass +class MismatchQueries: + source_mismatch_query: str + target_mismatch_query: str + + +@dataclass +class MissingQueries: + source_missing_query: str + target_missing_query: str + + +@dataclass +class ThresholdQueries: + source_threshold_query: str + target_threshold_query: str + threshold_comparison_query: str + + +@dataclass +class RowQueries: + source_row_query: str + target_row_query: str + + +@dataclass +class RecordCountQueries: + source_record_count_query: str + target_record_count_query: str + + +@dataclass +class QueryStore: + hash_queries: HashQueries + mismatch_queries: MismatchQueries + missing_queries: MissingQueries + threshold_queries: ThresholdQueries + row_queries: RowQueries + record_count_queries: RecordCountQueries + + +@pytest.fixture +def setup_metadata_table(mock_spark, report_tables_schema): + recon_schema, metrics_schema, details_schema = report_tables_schema + mode = "overwrite" + mock_spark.createDataFrame(data=[], schema=recon_schema).write.mode(mode).saveAsTable("DEFAULT.MAIN") + mock_spark.createDataFrame(data=[], schema=metrics_schema).write.mode(mode).saveAsTable("DEFAULT.METRICS") + mock_spark.createDataFrame(data=[], schema=details_schema).write.mode(mode).saveAsTable("DEFAULT.DETAILS") + + +@pytest.fixture +def query_store(mock_spark): + source_hash_query = "SELECT LOWER(SHA2(CONCAT(TRIM(s_address), TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), TRIM(s_phone), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, s_nationkey AS s_nationkey, s_suppkey AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + target_hash_query = "SELECT LOWER(SHA2(CONCAT(TRIM(s_address_t), TRIM(s_name), COALESCE(TRIM(s_nationkey_t), '_null_recon_'), TRIM(s_phone_t), COALESCE(TRIM(s_suppkey_t), '_null_recon_')), 256)) AS hash_value_recon, s_nationkey_t AS s_nationkey, s_suppkey_t AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" + source_mismatch_query = "WITH recon AS (SELECT 22 AS s_nationkey, 2 AS s_suppkey), src AS (SELECT TRIM(s_address) AS s_address, TRIM(s_name) AS s_name, COALESCE(TRIM(s_nationkey), '_null_recon_') AS s_nationkey, TRIM(s_phone) AS s_phone, COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address = 'a') SELECT src.s_address, src.s_name, src.s_nationkey, src.s_phone, src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + target_mismatch_query = "WITH recon AS (SELECT 22 AS s_nationkey, 2 AS s_suppkey), src AS (SELECT TRIM(s_address_t) AS s_address, TRIM(s_name) AS s_name, COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, TRIM(s_phone_t) AS s_phone, COALESCE(TRIM(s_suppkey_t), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a') SELECT src.s_address, src.s_name, src.s_nationkey, src.s_phone, src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + source_missing_query = "WITH recon AS (SELECT 44 AS s_nationkey, 4 AS s_suppkey), src AS (SELECT TRIM(s_address_t) AS s_address, TRIM(s_name) AS s_name, COALESCE(TRIM(s_nationkey_t), '_null_recon_') AS s_nationkey, TRIM(s_phone_t) AS s_phone, COALESCE(TRIM(s_suppkey_t), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a') SELECT src.s_address, src.s_name, src.s_nationkey, src.s_phone, src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + target_missing_query = "WITH recon AS (SELECT 33 AS s_nationkey, 3 AS s_suppkey), src AS (SELECT TRIM(s_address) AS s_address, TRIM(s_name) AS s_name, COALESCE(TRIM(s_nationkey), '_null_recon_') AS s_nationkey, TRIM(s_phone) AS s_phone, COALESCE(TRIM(s_suppkey), '_null_recon_') AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address = 'a') SELECT src.s_address, src.s_name, src.s_nationkey, src.s_phone, src.s_suppkey FROM src INNER JOIN recon AS recon ON COALESCE(TRIM(src.s_nationkey), '_null_recon_') = COALESCE(TRIM(recon.s_nationkey), '_null_recon_') AND COALESCE(TRIM(src.s_suppkey), '_null_recon_') = COALESCE(TRIM(recon.s_suppkey), '_null_recon_')" + source_threshold_query = "SELECT s_nationkey AS s_nationkey, s_suppkey AS s_suppkey, s_acctbal AS s_acctbal FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + target_threshold_query = "SELECT s_nationkey_t AS s_nationkey, s_suppkey_t AS s_suppkey, s_acctbal_t AS s_acctbal FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" + threshold_comparison_query = "SELECT COALESCE(source.s_acctbal, 0) AS s_acctbal_source, COALESCE(databricks.s_acctbal, 0) AS s_acctbal_databricks, CASE WHEN (COALESCE(source.s_acctbal, 0) - COALESCE(databricks.s_acctbal, 0)) = 0 THEN 'Match' WHEN (COALESCE(source.s_acctbal, 0) - COALESCE(databricks.s_acctbal, 0)) BETWEEN 0 AND 100 THEN 'Warning' ELSE 'Failed' END AS s_acctbal_match, source.s_nationkey AS s_nationkey_source, source.s_suppkey AS s_suppkey_source FROM source_supplier_df_threshold_vw AS source INNER JOIN target_target_supplier_df_threshold_vw AS databricks ON source.s_nationkey <=> databricks.s_nationkey AND source.s_suppkey <=> databricks.s_suppkey WHERE (1 = 1 OR 1 = 1) OR (COALESCE(source.s_acctbal, 0) - COALESCE(databricks.s_acctbal, 0)) <> 0" + source_row_query = "SELECT LOWER(SHA2(CONCAT(TRIM(s_address), TRIM(s_name), COALESCE(TRIM(s_nationkey), '_null_recon_'), TRIM(s_phone), COALESCE(TRIM(s_suppkey), '_null_recon_')), 256)) AS hash_value_recon, TRIM(s_address) AS s_address, TRIM(s_name) AS s_name, s_nationkey AS s_nationkey, TRIM(s_phone) AS s_phone, s_suppkey AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address = 'a'" + target_row_query = "SELECT LOWER(SHA2(CONCAT(TRIM(s_address_t), TRIM(s_name), COALESCE(TRIM(s_nationkey_t), '_null_recon_'), TRIM(s_phone_t), COALESCE(TRIM(s_suppkey_t), '_null_recon_')), 256)) AS hash_value_recon, TRIM(s_address_t) AS s_address, TRIM(s_name) AS s_name, s_nationkey_t AS s_nationkey, TRIM(s_phone_t) AS s_phone, s_suppkey_t AS s_suppkey FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'" + + hash_queries = HashQueries( + source_hash_query=source_hash_query, + target_hash_query=target_hash_query, + ) + mismatch_queries = MismatchQueries( + source_mismatch_query=source_mismatch_query, + target_mismatch_query=target_mismatch_query, + ) + missing_queries = MissingQueries( + source_missing_query=source_missing_query, + target_missing_query=target_missing_query, + ) + threshold_queries = ThresholdQueries( + source_threshold_query=source_threshold_query, + target_threshold_query=target_threshold_query, + threshold_comparison_query=threshold_comparison_query, + ) + row_queries = RowQueries( + source_row_query=source_row_query, + target_row_query=target_row_query, + ) + record_count_queries = RecordCountQueries( + source_record_count_query="SELECT COUNT(1) AS count FROM :tbl WHERE s_name = 't' AND s_address = 'a'", + target_record_count_query="SELECT COUNT(1) AS count FROM :tbl WHERE s_name = 't' AND s_address_t = 'a'", + ) + + return QueryStore( + hash_queries=hash_queries, + mismatch_queries=mismatch_queries, + missing_queries=missing_queries, + threshold_queries=threshold_queries, + row_queries=row_queries, + record_count_queries=record_count_queries, + ) + + +def test_reconcile_data_with_mismatches_and_missing( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="e3g", s_nationkey=33, s_suppkey=3), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.source_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-2", s_name="name-2", s_nationkey=22, s_phone="222-2", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.target_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.source_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=100)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2de", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="k4l", s_nationkey=44, s_suppkey=4), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.target_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-22", s_name="name-2", s_nationkey=22, s_phone="222", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.source_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.target_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=210)] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.threshold_comparison_query): mock_spark.createDataFrame( + [ + Row( + s_acctbal_source=100, + s_acctbal_databricks=210, + s_acctbal_match="Failed", + s_nationkey_source=11, + s_suppkey_source=1, + ) + ] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + schema_comparator = SchemaCompare(mock_spark) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual_data_reconcile = Reconciliation( + source, + target, + database_config, + "data", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_data(table_conf_with_opts, src_schema, tgt_schema) + expected_data_reconcile = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=1, + missing_in_tgt_count=1, + missing_in_src=mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + missing_in_tgt=mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + mismatch=MismatchOutput( + mismatch_df=mock_spark.createDataFrame( + [ + Row( + s_suppkey=2, + s_nationkey=22, + s_address_base="address-2", + s_address_compare="address-22", + s_address_match=False, + s_name_base="name-2", + s_name_compare="name-2", + s_name_match=True, + s_phone_base="222-2", + s_phone_compare="222", + s_phone_match=False, + ) + ] + ), + mismatch_columns=["s_address", "s_phone"], + ), + threshold_output=ThresholdOutput( + threshold_df=mock_spark.createDataFrame( + [ + Row( + s_acctbal_source=100, + s_acctbal_databricks=210, + s_acctbal_match="Failed", + s_nationkey_source=11, + s_suppkey_source=1, + ) + ] + ), + threshold_mismatch_count=1, + ), + ) + + assert actual_data_reconcile.mismatch_count == expected_data_reconcile.mismatch_count + assert actual_data_reconcile.missing_in_src_count == expected_data_reconcile.missing_in_src_count + assert actual_data_reconcile.missing_in_tgt_count == expected_data_reconcile.missing_in_tgt_count + assert actual_data_reconcile.mismatch.mismatch_columns == expected_data_reconcile.mismatch.mismatch_columns + + assertDataFrameEqual(actual_data_reconcile.mismatch.mismatch_df, expected_data_reconcile.mismatch.mismatch_df) + assertDataFrameEqual(actual_data_reconcile.missing_in_src, expected_data_reconcile.missing_in_src) + assertDataFrameEqual(actual_data_reconcile.missing_in_tgt, expected_data_reconcile.missing_in_tgt) + + actual_schema_reconcile = Reconciliation( + source, + target, + database_config, + "data", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_schema(src_schema, tgt_schema, table_conf_with_opts) + expected_schema_reconcile = mock_spark.createDataFrame( + [ + Row( + source_column="s_suppkey", + source_datatype="number", + databricks_column="s_suppkey_t", + databricks_datatype="number", + is_valid=True, + ), + Row( + source_column="s_name", + source_datatype="varchar", + databricks_column="s_name", + databricks_datatype="varchar", + is_valid=True, + ), + Row( + source_column="s_address", + source_datatype="varchar", + databricks_column="s_address_t", + databricks_datatype="varchar", + is_valid=True, + ), + Row( + source_column="s_nationkey", + source_datatype="number", + databricks_column="s_nationkey_t", + databricks_datatype="number", + is_valid=True, + ), + Row( + source_column="s_phone", + source_datatype="varchar", + databricks_column="s_phone_t", + databricks_datatype="varchar", + is_valid=True, + ), + Row( + source_column="s_acctbal", + source_datatype="number", + databricks_column="s_acctbal_t", + databricks_datatype="number", + is_valid=True, + ), + ] + ) + assertDataFrameEqual(actual_schema_reconcile.compare_df, expected_schema_reconcile) + assert actual_schema_reconcile.is_valid is True + + assertDataFrameEqual( + actual_data_reconcile.threshold_output.threshold_df, + mock_spark.createDataFrame( + [ + Row( + s_acctbal_source=100, + s_acctbal_databricks=210, + s_acctbal_match="Failed", + s_nationkey_source=11, + s_suppkey_source=1, + ) + ] + ), + ) + assert actual_data_reconcile.threshold_output.threshold_mismatch_count == 1 + + +def test_reconcile_data_without_mismatches_and_missing( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.source_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=100)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.target_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=110)] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.threshold_comparison_query): mock_spark.createDataFrame( + [ + Row( + s_acctbal_source=100, + s_acctbal_databricks=110, + s_acctbal_match="Warning", + s_nationkey_source=11, + s_suppkey_source=1, + ) + ] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + schema_comparator = SchemaCompare(mock_spark) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual = Reconciliation( + source, + target, + database_config, + "data", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_data(table_conf_with_opts, src_schema, tgt_schema) + + assert actual.mismatch_count == 0 + assert actual.missing_in_src_count == 0 + assert actual.missing_in_tgt_count == 0 + assert actual.mismatch is None + assert actual.missing_in_src is None + assert actual.missing_in_tgt is None + assert actual.threshold_output.threshold_df is None + assert actual.threshold_output.threshold_mismatch_count == 0 + + +def test_reconcile_data_with_mismatch_and_no_missing( + mock_spark, table_conf_with_opts, table_schema, query_store, tmp_path: Path +): + src_schema, tgt_schema = table_schema + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.source_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-2", s_name="name-2", s_nationkey=22, s_phone="222-2", s_suppkey=2)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2de", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.target_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-22", s_name="name-2", s_nationkey=22, s_phone="222", s_suppkey=2)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + schema_comparator = SchemaCompare(mock_spark) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual = Reconciliation( + source, + target, + database_config, + "data", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_data(table_conf_with_opts, src_schema, tgt_schema) + expected = DataReconcileOutput( + mismatch_count=1, + missing_in_src_count=0, + missing_in_tgt_count=0, + missing_in_src=None, + missing_in_tgt=None, + mismatch=MismatchOutput( + mismatch_df=mock_spark.createDataFrame( + [ + Row( + s_suppkey=2, + s_nationkey=22, + s_address_base="address-2", + s_address_compare="address-22", + s_address_match=False, + s_name_base="name-2", + s_name_compare="name-2", + s_name_match=True, + s_phone_base="222-2", + s_phone_compare="222", + s_phone_match=False, + ) + ] + ), + mismatch_columns=["s_address", "s_phone"], + ), + ) + + assert actual.mismatch_count == expected.mismatch_count + assert actual.missing_in_src_count == expected.missing_in_src_count + assert actual.missing_in_tgt_count == expected.missing_in_tgt_count + assert actual.mismatch.mismatch_columns == expected.mismatch.mismatch_columns + assert actual.missing_in_src is None + assert actual.missing_in_tgt is None + + assertDataFrameEqual(actual.mismatch.mismatch_df, expected.mismatch.mismatch_df) + + +def test_reconcile_data_missing_and_no_mismatch( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="e3g", s_nationkey=33, s_suppkey=3), + ] + ), + (CATALOG, SCHEMA, query_store.missing_queries.target_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="k4l", s_nationkey=44, s_suppkey=4), + ] + ), + (CATALOG, SCHEMA, query_store.missing_queries.source_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + schema_comparator = SchemaCompare(mock_spark) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual = Reconciliation( + source, + target, + database_config, + "data", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_data(table_conf_with_opts, src_schema, tgt_schema) + expected = DataReconcileOutput( + mismatch_count=0, + missing_in_src_count=1, + missing_in_tgt_count=1, + missing_in_src=mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + missing_in_tgt=mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + mismatch=MismatchOutput(), + ) + + assert actual.mismatch_count == expected.mismatch_count + assert actual.missing_in_src_count == expected.missing_in_src_count + assert actual.missing_in_tgt_count == expected.missing_in_tgt_count + assert actual.mismatch is None + + assertDataFrameEqual(actual.missing_in_src, expected.missing_in_src) + assertDataFrameEqual(actual.missing_in_tgt, expected.missing_in_tgt) + + +@pytest.fixture +def mock_for_report_type_data( + table_conf_with_opts, + table_schema, + query_store, + setup_metadata_table, + mock_spark, +): + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_recon = TableRecon( + source_catalog="org", + source_schema="data", + target_catalog="org", + target_schema="data", + tables=[table_conf_with_opts], + ) + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="e3g", s_nationkey=33, s_suppkey=3), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.source_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-2", s_name="name-2", s_nationkey=22, s_phone="222-2", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.target_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.source_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2de", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="k4l", s_nationkey=44, s_suppkey=4), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.target_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-22", s_name="name-2", s_nationkey=22, s_phone="222", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.source_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.target_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + + reconcile_config_data = ReconcileConfig( + data_source="databricks", + report_type="data", + secret_scope="remorph_databricks", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(schema="default"), + ) + + return table_recon, source, target, reconcile_config_data + + +def test_recon_for_report_type_is_data( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_data, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config_data = mock_for_report_type_data + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=11111 + ), + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + with pytest.raises(ReconciliationException) as exc_info: + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config_data, local_test_run=True) + if exc_info.value.reconcile_output is not None: + assert exc_info.value.reconcile_output.recon_id == "00112233-4455-6677-8899-aabbccddeeff" + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 11111, + "00112233-4455-6677-8899-aabbccddeeff", + "Databricks", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "data", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 11111, + ((1, 1), (1, 0, "s_address,s_phone"), None), + (False, "remorph", ""), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame( + data=[ + ( + 11111, + "mismatch", + False, + [ + { + "s_suppkey": "2", + "s_nationkey": "22", + "s_address_base": "address-2", + "s_address_compare": "address-22", + "s_address_match": "false", + "s_name_base": "name-2", + "s_name_compare": "name-2", + "s_name_match": "true", + "s_phone_base": "222-2", + "s_phone_compare": "222", + "s_phone_match": "false", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 11111, + "missing_in_source", + False, + [ + { + "s_address": "address-4", + "s_name": "name-4", + "s_nationkey": "44", + "s_phone": "444", + "s_suppkey": "4", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 11111, + "missing_in_target", + False, + [ + { + "s_address": "address-3", + "s_name": "name-3", + "s_nationkey": "33", + "s_phone": "333", + "s_suppkey": "3", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ], + schema=details_schema, + ) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +@pytest.fixture +def mock_for_report_type_schema(table_conf_with_opts, table_schema, query_store, mock_spark, setup_metadata_table): + table_recon = TableRecon( + source_catalog="org", + source_schema="data", + target_catalog="org", + target_schema="data", + tables=[table_conf_with_opts], + ) + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="e3g", s_nationkey=33, s_suppkey=3), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.source_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-2", s_name="name-2", s_nationkey=22, s_phone="222-2", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.target_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.source_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2de", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="k4l", s_nationkey=44, s_suppkey=4), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.target_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-22", s_name="name-2", s_nationkey=22, s_phone="222", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.source_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.target_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + + reconcile_config_schema = ReconcileConfig( + data_source="databricks", + report_type="schema", + secret_scope="remorph_databricks", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(schema="default"), + ) + + return table_recon, source, target, reconcile_config_schema + + +def test_recon_for_report_type_schema( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_schema, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config_schema = mock_for_report_type_schema + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=22222 + ), + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + final_reconcile_output = recon( + mock_workspace_client, mock_spark, table_recon, reconcile_config_schema, local_test_run=True + ) + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 22222, + "00112233-4455-6677-8899-aabbccddeeff", + "Databricks", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "schema", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[(22222, (None, None, True), (True, "remorph", ""), datetime(2024, 5, 23, 9, 21, 25, 122185))], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame( + data=[ + ( + 22222, + "schema", + True, + [ + { + "source_column": "s_suppkey", + "source_datatype": "number", + "databricks_column": "s_suppkey_t", + "databricks_datatype": "number", + "is_valid": "true", + }, + { + "source_column": "s_name", + "source_datatype": "varchar", + "databricks_column": "s_name", + "databricks_datatype": "varchar", + "is_valid": "true", + }, + { + "source_column": "s_address", + "source_datatype": "varchar", + "databricks_column": "s_address_t", + "databricks_datatype": "varchar", + "is_valid": "true", + }, + { + "source_column": "s_nationkey", + "source_datatype": "number", + "databricks_column": "s_nationkey_t", + "databricks_datatype": "number", + "is_valid": "true", + }, + { + "source_column": "s_phone", + "source_datatype": "varchar", + "databricks_column": "s_phone_t", + "databricks_datatype": "varchar", + "is_valid": "true", + }, + { + "source_column": "s_acctbal", + "source_datatype": "number", + "databricks_column": "s_acctbal_t", + "databricks_datatype": "number", + "is_valid": "true", + }, + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=details_schema, + ) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + assert final_reconcile_output.recon_id == "00112233-4455-6677-8899-aabbccddeeff" + + +@pytest.fixture +def mock_for_report_type_all( + mock_workspace_client, + table_conf_with_opts, + table_schema, + mock_spark, + query_store, + setup_metadata_table, +): + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_recon = TableRecon( + source_catalog="org", + source_schema="data", + target_catalog="org", + target_schema="data", + tables=[table_conf_with_opts], + ) + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.source_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="e3g", s_nationkey=33, s_suppkey=3), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.source_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-2", s_name="name-2", s_nationkey=22, s_phone="222-2", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.target_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-3", s_name="name-3", s_nationkey=33, s_phone="333", s_suppkey=3)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.source_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.hash_queries.target_hash_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2de", s_nationkey=22, s_suppkey=2), + Row(hash_value_recon="k4l", s_nationkey=44, s_suppkey=4), + ] + ), + (CATALOG, SCHEMA, query_store.mismatch_queries.target_mismatch_query): mock_spark.createDataFrame( + [Row(s_address="address-22", s_name="name-2", s_nationkey=22, s_phone="222", s_suppkey=2)] + ), + (CATALOG, SCHEMA, query_store.missing_queries.source_missing_query): mock_spark.createDataFrame( + [Row(s_address="address-4", s_name="name-4", s_nationkey=44, s_phone="444", s_suppkey=4)] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.target_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + reconcile_config_all = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(), + ) + return table_recon, source, target, reconcile_config_all + + +def test_recon_for_report_type_all( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_all, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config_all = mock_for_report_type_all + + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + with pytest.raises(ReconciliationException) as exc_info: + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config_all, local_test_run=True) + if exc_info.value.reconcile_output is not None: + assert exc_info.value.reconcile_output.recon_id == "00112233-4455-6677-8899-aabbccddeeff" + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "all", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + ((1, 1), (1, 0, "s_address,s_phone"), False), + (False, "remorph", ""), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame( + data=[ + ( + 33333, + "mismatch", + False, + [ + { + "s_suppkey": "2", + "s_nationkey": "22", + "s_address_base": "address-2", + "s_address_compare": "address-22", + "s_address_match": "false", + "s_name_base": "name-2", + "s_name_compare": "name-2", + "s_name_match": "true", + "s_phone_base": "222-2", + "s_phone_compare": "222", + "s_phone_match": "false", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 33333, + "missing_in_source", + False, + [ + { + "s_address": "address-4", + "s_name": "name-4", + "s_nationkey": "44", + "s_phone": "444", + "s_suppkey": "4", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 33333, + "missing_in_target", + False, + [ + { + "s_address": "address-3", + "s_name": "name-3", + "s_nationkey": "33", + "s_phone": "333", + "s_suppkey": "3", + } + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 33333, + "schema", + False, + [ + { + "source_column": "s_suppkey", + "source_datatype": "number", + "databricks_column": "s_suppkey_t", + "databricks_datatype": "number", + "is_valid": "false", + }, + { + "source_column": "s_name", + "source_datatype": "varchar", + "databricks_column": "s_name", + "databricks_datatype": "varchar", + "is_valid": "false", + }, + { + "source_column": "s_address", + "source_datatype": "varchar", + "databricks_column": "s_address_t", + "databricks_datatype": "varchar", + "is_valid": "false", + }, + { + "source_column": "s_nationkey", + "source_datatype": "number", + "databricks_column": "s_nationkey_t", + "databricks_datatype": "number", + "is_valid": "false", + }, + { + "source_column": "s_phone", + "source_datatype": "varchar", + "databricks_column": "s_phone_t", + "databricks_datatype": "varchar", + "is_valid": "false", + }, + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ], + schema=details_schema, + ) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +@pytest.fixture +def mock_for_report_type_row(table_conf_with_opts, table_schema, mock_spark, query_store, setup_metadata_table): + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_recon = TableRecon( + source_catalog="org", + source_schema="data", + target_catalog="org", + target_schema="data", + tables=[table_conf_with_opts], + ) + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.row_queries.source_row_query, + ): mock_spark.createDataFrame( + [ + Row( + hash_value_recon="a1b", + s_address="address-1", + s_name="name-1", + s_nationkey=11, + s_phone="111", + s_suppkey=1, + ), + Row( + hash_value_recon="c2d", + s_address="address-2", + s_name="name-2", + s_nationkey=22, + s_phone="222-2", + s_suppkey=2, + ), + Row( + hash_value_recon="e3g", + s_address="address-3", + s_name="name-3", + s_nationkey=33, + s_phone="333", + s_suppkey=3, + ), + ] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.source_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.row_queries.target_row_query, + ): mock_spark.createDataFrame( + [ + Row( + hash_value_recon="a1b", + s_address="address-1", + s_name="name-1", + s_nationkey=11, + s_phone="111", + s_suppkey=1, + ), + Row( + hash_value_recon="c2de", + s_address="address-2", + s_name="name-2", + s_nationkey=22, + s_phone="222", + s_suppkey=2, + ), + Row( + hash_value_recon="h4k", + s_address="address-4", + s_name="name-4", + s_nationkey=44, + s_phone="444", + s_suppkey=4, + ), + ] + ), + (CATALOG, SCHEMA, query_store.record_count_queries.target_record_count_query): mock_spark.createDataFrame( + [Row(count=3)] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + reconcile_config_row = ReconcileConfig( + data_source="snowflake", + report_type="row", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(), + ) + + return source, target, table_recon, reconcile_config_row + + +def test_recon_for_report_type_is_row( + mock_workspace_client, + mock_spark, + mock_for_report_type_row, + report_tables_schema, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + source, target, table_recon, reconcile_config_row = mock_for_report_type_row + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + with pytest.raises(ReconciliationException) as exc_info: + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config_row, local_test_run=True) + + if exc_info.value.reconcile_output is not None: + assert exc_info.value.reconcile_output.recon_id == "00112233-4455-6677-8899-aabbccddeeff" + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "row", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + ((2, 2), None, None), + (False, "remorph", ""), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame( + data=[ + ( + 33333, + "missing_in_source", + False, + [ + { + 's_address': 'address-2', + 's_name': 'name-2', + 's_nationkey': '22', + 's_phone': '222', + 's_suppkey': '2', + }, + { + 's_address': 'address-4', + 's_name': 'name-4', + 's_nationkey': '44', + 's_phone': '444', + 's_suppkey': '4', + }, + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ( + 33333, + "missing_in_target", + False, + [ + { + 's_address': 'address-2', + 's_name': 'name-2', + 's_nationkey': '22', + 's_phone': '222-2', + 's_suppkey': '2', + }, + { + 's_address': 'address-3', + 's_name': 'name-3', + 's_nationkey': '33', + 's_phone': '333', + 's_suppkey': '3', + }, + ], + datetime(2024, 5, 23, 9, 21, 25, 122185), + ), + ], + schema=details_schema, + ) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +@pytest.fixture +def mock_for_recon_exception(table_conf_with_opts, setup_metadata_table): + table_conf_with_opts.drop_columns = ["s_acctbal"] + table_conf_with_opts.column_thresholds = None + table_conf_with_opts.join_columns = None + table_recon = TableRecon( + source_catalog="org", + source_schema="data", + target_catalog="org", + target_schema="data", + tables=[table_conf_with_opts], + ) + source = MockDataSource({}, {}) + target = MockDataSource({}, {}) + reconcile_config_exception = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(), + ) + + return table_recon, source, target, reconcile_config_exception + + +def test_schema_recon_with_data_source_exception( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_recon_exception, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config_exception = mock_for_recon_exception + reconcile_config_exception.report_type = "schema" + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + pytest.raises(ReconciliationException, match="00112233-4455-6677-8899-aabbccddeeff"), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config_exception, local_test_run=True) + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "schema", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + (None, None, None), + ( + False, + "remorph", + "Runtime exception occurred while fetching schema using (org, data, supplier) : Mock Exception", + ), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame(data=[], schema=details_schema) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +def test_schema_recon_with_general_exception( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_schema, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config_schema = mock_for_report_type_schema + reconcile_config_schema.data_source = "snowflake" + reconcile_config_schema.secret_scope = "remorph_snowflake" + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.Reconciliation.reconcile_schema") as schema_source_mock, + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + pytest.raises(ReconciliationException, match="00112233-4455-6677-8899-aabbccddeeff"), + ): + schema_source_mock.side_effect = PySparkException("Unknown Error") + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config_schema, local_test_run=True) + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "schema", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + (None, None, None), + ( + False, + "remorph", + "Unknown Error", + ), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame(data=[], schema=details_schema) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +def test_data_recon_with_general_exception( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_schema, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config = mock_for_report_type_schema + reconcile_config.data_source = "snowflake" + reconcile_config.secret_scope = "remorph_snowflake" + reconcile_config.report_type = "data" + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.Reconciliation.reconcile_data") as data_source_mock, + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + pytest.raises(ReconciliationException, match="00112233-4455-6677-8899-aabbccddeeff"), + ): + data_source_mock.side_effect = DataSourceRuntimeException("Unknown Error") + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config, local_test_run=True) + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "data", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + (None, None, None), + ( + False, + "remorph", + "Unknown Error", + ), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame(data=[], schema=details_schema) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +def test_data_recon_with_source_exception( + mock_workspace_client, + mock_spark, + report_tables_schema, + mock_for_report_type_schema, + tmp_path: Path, +): + recon_schema, metrics_schema, details_schema = report_tables_schema + table_recon, source, target, reconcile_config = mock_for_report_type_schema + reconcile_config.data_source = "snowflake" + reconcile_config.secret_scope = "remorph_snowflake" + reconcile_config.report_type = "data" + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + patch("databricks.labs.remorph.reconcile.execute.Reconciliation.reconcile_data") as data_source_mock, + patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)), + pytest.raises(ReconciliationException, match="00112233-4455-6677-8899-aabbccddeeff"), + ): + data_source_mock.side_effect = DataSourceRuntimeException("Source Runtime Error") + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config, local_test_run=True) + + expected_remorph_recon = mock_spark.createDataFrame( + data=[ + ( + 33333, + "00112233-4455-6677-8899-aabbccddeeff", + "Snowflake", + ("org", "data", "supplier"), + ("org", "data", "target_supplier"), + "data", + "reconcile", + datetime(2024, 5, 23, 9, 21, 25, 122185), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=recon_schema, + ) + expected_remorph_recon_metrics = mock_spark.createDataFrame( + data=[ + ( + 33333, + (None, None, None), + ( + False, + "remorph", + "Source Runtime Error", + ), + datetime(2024, 5, 23, 9, 21, 25, 122185), + ) + ], + schema=metrics_schema, + ) + expected_remorph_recon_details = mock_spark.createDataFrame(data=[], schema=details_schema) + + assertDataFrameEqual(mock_spark.sql("SELECT * FROM DEFAULT.MAIN"), expected_remorph_recon, ignoreNullable=True) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.METRICS"), expected_remorph_recon_metrics, ignoreNullable=True + ) + assertDataFrameEqual( + mock_spark.sql("SELECT * FROM DEFAULT.DETAILS"), expected_remorph_recon_details, ignoreNullable=True + ) + + +def test_initialise_data_source(mock_workspace_client, mock_spark): + src_engine = get_dialect("snowflake") + secret_scope = "test" + + source, target = initialise_data_source(mock_workspace_client, mock_spark, src_engine, secret_scope) + + snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ + databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ + + assert isinstance(source, snowflake_data_source) + assert isinstance(target, databricks_data_source) + + +def test_recon_for_wrong_report_type(mock_workspace_client, mock_spark, mock_for_report_type_row, report_tables_schema): + source, target, table_recon, reconcile_config = mock_for_report_type_row + reconcile_config.report_type = "ro" + with ( + patch("databricks.labs.remorph.reconcile.execute.datetime") as mock_datetime, + patch("databricks.labs.remorph.reconcile.recon_capture.datetime") as recon_datetime, + patch("databricks.labs.remorph.reconcile.execute.initialise_data_source", return_value=(source, target)), + patch("databricks.labs.remorph.reconcile.execute.uuid4", return_value="00112233-4455-6677-8899-aabbccddeeff"), + patch( + "databricks.labs.remorph.reconcile.recon_capture.ReconCapture._generate_recon_main_id", return_value=33333 + ), + pytest.raises(InvalidInputException), + ): + mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185) + recon(mock_workspace_client, mock_spark, table_recon, reconcile_config, local_test_run=True) + + +def test_reconcile_data_with_threshold_and_row_report_type( + mock_spark, + table_conf_with_opts, + table_schema, + query_store, + tmp_path: Path, +): + src_schema, tgt_schema = table_schema + source_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.row_queries.source_row_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.source_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=100)] + ), + } + source_schema_repository = {(CATALOG, SCHEMA, SRC_TABLE): src_schema} + + target_dataframe_repository = { + ( + CATALOG, + SCHEMA, + query_store.row_queries.target_row_query, + ): mock_spark.createDataFrame( + [ + Row(hash_value_recon="a1b", s_nationkey=11, s_suppkey=1), + Row(hash_value_recon="c2d", s_nationkey=22, s_suppkey=2), + ] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.target_threshold_query): mock_spark.createDataFrame( + [Row(s_nationkey=11, s_suppkey=1, s_acctbal=110)] + ), + (CATALOG, SCHEMA, query_store.threshold_queries.threshold_comparison_query): mock_spark.createDataFrame( + [ + Row( + s_acctbal_source=100, + s_acctbal_databricks=110, + s_acctbal_match="Warning", + s_nationkey_source=11, + s_suppkey_source=1, + ) + ] + ), + } + + target_schema_repository = {(CATALOG, SCHEMA, TGT_TABLE): tgt_schema} + database_config = DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ) + schema_comparator = SchemaCompare(mock_spark) + source = MockDataSource(source_dataframe_repository, source_schema_repository) + target = MockDataSource(target_dataframe_repository, target_schema_repository) + + with patch("databricks.labs.remorph.reconcile.execute.generate_volume_path", return_value=str(tmp_path)): + actual = Reconciliation( + source, + target, + database_config, + "row", + schema_comparator, + get_dialect("databricks"), + mock_spark, + ReconcileMetadataConfig(), + ).reconcile_data(table_conf_with_opts, src_schema, tgt_schema) + + assert actual.mismatch_count == 0 + assert actual.missing_in_src_count == 0 + assert actual.missing_in_tgt_count == 0 + assert actual.threshold_output.threshold_df is None + assert actual.threshold_output.threshold_mismatch_count == 0 + + +@patch('databricks.labs.remorph.reconcile.execute.generate_final_reconcile_output') +def test_recon_output_without_exception(mock_gen_final_recon_output): + mock_workspace_client = MagicMock() + mock_spark = MagicMock() + mock_table_recon = MagicMock() + mock_gen_final_recon_output.return_value = ReconcileOutput( + recon_id="00112233-4455-6677-8899-aabbccddeeff", + results=[ + ReconcileTableOutput( + target_table_name="supplier", + source_table_name="target_supplier", + status=StatusOutput( + row=True, + column=True, + schema=True, + ), + exception_message=None, + ) + ], + ) + reconcile_config = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_catalog=CATALOG, + source_schema=SCHEMA, + target_catalog=CATALOG, + target_schema=SCHEMA, + ), + metadata_config=ReconcileMetadataConfig(), + ) + + try: + recon( + mock_workspace_client, + mock_spark, + mock_table_recon, + reconcile_config, + ) + except ReconciliationException as e: + msg = f"An exception {e} was raised when it should not have been" + pytest.fail(msg) + + +def test_generate_volume_path(table_conf_with_opts): + volume_path = generate_volume_path(table_conf_with_opts, ReconcileMetadataConfig()) + assert ( + volume_path + == f"/Volumes/remorph/reconcile/reconcile_volume/{table_conf_with_opts.source_name}_{table_conf_with_opts.target_name}/" + ) diff --git a/tests/unit/reconcile/test_recon_capture.py b/tests/unit/reconcile/test_recon_capture.py new file mode 100644 index 0000000000..ed98a7ebd5 --- /dev/null +++ b/tests/unit/reconcile/test_recon_capture.py @@ -0,0 +1,1013 @@ +from pathlib import Path +import datetime +import json + +import pytest +from pyspark.sql import Row, SparkSession +from pyspark.sql.functions import countDistinct +from pyspark.sql.types import BooleanType, StringType, StructField, StructType + +from databricks.labs.remorph.config import DatabaseConfig, get_dialect, ReconcileMetadataConfig +from databricks.labs.remorph.reconcile.exception import WriteToTableException, ReadAndWriteWithVolumeException +from databricks.labs.remorph.reconcile.recon_capture import ( + ReconCapture, + generate_final_reconcile_output, + ReconIntermediatePersist, +) +from databricks.labs.remorph.reconcile.recon_config import ( + DataReconcileOutput, + MismatchOutput, + ReconcileOutput, + ReconcileProcessDuration, + ReconcileTableOutput, + SchemaReconcileOutput, + StatusOutput, + Table, + ThresholdOutput, + ReconcileRecordCount, + TableThresholds, + TableThresholdBoundsException, +) + + +def data_prep(spark: SparkSession): + # Mismatch DataFrame + data = [ + Row(id=1, name_source='source1', name_target='target1', name_match='match1'), + Row(id=2, name_source='source2', name_target='target2', name_match='match2'), + ] + mismatch_df = spark.createDataFrame(data) + + # Missing DataFrames + data1 = [Row(id=1, name='name1'), Row(id=2, name='name2'), Row(id=3, name='name3')] + data2 = [Row(id=1, name='name1'), Row(id=2, name='name2'), Row(id=3, name='name3'), Row(id=4, name='name4')] + df1 = spark.createDataFrame(data1) + df2 = spark.createDataFrame(data2) + + # Schema Compare Dataframe + schema = StructType( + [ + StructField("source_column", StringType(), True), + StructField("source_datatype", StringType(), True), + StructField("databricks_column", StringType(), True), + StructField("databricks_datatype", StringType(), True), + StructField("is_valid", BooleanType(), True), + ] + ) + + data = [ + Row( + source_column="source_column1", + source_datatype="source_datatype1", + databricks_column="databricks_column1", + databricks_datatype="databricks_datatype1", + is_valid=True, + ), + Row( + source_column="source_column2", + source_datatype="source_datatype2", + databricks_column="databricks_column2", + databricks_datatype="databricks_datatype2", + is_valid=True, + ), + Row( + source_column="source_column3", + source_datatype="source_datatype3", + databricks_column="databricks_column3", + databricks_datatype="databricks_datatype3", + is_valid=True, + ), + Row( + source_column="source_column4", + source_datatype="source_datatype4", + databricks_column="databricks_column4", + databricks_datatype="databricks_datatype4", + is_valid=True, + ), + ] + + schema_df = spark.createDataFrame(data, schema) + + data_rows = [ + Row(id=1, sal_source=1000, sal_target=1100, sal_match=True), + Row(id=2, sal_source=2000, sal_target=2100, sal_match=False), + ] + threshold_df = spark.createDataFrame(data_rows) + + # Prepare output dataclasses + mismatch = MismatchOutput(mismatch_df=mismatch_df, mismatch_columns=["name"]) + threshold = ThresholdOutput(threshold_df, threshold_mismatch_count=2) + reconcile_output = DataReconcileOutput( + mismatch_count=2, + missing_in_src_count=3, + missing_in_tgt_count=4, + mismatch=mismatch, + missing_in_src=df1, + missing_in_tgt=df2, + threshold_output=threshold, + ) + schema_output = SchemaReconcileOutput(is_valid=True, compare_df=schema_df) + table_conf = Table(source_name="supplier", target_name="target_supplier") + reconcile_process = ReconcileProcessDuration( + start_ts=str(datetime.datetime.now()), end_ts=str(datetime.datetime.now()) + ) + + # Drop old data + spark.sql("DROP TABLE IF EXISTS DEFAULT.main") + spark.sql("DROP TABLE IF EXISTS DEFAULT.metrics") + spark.sql("DROP TABLE IF EXISTS DEFAULT.details") + + row_count = ReconcileRecordCount(source=5, target=5) + + return reconcile_output, schema_output, table_conf, reconcile_process, row_count + + +def test_recon_capture_start_snowflake_all(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert main + remorph_recon_df = spark.sql("select * from DEFAULT.main") + row = remorph_recon_df.collect()[0] + assert remorph_recon_df.count() == 1 + assert row.recon_id == "73b44582-dbb7-489f-bad1-6a7e8f4821b1" + assert row.source_table.catalog == "source_test_catalog" + assert row.source_table.schema == "source_test_schema" + assert row.source_table.table_name == "supplier" + assert row.target_table.catalog == "target_test_catalog" + assert row.target_table.schema == "target_test_schema" + assert row.target_table.table_name == "target_supplier" + assert row.report_type == "all" + assert row.source_type == "Snowflake" + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert remorph_recon_metrics_df.count() == 1 + assert row.recon_metrics.row_comparison.missing_in_source == 3 + assert row.recon_metrics.row_comparison.missing_in_target == 4 + assert row.recon_metrics.column_comparison.absolute_mismatch == 2 + assert row.recon_metrics.column_comparison.threshold_mismatch == 2 + assert row.recon_metrics.column_comparison.mismatch_columns == "name" + assert row.recon_metrics.schema_comparison is True + assert row.run_metrics.status is False + assert row.run_metrics.run_by_user == "remorph" + assert row.run_metrics.exception_message == "" + + # assert details + remorph_recon_details_df = spark.sql("select * from DEFAULT.details") + assert remorph_recon_details_df.count() == 5 + assert remorph_recon_details_df.select("recon_type").distinct().count() == 5 + assert ( + remorph_recon_details_df.select("recon_table_id", "status") + .groupby("recon_table_id") + .agg(countDistinct("status").alias("count_stat")) + .collect()[0] + .count_stat + == 2 + ) + assert json.dumps(remorph_recon_details_df.where("recon_type = 'mismatch'").select("data").collect()[0].data) == ( + "[{\"id\": \"1\", \"name_source\": \"source1\", \"name_target\": \"target1\", " + "\"name_match\": \"match1\"}, {\"id\": \"2\", \"name_source\": \"source2\", " + "\"name_target\": \"target2\", \"name_match\": \"match2\"}]" + ) + + rows = remorph_recon_details_df.orderBy("recon_type").collect() + assert rows[0].recon_type == "mismatch" + assert rows[0].status is False + assert rows[1].recon_type == "missing_in_source" + assert rows[1].status is False + assert rows[2].recon_type == "missing_in_target" + assert rows[2].status is False + assert rows[3].recon_type == "schema" + assert rows[3].status is True + assert rows[4].recon_type == "threshold_mismatch" + assert rows[4].status is False + + +def test_test_recon_capture_start_databricks_data(mock_workspace_client, mock_spark): + database_config = DatabaseConfig("source_test_schema", "target_test_catalog", "target_test_schema") + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "data", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + schema_output.compare_df = None + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert main + remorph_recon_df = spark.sql("select * from DEFAULT.main") + row = remorph_recon_df.collect()[0] + assert remorph_recon_df.count() == 1 + assert row.source_table.catalog is None + assert row.report_type == "data" + assert row.source_type == "Databricks" + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.recon_metrics.schema_comparison is None + assert row.run_metrics.status is False + + # assert details + remorph_recon_details_df = spark.sql("select * from DEFAULT.details") + assert remorph_recon_details_df.count() == 4 + assert remorph_recon_details_df.select("recon_type").distinct().count() == 4 + + +def test_test_recon_capture_start_databricks_row(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "row", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.mismatch_count = 0 + reconcile_output.mismatch = MismatchOutput() + reconcile_output.threshold_output = ThresholdOutput() + schema_output.compare_df = None + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert main + remorph_recon_df = spark.sql("select * from DEFAULT.main") + row = remorph_recon_df.collect()[0] + assert remorph_recon_df.count() == 1 + assert row.report_type == "row" + assert row.source_type == "Databricks" + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.recon_metrics.column_comparison is None + assert row.recon_metrics.schema_comparison is None + assert row.run_metrics.status is False + + # assert details + remorph_recon_details_df = spark.sql("select * from DEFAULT.details") + assert remorph_recon_details_df.count() == 2 + assert remorph_recon_details_df.select("recon_type").distinct().count() == 2 + + +def test_recon_capture_start_oracle_schema(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("oracle") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "schema", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.threshold_output = ThresholdOutput() + reconcile_output.mismatch_count = 0 + reconcile_output.mismatch = MismatchOutput() + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert main + remorph_recon_df = spark.sql("select * from DEFAULT.main") + row = remorph_recon_df.collect()[0] + assert remorph_recon_df.count() == 1 + assert row.report_type == "schema" + assert row.source_type == "Oracle" + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.recon_metrics.row_comparison is None + assert row.recon_metrics.column_comparison is None + assert row.recon_metrics.schema_comparison is True + assert row.run_metrics.status is True + + # assert details + remorph_recon_details_df = spark.sql("select * from DEFAULT.details") + assert remorph_recon_details_df.count() == 1 + assert remorph_recon_details_df.select("recon_type").distinct().count() == 1 + + +def test_recon_capture_start_oracle_with_exception(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("oracle") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.threshold_output = ThresholdOutput() + reconcile_output.mismatch_count = 0 + reconcile_output.mismatch = MismatchOutput() + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.exception = "Test exception" + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert main + remorph_recon_df = spark.sql("select * from DEFAULT.main") + row = remorph_recon_df.collect()[0] + assert remorph_recon_df.count() == 1 + assert row.report_type == "all" + assert row.source_type == "Oracle" + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.recon_metrics.schema_comparison is None + assert row.run_metrics.status is False + assert row.run_metrics.exception_message == "Test exception" + + +def test_recon_capture_start_with_exception(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + with pytest.raises(WriteToTableException): + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + +def test_generate_final_reconcile_output_row(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", + "target_test_catalog", + "target_test_schema", + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "row", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + final_output = generate_final_reconcile_output( + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + assert final_output == ReconcileOutput( + recon_id='73b44582-dbb7-489f-bad1-6a7e8f4821b1', + results=[ + ReconcileTableOutput( + target_table_name='target_test_catalog.target_test_schema.target_supplier', + source_table_name='source_test_schema.supplier', + status=StatusOutput(row=False, column=None, schema=None), + exception_message='', + ) + ], + ) + + +def test_generate_final_reconcile_output_data(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", + "target_test_catalog", + "target_test_schema", + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "data", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + final_output = generate_final_reconcile_output( + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + assert final_output == ReconcileOutput( + recon_id='73b44582-dbb7-489f-bad1-6a7e8f4821b1', + results=[ + ReconcileTableOutput( + target_table_name='target_test_catalog.target_test_schema.target_supplier', + source_table_name='source_test_schema.supplier', + status=StatusOutput(row=False, column=False, schema=None), + exception_message='', + ) + ], + ) + + +def test_generate_final_reconcile_output_schema(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", + "target_test_catalog", + "target_test_schema", + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "schema", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + final_output = generate_final_reconcile_output( + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + assert final_output == ReconcileOutput( + recon_id='73b44582-dbb7-489f-bad1-6a7e8f4821b1', + results=[ + ReconcileTableOutput( + target_table_name='target_test_catalog.target_test_schema.target_supplier', + source_table_name='source_test_schema.supplier', + status=StatusOutput(row=None, column=None, schema=True), + exception_message='', + ) + ], + ) + + +def test_generate_final_reconcile_output_all(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", + "target_test_catalog", + "target_test_schema", + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + final_output = generate_final_reconcile_output( + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + assert final_output == ReconcileOutput( + recon_id='73b44582-dbb7-489f-bad1-6a7e8f4821b1', + results=[ + ReconcileTableOutput( + target_table_name='target_test_catalog.target_test_schema.target_supplier', + source_table_name='source_test_schema.supplier', + status=StatusOutput(row=False, column=False, schema=True), + exception_message='', + ) + ], + ) + + +def test_generate_final_reconcile_output_exception(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", + "target_test_catalog", + "target_test_schema", + ) + ws = mock_workspace_client + source_type = get_dialect("databricks") + spark = mock_spark + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.exception = "Test exception" + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + final_output = generate_final_reconcile_output( + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + mock_spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + assert final_output == ReconcileOutput( + recon_id='73b44582-dbb7-489f-bad1-6a7e8f4821b1', + results=[ + ReconcileTableOutput( + target_table_name='target_test_catalog.target_test_schema.target_supplier', + source_table_name='source_test_schema.supplier', + status=StatusOutput(row=None, column=None, schema=None), + exception_message='Test exception', + ) + ], + ) + + +def test_write_and_read_unmatched_df_with_volumes_with_exception(tmp_path: Path, mock_spark, mock_workspace_client): + data = [Row(id=1, name='John', sal=5000), Row(id=2, name='Jane', sal=6000), Row(id=3, name='Doe', sal=7000)] + df = mock_spark.createDataFrame(data) + + path = str(tmp_path) + df = ReconIntermediatePersist(mock_spark, path).write_and_read_unmatched_df_with_volumes(df) + assert df.count() == 3 + + path = "/path/that/does/not/exist" + with pytest.raises(ReadAndWriteWithVolumeException): + ReconIntermediatePersist(mock_spark, path).write_and_read_unmatched_df_with_volumes(df) + + +def test_clean_unmatched_df_from_volume_with_exception(mock_spark): + path = "/path/that/does/not/exist" + with pytest.raises(Exception): + ReconIntermediatePersist(mock_spark, path).clean_unmatched_df_from_volume() + + +def test_apply_threshold_for_mismatch_with_true_absolute(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), + ] + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is True + + +def test_apply_threshold_for_mismatch_with_missing(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), + ] + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is False + + +def test_apply_threshold_for_mismatch_with_schema_fail(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0", upper_bound="4", model="mismatch"), + ] + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + schema_output = SchemaReconcileOutput(is_valid=False, compare_df=None) + + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is False + + +def test_apply_threshold_for_mismatch_with_wrong_absolute_bound(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0", upper_bound="1", model="mismatch"), + ] + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.threshold_output = ThresholdOutput() + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is False + + +def test_apply_threshold_for_mismatch_with_wrong_percentage_bound(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0%", upper_bound="20%", model="mismatch"), + ] + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.threshold_output = ThresholdOutput() + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is False + + +def test_apply_threshold_for_mismatch_with_true_percentage_bound(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0%", upper_bound="90%", model="mismatch"), + ] + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is True + + +def test_apply_threshold_for_mismatch_with_invalid_bounds(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.threshold_output = ThresholdOutput() + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + with pytest.raises(TableThresholdBoundsException): + table_conf.table_thresholds = [ + TableThresholds(lower_bound="-0%", upper_bound="-40%", model="mismatch"), + ] + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + with pytest.raises(TableThresholdBoundsException): + table_conf.table_thresholds = [ + TableThresholds(lower_bound="10%", upper_bound="5%", model="mismatch"), + ] + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + +def test_apply_threshold_for_only_threshold_mismatch_with_true_absolute(mock_workspace_client, mock_spark): + database_config = DatabaseConfig( + "source_test_schema", "target_test_catalog", "target_test_schema", "source_test_catalog" + ) + ws = mock_workspace_client + source_type = get_dialect("snowflake") + spark = mock_spark + reconcile_output, schema_output, table_conf, reconcile_process, row_count = data_prep(spark) + reconcile_output.mismatch_count = 0 + reconcile_output.missing_in_src_count = 0 + reconcile_output.missing_in_tgt_count = 0 + reconcile_output.missing_in_src = None + reconcile_output.missing_in_tgt = None + table_conf.table_thresholds = [ + TableThresholds(lower_bound="0", upper_bound="2", model="mismatch"), + ] + recon_capture = ReconCapture( + database_config, + "73b44582-dbb7-489f-bad1-6a7e8f4821b1", + "all", + source_type, + ws, + spark, + metadata_config=ReconcileMetadataConfig(schema="default"), + local_test_run=True, + ) + recon_capture.start( + data_reconcile_output=reconcile_output, + schema_reconcile_output=schema_output, + table_conf=table_conf, + recon_process_duration=reconcile_process, + record_count=row_count, + ) + + # assert metrics + remorph_recon_metrics_df = spark.sql("select * from DEFAULT.metrics") + row = remorph_recon_metrics_df.collect()[0] + assert row.run_metrics.status is True diff --git a/tests/unit/reconcile/test_recon_config.py b/tests/unit/reconcile/test_recon_config.py new file mode 100644 index 0000000000..959685d73f --- /dev/null +++ b/tests/unit/reconcile/test_recon_config.py @@ -0,0 +1,38 @@ +def test_table_without_join_column(table_conf_mock): + table_conf = table_conf_mock() + assert table_conf.get_join_columns("source") is None + assert table_conf.get_drop_columns("source") == set() + assert table_conf.get_partition_column("source") == set() + assert table_conf.get_partition_column("target") == set() + assert table_conf.get_filter("source") is None + assert table_conf.get_filter("target") is None + assert table_conf.get_threshold_columns("source") == set() + + +def test_table_with_all_options(table_conf_with_opts): + ## layer == source + + assert table_conf_with_opts.get_join_columns("source") == {"s_nationkey", "s_suppkey"} + assert table_conf_with_opts.get_drop_columns("source") == {"s_comment"} + assert table_conf_with_opts.get_partition_column("source") == {"s_nationkey"} + assert table_conf_with_opts.get_partition_column("target") == set() + assert table_conf_with_opts.get_filter("source") == "s_name='t' and s_address='a'" + assert table_conf_with_opts.get_threshold_columns("source") == {"s_acctbal"} + + ## layer == target + assert table_conf_with_opts.get_join_columns("target") == {"s_nationkey_t", "s_suppkey_t"} + assert table_conf_with_opts.get_drop_columns("target") == {"s_comment_t"} + assert table_conf_with_opts.get_partition_column("target") == set() + assert table_conf_with_opts.get_filter("target") == "s_name='t' and s_address_t='a'" + assert table_conf_with_opts.get_threshold_columns("target") == {"s_acctbal_t"} + + +def test_table_without_column_mapping(table_conf_mock, column_mapping): + table_conf = table_conf_mock() + + assert table_conf.get_tgt_to_src_col_mapping_list(["s_address", "s_name"]) == {"s_address", "s_name"} + assert table_conf.get_layer_tgt_to_src_col_mapping("s_address_t", "target") == "s_address_t" + assert table_conf.get_layer_tgt_to_src_col_mapping("s_address", "source") == "s_address" + assert table_conf.get_src_to_tgt_col_mapping_list(["s_address", "s_name"], "source") == {"s_address", "s_name"} + assert table_conf.get_src_to_tgt_col_mapping_list(["s_address", "s_name"], "target") == {"s_address", "s_name"} + assert table_conf.get_layer_src_to_tgt_col_mapping("s_address", "source") == "s_address" diff --git a/tests/unit/reconcile/test_runner.py b/tests/unit/reconcile/test_runner.py new file mode 100644 index 0000000000..92a7d23066 --- /dev/null +++ b/tests/unit/reconcile/test_runner.py @@ -0,0 +1,454 @@ +from unittest.mock import create_autospec, Mock +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.tui import MockPrompts +from databricks.sdk import WorkspaceClient +from databricks.labs.remorph.reconcile.runner import ReconcileRunner +from databricks.labs.remorph.deployment.recon import RECON_JOB_NAME + + +def test_run_with_missing_recon_config(): + ws = create_autospec(WorkspaceClient) + installation = MockInstallation() + install_state = InstallState.from_installation(installation) + prompts = MockPrompts({}) + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + + +def test_run_with_corrupt_recon_config(): + ws = create_autospec(WorkspaceClient) + prompts = MockPrompts({}) + installation = MockInstallation( + { + "reconcile.yml": { + "source": "oracle", # Invalid key + "report_type": "all", + "secret_scope": "remorph_oracle2", + "database_config": { + "source_schema": "tpch_sf10002", + "target_catalog": "tpch2", + "target_schema": "1000gb2", + }, + "metadata_config": { + "catalog": "remorph2", + "schema": "reconcile2", + "volume": "reconcile_volume2", + }, + "version": 1, + } + } + ) + install_state = InstallState.from_installation(installation) + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + + +def test_run_with_missing_table_config(): + ws = create_autospec(WorkspaceClient) + installation = MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "job_id": "45t34wer32", + "version": 1, + } + } + ) + install_state = InstallState.from_installation(installation) + prompts = MockPrompts({}) + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + + +def test_run_with_corrupt_table_config(): + ws = create_autospec(WorkspaceClient) + installation = MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "job_id": "45t34wer32", + "version": 1, + }, + "recon_config_snowflake_abc_all.json": { + "source_catalog": "abc", + "source": "def", # Invalid key + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + }, + } + ) + install_state = InstallState.from_installation(installation) + prompts = MockPrompts({}) + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + + +def test_run_with_missing_job_id(): + ws = create_autospec(WorkspaceClient) + installation = MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + "recon_config_snowflake_abc_all.json": { + "source_catalog": "abc", + "source_schema": "def", + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + }, + } + ) + install_state = InstallState.from_installation(installation) + prompts = MockPrompts({}) + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + + +def test_run_with_job_id_in_config(): + ws = create_autospec(WorkspaceClient) + prompts = MockPrompts( + { + r"Would you like to open the job run URL .*": "no", + } + ) + installation = MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "job_id": "1234", + "version": 1, + }, + "recon_config_snowflake_abc_all.json": { + "source_catalog": "abc", + "source_schema": "def", + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + }, + } + ) + install_state = InstallState.from_installation(installation) + wait = Mock() + wait.run_id = "rid" + ws.jobs.run_now.return_value = wait + + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + recon_runner.run() + ws.jobs.run_now.assert_called_once_with(1234, job_parameters={'operation_name': 'reconcile'}) + + +def test_run_with_job_id_in_state(monkeypatch): + monkeypatch.setattr("webbrowser.open", lambda url: None) + ws = create_autospec(WorkspaceClient) + prompts = MockPrompts( + { + r"Would you like to open the job run URL .*": "yes", + } + ) + installation = MockInstallation( + { + "state.json": { + "resources": {"jobs": {RECON_JOB_NAME: "1234"}}, + "version": 1, + }, + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + "recon_config_snowflake_abc_all.json": { + "source_catalog": "abc", + "source_schema": "def", + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + }, + } + ) + install_state = InstallState.from_installation(installation) + wait = Mock() + wait.run_id = "rid" + ws.jobs.run_now.return_value = wait + + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + recon_runner.run() + ws.jobs.run_now.assert_called_once_with(1234, job_parameters={'operation_name': 'reconcile'}) + + +def test_run_with_failed_execution(): + ws = create_autospec(WorkspaceClient) + installation = MockInstallation( + { + "state.json": { + "resources": {"jobs": {RECON_JOB_NAME: "1234"}}, + "version": 1, + }, + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + "recon_config_snowflake_abc_all.json": { + "source_catalog": "abc", + "source_schema": "def", + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + }, + } + ) + install_state = InstallState.from_installation(installation) + prompts = MockPrompts({}) + wait = Mock() + wait.run_id = None + ws.jobs.run_now.return_value = wait + + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + with pytest.raises(SystemExit): + recon_runner.run() + ws.jobs.run_now.assert_called_once_with(1234, job_parameters={'operation_name': 'reconcile'}) + + +def test_aggregates_reconcile_run_with_job_id_in_state(monkeypatch): + monkeypatch.setattr("webbrowser.open", lambda url: None) + ws = create_autospec(WorkspaceClient) + prompts = MockPrompts( + { + r"Would you like to open the job run URL .*": "yes", + } + ) + state = { + "resources": {"jobs": {RECON_JOB_NAME: "1234"}}, + "version": 1, + } + + reconcile = { + "data_source": "snowflake", + "database_config": { + "source_catalog": "abc", + "source_schema": "def", + "target_catalog": "tgt", + "target_schema": "sch", + }, + "report_type": "all", + "secret_scope": "remorph", + "tables": { + "filter_type": "all", + "tables_list": ["*"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + } + + sf_recon_config = { + "source_catalog": "abc", + "source_schema": "def", + "tables": [ + { + "aggregates": [ + {"type": "MIN", "agg_columns": ["discount"], "group_by_columns": ["p_id"]}, + {"type": "AVG", "agg_columns": ["discount"], "group_by_columns": ["p_id"]}, + {"type": "MAX", "agg_columns": ["p_id"], "group_by_columns": ["creation_date"]}, + {"type": "MAX", "agg_columns": ["p_name"]}, + {"type": "SUM", "agg_columns": ["p_id"]}, + {"type": "MAX", "agg_columns": ["creation_date"]}, + {"type": "MAX", "agg_columns": ["p_id"], "group_by_columns": ["creation_date"]}, + ], + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + } + ], + "target_catalog": "tgt", + "target_schema": "sch", + } + + installation = MockInstallation( + { + "state.json": state, + "reconcile.yml": reconcile, + "recon_config_snowflake_abc_all.json": sf_recon_config, + } + ) + install_state = InstallState.from_installation(installation) + wait = Mock() + wait.run_id = "rid" + ws.jobs.run_now.return_value = wait + + recon_runner = ReconcileRunner(ws, installation, install_state, prompts) + recon_runner.run(operation_name="aggregates-reconcile") + ws.jobs.run_now.assert_called_once_with(1234, job_parameters={'operation_name': 'aggregates-reconcile'}) diff --git a/tests/unit/reconcile/test_schema_compare.py b/tests/unit/reconcile/test_schema_compare.py new file mode 100644 index 0000000000..4aa78bb5ab --- /dev/null +++ b/tests/unit/reconcile/test_schema_compare.py @@ -0,0 +1,287 @@ +import pytest + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.recon_config import ColumnMapping, Schema, Table +from databricks.labs.remorph.reconcile.schema_compare import SchemaCompare + + +def snowflake_databricks_schema(): + src_schema = [ + Schema("col_boolean", "boolean"), + Schema("col_char", "varchar(1)"), + Schema("col_varchar", "varchar(16777216)"), + Schema("col_string", "varchar(16777216)"), + Schema("col_text", "varchar(16777216)"), + Schema("col_binary", "binary(8388608)"), + Schema("col_varbinary", "binary(8388608)"), + Schema("col_int", "number(38,0)"), + Schema("col_bigint", "number(38,0)"), + Schema("col_smallint", "number(38,0)"), + Schema("col_float", "float"), + Schema("col_float4", "float"), + Schema("col_double", "float"), + Schema("col_real", "float"), + Schema("col_date", "date"), + Schema("col_time", "time(9)"), + Schema("col_timestamp", "timestamp_ntz(9)"), + Schema("col_timestamp_ltz", "timestamp_ltz(9)"), + Schema("col_timestamp_ntz", "timestamp_ntz(9)"), + Schema("col_timestamp_tz", "timestamp_tz(9)"), + Schema("col_variant", "variant"), + Schema("col_object", "object"), + Schema("col_array", "array"), + Schema("col_geography", "geography"), + Schema("col_num10", "number(10,1)"), + Schema("col_dec", "number(20,2)"), + Schema("col_numeric_2", "numeric(38,0)"), + Schema("dummy", "string"), + ] + tgt_schema = [ + Schema("col_boolean", "boolean"), + Schema("char", "string"), + Schema("col_varchar", "string"), + Schema("col_string", "string"), + Schema("col_text", "string"), + Schema("col_binary", "binary"), + Schema("col_varbinary", "binary"), + Schema("col_int", "decimal(38,0)"), + Schema("col_bigint", "decimal(38,0)"), + Schema("col_smallint", "decimal(38,0)"), + Schema("col_float", "double"), + Schema("col_float4", "double"), + Schema("col_double", "double"), + Schema("col_real", "double"), + Schema("col_date", "date"), + Schema("col_time", "timestamp"), + Schema("col_timestamp", "timestamp_ntz"), + Schema("col_timestamp_ltz", "timestamp"), + Schema("col_timestamp_ntz", "timestamp_ntz"), + Schema("col_timestamp_tz", "timestamp"), + Schema("col_variant", "variant"), + Schema("col_object", "string"), + Schema("array_col", "array"), + Schema("col_geography", "string"), + Schema("col_num10", "decimal(10,1)"), + Schema("col_dec", "decimal(20,1)"), + Schema("col_numeric_2", "decimal(38,0)"), + ] + return src_schema, tgt_schema + + +def databricks_databricks_schema(): + src_schema = [ + Schema("col_boolean", "boolean"), + Schema("col_char", "string"), + Schema("col_int", "int"), + Schema("col_string", "string"), + Schema("col_bigint", "int"), + Schema("col_num10", "decimal(10,1)"), + Schema("col_dec", "decimal(20,2)"), + Schema("col_numeric_2", "decimal(38,0)"), + Schema("dummy", "string"), + ] + tgt_schema = [ + Schema("col_boolean", "boolean"), + Schema("char", "string"), + Schema("col_int", "int"), + Schema("col_string", "string"), + Schema("col_bigint", "int"), + Schema("col_num10", "decimal(10,1)"), + Schema("col_dec", "decimal(20,1)"), + Schema("col_numeric_2", "decimal(38,0)"), + ] + return src_schema, tgt_schema + + +def oracle_databricks_schema(): + src_schema = [ + Schema("col_xmltype", "xmltype"), + Schema("col_char", "char(1)"), + Schema("col_nchar", "nchar(255)"), + Schema("col_varchar", "varchar2(255)"), + Schema("col_varchar2", "varchar2(255)"), + Schema("col_nvarchar", "nvarchar2(255)"), + Schema("col_nvarchar2", "nvarchar2(255)"), + Schema("col_character", "char(255)"), + Schema("col_clob", "clob"), + Schema("col_nclob", "nclob"), + Schema("col_long", "long"), + Schema("col_number", "number(10,2)"), + Schema("col_float", "float"), + Schema("col_binary_float", "binary_float"), + Schema("col_binary_double", "binary_double"), + Schema("col_date", "date"), + Schema("col_timestamp", "timestamp(6)"), + Schema("col_time_with_tz", "timestamp(6) with time zone"), + Schema("col_timestamp_with_tz", "timestamp(6) with time zone"), + Schema("col_timestamp_with_local_tz", "timestamp(6) with local time zone"), + Schema("col_blob", "blob"), + Schema("col_rowid", "rowid"), + Schema("col_urowid", "urowid"), + Schema("col_anytype", "anytype"), + Schema("col_anydata", "anydata"), + Schema("col_anydataset", "anydataset"), + Schema("dummy", "string"), + ] + + tgt_schema = [ + Schema("col_xmltype", "string"), + Schema("char", "string"), + Schema("col_nchar", "string"), + Schema("col_varchar", "string"), + Schema("col_varchar2", "string"), + Schema("col_nvarchar", "string"), + Schema("col_nvarchar2", "string"), + Schema("col_character", "string"), + Schema("col_clob", "string"), + Schema("col_nclob", "string"), + Schema("col_long", "string"), + Schema("col_number", "DECIMAL(10,2)"), + Schema("col_float", "double"), + Schema("col_binary_float", "double"), + Schema("col_binary_double", "double"), + Schema("col_date", "date"), + Schema("col_timestamp", "timestamp"), + Schema("col_time_with_tz", "timestamp"), + Schema("col_timestamp_with_tz", "timestamp"), + Schema("col_timestamp_with_local_tz", "timestamp"), + Schema("col_blob", "binary"), + Schema("col_rowid", "string"), + Schema("col_urowid", "string"), + Schema("col_anytype", "string"), + Schema("col_anydata", "string"), + Schema("col_anydataset", "string"), + ] + + return src_schema, tgt_schema + + +@pytest.fixture +def schemas(): + return { + "snowflake_databricks_schema": snowflake_databricks_schema(), + "databricks_databricks_schema": databricks_databricks_schema(), + "oracle_databricks_schema": oracle_databricks_schema(), + } + + +def test_snowflake_schema_compare(schemas, mock_spark): + src_schema, tgt_schema = schemas["snowflake_databricks_schema"] + spark = mock_spark + table_conf = Table( + source_name="supplier", + target_name="supplier", + drop_columns=["dummy"], + column_mapping=[ + ColumnMapping(source_name="col_char", target_name="char"), + ColumnMapping(source_name="col_array", target_name="array_col"), + ], + ) + + schema_compare_output = SchemaCompare(spark).compare( + src_schema, + tgt_schema, + get_dialect("snowflake"), + table_conf, + ) + df = schema_compare_output.compare_df + + assert not schema_compare_output.is_valid + assert df.count() == 27 + assert df.filter("is_valid = 'true'").count() == 25 + assert df.filter("is_valid = 'false'").count() == 2 + + +def test_databricks_schema_compare(schemas, mock_spark): + src_schema, tgt_schema = schemas["databricks_databricks_schema"] + spark = mock_spark + table_conf = Table( + source_name="supplier", + target_name="supplier", + select_columns=[ + "col_boolean", + "col_char", + "col_int", + "col_string", + "col_bigint", + "col_num10", + "col_dec", + "col_numeric_2", + ], + column_mapping=[ + ColumnMapping(source_name="col_char", target_name="char"), + ColumnMapping(source_name="col_array", target_name="array_col"), + ], + ) + schema_compare_output = SchemaCompare(spark).compare( + src_schema, + tgt_schema, + get_dialect("databricks"), + table_conf, + ) + df = schema_compare_output.compare_df + + assert not schema_compare_output.is_valid + assert df.count() == 8 + assert df.filter("is_valid = 'true'").count() == 7 + assert df.filter("is_valid = 'false'").count() == 1 + + +def test_oracle_schema_compare(schemas, mock_spark): + src_schema, tgt_schema = schemas["oracle_databricks_schema"] + spark = mock_spark + table_conf = Table( + source_name="supplier", + target_name="supplier", + drop_columns=["dummy"], + column_mapping=[ + ColumnMapping(source_name="col_char", target_name="char"), + ColumnMapping(source_name="col_array", target_name="array_col"), + ], + ) + schema_compare_output = SchemaCompare(spark).compare( + src_schema, + tgt_schema, + get_dialect("oracle"), + table_conf, + ) + df = schema_compare_output.compare_df + + assert schema_compare_output.is_valid + assert df.count() == 26 + assert df.filter("is_valid = 'true'").count() == 26 + assert df.filter("is_valid = 'false'").count() == 0 + + +def test_schema_compare(mock_spark): + src_schema = [ + Schema("col1", "int"), + Schema("col2", "string"), + ] + tgt_schema = [ + Schema("col1", "int"), + Schema("col2", "string"), + ] + spark = mock_spark + table_conf = Table( + source_name="supplier", + target_name="supplier", + drop_columns=["dummy"], + column_mapping=[ + ColumnMapping(source_name="col_char", target_name="char"), + ColumnMapping(source_name="col_array", target_name="array_col"), + ], + ) + + schema_compare_output = SchemaCompare(spark).compare( + src_schema, + tgt_schema, + get_dialect("databricks"), + table_conf, + ) + df = schema_compare_output.compare_df + + assert schema_compare_output.is_valid + assert df.count() == 2 + assert df.filter("is_valid = 'true'").count() == 2 + assert df.filter("is_valid = 'false'").count() == 0 diff --git a/tests/unit/reconcile/test_source_adapter.py b/tests/unit/reconcile/test_source_adapter.py new file mode 100644 index 0000000000..68d715d424 --- /dev/null +++ b/tests/unit/reconcile/test_source_adapter.py @@ -0,0 +1,57 @@ +from unittest.mock import create_autospec + +import pytest + +from databricks.connect import DatabricksSession +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource +from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource +from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource +from databricks.labs.remorph.reconcile.connectors.source_adapter import create_adapter +from databricks.sdk import WorkspaceClient + + +def test_create_adapter_for_snowflake_dialect(): + spark = create_autospec(DatabricksSession) + engine = get_dialect("snowflake") + ws = create_autospec(WorkspaceClient) + scope = "scope" + + data_source = create_adapter(engine, spark, ws, scope) + snowflake_data_source = SnowflakeDataSource(engine, spark, ws, scope).__class__ + + assert isinstance(data_source, snowflake_data_source) + + +def test_create_adapter_for_oracle_dialect(): + spark = create_autospec(DatabricksSession) + engine = get_dialect("oracle") + ws = create_autospec(WorkspaceClient) + scope = "scope" + + data_source = create_adapter(engine, spark, ws, scope) + oracle_data_source = OracleDataSource(engine, spark, ws, scope).__class__ + + assert isinstance(data_source, oracle_data_source) + + +def test_create_adapter_for_databricks_dialect(): + spark = create_autospec(DatabricksSession) + engine = get_dialect("databricks") + ws = create_autospec(WorkspaceClient) + scope = "scope" + + data_source = create_adapter(engine, spark, ws, scope) + databricks_data_source = DatabricksDataSource(engine, spark, ws, scope).__class__ + + assert isinstance(data_source, databricks_data_source) + + +def test_raise_exception_for_unknown_dialect(): + spark = create_autospec(DatabricksSession) + engine = get_dialect("trino") + ws = create_autospec(WorkspaceClient) + scope = "scope" + + with pytest.raises(ValueError, match=f"Unsupported source type --> {engine}"): + create_adapter(engine, spark, ws, scope) diff --git a/tests/unit/reconcile/test_table.py b/tests/unit/reconcile/test_table.py new file mode 100644 index 0000000000..4316e98eec --- /dev/null +++ b/tests/unit/reconcile/test_table.py @@ -0,0 +1,21 @@ +from databricks.labs.remorph.reconcile.recon_config import Filters + + +def test_table_column_mapping(table_conf_mock): + table_conf = table_conf_mock( + join_columns=["s_suppkey", "s_nationkey"], + filters=Filters(source="s_nationkey=1"), + ) + + assert table_conf.to_src_col_map is None + assert table_conf.to_src_col_map is None + + +def test_table_select_columns(table_conf_mock, table_schema): + schema, _ = table_schema + table_conf = table_conf_mock( + select_columns=["s_nationkey", "s_suppkey"], + ) + + assert table_conf.get_select_columns(schema, "source") == {"s_nationkey", "s_suppkey"} + assert len(table_conf.get_select_columns(schema, "source")) == 2 diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000000..5efbf63a20 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,545 @@ +import datetime +import io +from unittest.mock import create_autospec, patch + +import pytest +import yaml + +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.remorph import cli +from databricks.labs.remorph.config import TranspileConfig +from databricks.labs.remorph.helpers.recon_config_utils import ReconConfigPrompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.labs.blueprint.installation import MockInstallation +from databricks.sdk.config import Config + + +@pytest.fixture +def mock_workspace_client_cli(): + state = { + "/Users/foo/.remorph/config.yml": yaml.dump( + { + 'version': 1, + 'catalog_name': 'transpiler', + 'schema_name': 'remorph', + 'source_dialect': 'snowflake', + 'sdk_config': {'cluster_id': 'test_cluster'}, + } + ), + "/Users/foo/.remorph/recon_config.yml": yaml.dump( + { + 'version': 1, + 'source_schema': "src_schema", + 'target_catalog': "src_catalog", + 'target_schema': "tgt_schema", + 'tables': [ + { + "source_name": 'src_table', + "target_name": 'tgt_table', + "join_columns": ['id'], + "jdbc_reader_options": None, + "select_columns": None, + "drop_columns": None, + "column_mapping": None, + "transformations": None, + "thresholds": None, + "filters": None, + } + ], + 'source_catalog': "src_catalog", + } + ), + } + + def download(path: str) -> io.StringIO | io.BytesIO: + if path not in state: + raise NotFound(path) + if ".csv" in path: + return io.BytesIO(state[path].encode('utf-8')) + return io.StringIO(state[path]) + + workspace_client = create_autospec(WorkspaceClient) + workspace_client.current_user.me().user_name = "foo" + workspace_client.workspace.download = download + config = create_autospec(Config) + config.warehouse_id = None + config.cluster_id = None + workspace_client.config = config + return workspace_client + + +@pytest.fixture +def mock_installation_reconcile(): + return MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "database_config": { + "source_catalog": "sf_test", + "source_schema": "mock", + "target_catalog": "hive_metastore", + "target_schema": "default", + }, + "report_type": "include", + "secret_scope": "remorph_snowflake", + "tables": { + "filter_type": "include", + "tables_list": ["product"], + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "job_id": "54321", + "version": 1, + }, + "recon_config_snowflake_sf_test_include.json": { + "source_catalog": "sf_functions_test", + "source_schema": "mock", + "tables": [ + { + "column_mapping": [ + {"source_name": "p_id", "target_name": "product_id"}, + {"source_name": "p_name", "target_name": "product_name"}, + ], + "drop_columns": None, + "filters": None, + "jdbc_reader_options": None, + "join_columns": ["p_id"], + "select_columns": ["p_id", "p_name"], + "source_name": "product", + "target_name": "product_delta", + "thresholds": None, + "transformations": [ + { + "column_name": "creation_date", + "source": "creation_date", + "target": "to_date(creation_date,'yyyy-mm-dd')", + } + ], + } + ], + "target_catalog": "hive_metastore", + "target_schema": "default", + }, + } + ) + + +@pytest.fixture +def temp_dirs_for_lineage(tmpdir): + input_dir = tmpdir.mkdir("input") + output_dir = tmpdir.mkdir("output") + + sample_sql_file = input_dir.join("sample.sql") + sample_sql_content = """ + create table table1 select * from table2 inner join + table3 on table2.id = table3.id where table2.id in (select id from table4); + create table table2 select * from table4; + create table table5 select * from table3 join table4 on table3.id = table4.id; + """ + sample_sql_file.write(sample_sql_content) + + return input_dir, output_dir + + +def test_transpile_with_missing_installation(): + workspace_client = create_autospec(WorkspaceClient) + with ( + patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, + pytest.raises(SystemExit), + ): + mock_app_context.return_value.workspace_client = workspace_client + mock_app_context.return_value.transpile_config = None + cli.transpile( + workspace_client, + "snowflake", + "/path/to/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + + +def test_transpile_with_no_sdk_config(): + workspace_client = create_autospec(WorkspaceClient) + with ( + patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, + patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("os.path.exists", return_value=True), + ): + default_config = TranspileConfig( + sdk_config=None, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ) + mock_app_context.return_value.transpile_config = default_config + mock_app_context.return_value.workspace_client = workspace_client + cli.transpile( + workspace_client, + "snowflake", + "/path/to/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + mock_transpile.assert_called_once_with( + workspace_client, + TranspileConfig( + sdk_config=None, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ), + ) + + +def test_transpile_with_warehouse_id_in_sdk_config(): + workspace_client = create_autospec(WorkspaceClient) + with ( + patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, + patch("os.path.exists", return_value=True), + patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + ): + sdk_config = {"warehouse_id": "w_id"} + default_config = TranspileConfig( + sdk_config=sdk_config, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ) + mock_app_context.return_value.workspace_client = workspace_client + mock_app_context.return_value.transpile_config = default_config + cli.transpile( + workspace_client, + "snowflake", + "/path/to/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + mock_transpile.assert_called_once_with( + workspace_client, + TranspileConfig( + sdk_config=sdk_config, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ), + ) + + +def test_transpile_with_cluster_id_in_sdk_config(): + workspace_client = create_autospec(WorkspaceClient) + with ( + patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, + patch("os.path.exists", return_value=True), + patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + ): + sdk_config = {"cluster_id": "c_id"} + default_config = TranspileConfig( + sdk_config=sdk_config, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ) + mock_app_context.return_value.workspace_client = workspace_client + mock_app_context.return_value.transpile_config = default_config + cli.transpile( + workspace_client, + "snowflake", + "/path/to/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + mock_transpile.assert_called_once_with( + workspace_client, + TranspileConfig( + sdk_config=sdk_config, + source_dialect="snowflake", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + skip_validation=True, + catalog_name="my_catalog", + schema_name="my_schema", + mode="current", + ), + ) + + +def test_transpile_with_invalid_dialect(mock_workspace_client_cli): + with pytest.raises(Exception, match="Error: Invalid value for '--source-dialect'"): + cli.transpile( + mock_workspace_client_cli, + "invalid_dialect", + "/path/to/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + + +def test_transpile_with_invalid_skip_validation(mock_workspace_client_cli): + with ( + patch("os.path.exists", return_value=True), + pytest.raises(Exception, match="Error: Invalid value for '--skip-validation'"), + ): + cli.transpile( + mock_workspace_client_cli, + "snowflake", + "/path/to/sql/file.sql", + "/path/to/output", + "invalid_value", + "my_catalog", + "my_schema", + "current", + ) + + +def test_transpile_with_invalid_input_source(mock_workspace_client_cli): + with ( + patch("os.path.exists", return_value=False), + pytest.raises(Exception, match="Error: Invalid value for '--input-source'"), + ): + cli.transpile( + mock_workspace_client_cli, + "snowflake", + "/path/to/invalid/sql/file.sql", + "/path/to/output", + "true", + "my_catalog", + "my_schema", + "current", + ) + + +def test_transpile_with_valid_input(mock_workspace_client_cli): + source = "snowflake" + input_sql = "/path/to/sql/file.sql" + output_folder = "/path/to/output" + skip_validation = "true" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "current" + sdk_config = {'cluster_id': 'test_cluster'} + + with ( + patch("os.path.exists", return_value=True), + patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + ): + cli.transpile( + mock_workspace_client_cli, + source, + input_sql, + output_folder, + skip_validation, + catalog_name, + schema_name, + mode, + ) + mock_transpile.assert_called_once_with( + mock_workspace_client_cli, + TranspileConfig( + sdk_config=sdk_config, + source_dialect=source, + input_source=input_sql, + output_folder=output_folder, + skip_validation=True, + catalog_name=catalog_name, + schema_name=schema_name, + mode=mode, + ), + ) + + +def test_transpile_empty_output_folder(mock_workspace_client_cli): + source_dialect = "snowflake" + input_sql = "/path/to/sql/file2.sql" + output_folder = "" + skip_validation = "false" + catalog_name = "my_catalog" + schema_name = "my_schema" + + mode = "current" + sdk_config = {'cluster_id': 'test_cluster'} + + with ( + patch("os.path.exists", return_value=True), + patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + ): + cli.transpile( + mock_workspace_client_cli, + source_dialect, + input_sql, + output_folder, + skip_validation, + catalog_name, + schema_name, + mode, + ) + mock_transpile.assert_called_once_with( + mock_workspace_client_cli, + TranspileConfig( + sdk_config=sdk_config, + source_dialect=source_dialect, + input_source=input_sql, + output_folder="", + skip_validation=False, + catalog_name=catalog_name, + schema_name=schema_name, + mode=mode, + ), + ) + + +def test_transpile_with_invalid_mode(mock_workspace_client_cli): + with ( + patch("os.path.exists", return_value=True), + pytest.raises(Exception, match="Error: Invalid value for '--mode':"), + ): + source = "snowflake" + input_sql = "/path/to/sql/file2.sql" + output_folder = "" + skip_validation = "false" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "preview" + + cli.transpile( + mock_workspace_client_cli, + source, + input_sql, + output_folder, + skip_validation, + catalog_name, + schema_name, + mode, + ) + + +def test_generate_lineage_valid_input(temp_dirs_for_lineage, mock_workspace_client_cli): + input_dir, output_dir = temp_dirs_for_lineage + cli.generate_lineage( + mock_workspace_client_cli, + source_dialect="snowflake", + input_source=str(input_dir), + output_folder=str(output_dir), + ) + + date_str = datetime.datetime.now().strftime("%d%m%y") + output_filename = f"lineage_{date_str}.dot" + output_file = output_dir.join(output_filename) + assert output_file.check(file=1) + expected_output = """ + flowchart TD + Table1 --> Table2 + Table1 --> Table3 + Table1 --> Table4 + Table2 --> Table4 + Table3 + Table4 + Table5 --> Table3 + Table5 --> Table4 + """ + actual_output = output_file.read() + assert actual_output.strip() == expected_output.strip() + + +def test_generate_lineage_with_invalid_dialect(mock_workspace_client_cli): + with pytest.raises(Exception, match="Error: Invalid value for '--source-dialect'"): + cli.generate_lineage( + mock_workspace_client_cli, + source_dialect="invalid_dialect", + input_source="/path/to/sql/file.sql", + output_folder="/path/to/output", + ) + + +def test_generate_lineage_invalid_input_source(mock_workspace_client_cli): + with ( + patch("os.path.exists", return_value=False), + pytest.raises(Exception, match="Error: Invalid value for '--input-source'"), + ): + cli.generate_lineage( + mock_workspace_client_cli, + source_dialect="snowflake", + input_source="/path/to/invalid/sql/file.sql", + output_folder="/path/to/output", + ) + + +def test_generate_lineage_invalid_output_dir(mock_workspace_client_cli, monkeypatch): + input_sql = "/path/to/sql/file.sql" + output_folder = "/path/to/output" + monkeypatch.setattr("os.path.exists", lambda x: x == input_sql) + with pytest.raises(Exception, match="Error: Invalid value for '--output-folder'"): + cli.generate_lineage( + mock_workspace_client_cli, + source_dialect="snowflake", + input_source=input_sql, + output_folder=output_folder, + ) + + +def test_configure_secrets_databricks(mock_workspace_client): + source_dict = {"databricks": "0", "netezza": "1", "oracle": "2", "snowflake": "3"} + prompts = MockPrompts( + { + r"Select the source": source_dict["databricks"], + } + ) + + recon_conf = ReconConfigPrompts(mock_workspace_client, prompts) + recon_conf.prompt_source() + + recon_conf.prompt_and_save_connection_details() + + +def test_cli_configure_secrets_config(mock_workspace_client): + with patch("databricks.labs.remorph.cli.ReconConfigPrompts") as mock_recon_config: + cli.configure_secrets(mock_workspace_client) + mock_recon_config.assert_called_once_with(mock_workspace_client) + + +def test_cli_reconcile(mock_workspace_client): + with patch("databricks.labs.remorph.reconcile.runner.ReconcileRunner.run", return_value=True): + cli.reconcile(mock_workspace_client) + + +def test_cli_aggregates_reconcile(mock_workspace_client): + with patch("databricks.labs.remorph.reconcile.runner.ReconcileRunner.run", return_value=True): + cli.aggregates_reconcile(mock_workspace_client) diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py new file mode 100644 index 0000000000..d3791e15ea --- /dev/null +++ b/tests/unit/test_install.py @@ -0,0 +1,1010 @@ +from unittest.mock import create_autospec, patch + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import iam +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.remorph.config import RemorphConfigs, ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.deployment.configurator import ResourceConfigurator +from databricks.labs.remorph.deployment.installation import WorkspaceInstallation +from databricks.labs.remorph.install import WorkspaceInstaller, MODULES +from databricks.labs.remorph.config import TranspileConfig +from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.labs.remorph.config import SQLGLOT_DIALECTS +from databricks.labs.remorph.reconcile.constants import ReconSourceType, ReconReportType + +RECONCILE_DATA_SOURCES = sorted([source_type.value for source_type in ReconSourceType]) +RECONCILE_REPORT_TYPES = sorted([report_type.value for report_type in ReconReportType]) + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + return w + + +def test_workspace_installer_run_raise_error_in_dbr(ws): + ctx = ApplicationContext(ws) + environ = {"DATABRICKS_RUNTIME_VERSION": "8.3.x-scala2.12"} + with pytest.raises(SystemExit): + WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + environ=environ, + ) + + +def test_workspace_installer_run_install_not_called_in_test(ws): + ws_installation = create_autospec(WorkspaceInstallation) + ctx = ApplicationContext(ws) + ctx.replace( + product_info=ProductInfo.for_testing(RemorphConfigs), + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=ws_installation, + ) + + provided_config = RemorphConfigs() + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + returned_config = workspace_installer.run(config=provided_config) + assert returned_config == provided_config + ws_installation.install.assert_not_called() + + +def test_workspace_installer_run_install_called_with_provided_config(ws): + ws_installation = create_autospec(WorkspaceInstallation) + ctx = ApplicationContext(ws) + ctx.replace( + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=ws_installation, + ) + provided_config = RemorphConfigs() + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + returned_config = workspace_installer.run(config=provided_config) + assert returned_config == provided_config + ws_installation.install.assert_called_once_with(provided_config) + + +def test_configure_error_if_invalid_module_selected(ws): + ctx = ApplicationContext(ws) + ctx.replace( + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + ) + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + + with pytest.raises(ValueError): + workspace_installer.configure(module="invalid_module") + + +def test_workspace_installer_run_install_called_with_generated_config(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Do you want to override the existing installation?": "no", + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation() + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + workspace_installer.run() + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler", + "skip_validation": True, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_transpile_no_existing_installation(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Do you want to override the existing installation?": "no", + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation() + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + ) + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_morph_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + skip_validation=True, + catalog_name="remorph", + schema_name="transpiler", + mode="current", + ) + expected_config = RemorphConfigs(transpile=expected_morph_config) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler", + "skip_validation": True, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_transpile_installation_no_override(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Do you want to override the existing installation?": "no", + } + ) + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + installation=MockInstallation( + { + "config.yml": { + "source_dialect": "snowflake", + "catalog_name": "transpiler_test", + "input_source": "sf_queries", + "output_folder": "out_dir", + "schema_name": "convertor_test", + "sdk_config": { + "warehouse_id": "abc", + }, + "version": 1, + } + } + ), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + with pytest.raises(SystemExit): + workspace_installer.configure() + + +def test_configure_transpile_installation_config_error_continue_install(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Do you want to override the existing installation?": "no", + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation( + { + "config.yml": { + "source_name": "snowflake", # Invalid key + "catalog_name": "transpiler_test", + "input_source": "sf_queries", + "output_folder": "out_dir", + "schema_name": "convertor_test", + "sdk_config": { + "warehouse_id": "abc", + }, + "version": 1, + } + } + ) + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + ) + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_morph_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + skip_validation=True, + catalog_name="remorph", + schema_name="transpiler", + mode="current", + ) + expected_config = RemorphConfigs(transpile=expected_morph_config) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler", + "skip_validation": True, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +@patch("webbrowser.open") +def test_configure_transpile_installation_with_no_validation(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "yes", + } + ) + installation = MockInstallation() + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_morph_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + skip_validation=True, + catalog_name="remorph", + schema_name="transpiler", + mode="current", + ) + expected_config = RemorphConfigs(transpile=expected_morph_config) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler", + "skip_validation": True, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_transpile_installation_with_validation_and_cluster_id_in_config(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "yes", + r"Do you want to use SQL Warehouse for validation?": "no", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation() + ws.config.cluster_id = "1234" + + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph_test" + resource_configurator.prompt_for_schema_setup.return_value = "transpiler_test" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_config = RemorphConfigs( + transpile=TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + catalog_name="remorph_test", + schema_name="transpiler_test", + mode="current", + sdk_config={"cluster_id": "1234"}, + ) + ) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph_test", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler_test", + "sdk_config": {"cluster_id": "1234"}, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_transpile_installation_with_validation_and_cluster_id_from_prompt(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "yes", + r"Do you want to use SQL Warehouse for validation?": "no", + r"Enter a valid cluster_id to proceed": "1234", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation() + ws.config.cluster_id = None + + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph_test" + resource_configurator.prompt_for_schema_setup.return_value = "transpiler_test" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_config = RemorphConfigs( + transpile=TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + catalog_name="remorph_test", + schema_name="transpiler_test", + mode="current", + sdk_config={"cluster_id": "1234"}, + ) + ) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph_test", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler_test", + "sdk_config": {"cluster_id": "1234"}, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_transpile_installation_with_validation_and_warehouse_id_from_prompt(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "yes", + r"Do you want to use SQL Warehouse for validation?": "yes", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation() + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph_test" + resource_configurator.prompt_for_schema_setup.return_value = "transpiler_test" + resource_configurator.prompt_for_warehouse_setup.return_value = "w_id" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_config = RemorphConfigs( + transpile=TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + catalog_name="remorph_test", + schema_name="transpiler_test", + mode="current", + sdk_config={"warehouse_id": "w_id"}, + ) + ) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph_test", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler_test", + "sdk_config": {"warehouse_id": "w_id"}, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + +def test_configure_reconcile_installation_no_override(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("reconcile"), + r"Do you want to override the existing installation?": "no", + } + ) + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=create_autospec(WorkspaceInstallation), + installation=MockInstallation( + { + "reconcile.yml": { + "data_source": "snowflake", + "report_type": "all", + "secret_scope": "remorph_snowflake", + "database_config": { + "source_catalog": "snowflake_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + } + } + ), + ) + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + with pytest.raises(SystemExit): + workspace_installer.configure() + + +def test_configure_reconcile_installation_config_error_continue_install(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("reconcile"), + r"Select the Data Source": RECONCILE_DATA_SOURCES.index("oracle"), + r"Select the report type": RECONCILE_REPORT_TYPES.index("all"), + r"Enter Secret scope name to store .* connection details / secrets": "remorph_oracle", + r"Enter source database name for .*": "tpch_sf1000", + r"Enter target catalog name for Databricks": "tpch", + r"Enter target schema name for Databricks": "1000gb", + r"Open .* in the browser?": "no", + } + ) + installation = MockInstallation( + { + "reconcile.yml": { + "source_dialect": "oracle", # Invalid key + "report_type": "all", + "secret_scope": "remorph_oracle", + "database_config": { + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + } + } + ) + + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph" + resource_configurator.prompt_for_schema_setup.return_value = "reconcile" + resource_configurator.prompt_for_volume_setup.return_value = "reconcile_volume" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_config = RemorphConfigs( + reconcile=ReconcileConfig( + data_source="oracle", + report_type="all", + secret_scope="remorph_oracle", + database_config=DatabaseConfig( + source_schema="tpch_sf1000", + target_catalog="tpch", + target_schema="1000gb", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph", + schema="reconcile", + volume="reconcile_volume", + ), + ) + ) + assert config == expected_config + installation.assert_file_written( + "reconcile.yml", + { + "data_source": "oracle", + "report_type": "all", + "secret_scope": "remorph_oracle", + "database_config": { + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + ) + + +@patch("webbrowser.open") +def test_configure_reconcile_no_existing_installation(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("reconcile"), + r"Select the Data Source": RECONCILE_DATA_SOURCES.index("snowflake"), + r"Select the report type": RECONCILE_REPORT_TYPES.index("all"), + r"Enter Secret scope name to store .* connection details / secrets": "remorph_snowflake", + r"Enter source catalog name for .*": "snowflake_sample_data", + r"Enter source schema name for .*": "tpch_sf1000", + r"Enter target catalog name for Databricks": "tpch", + r"Enter target schema name for Databricks": "1000gb", + r"Open .* in the browser?": "yes", + } + ) + installation = MockInstallation() + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph" + resource_configurator.prompt_for_schema_setup.return_value = "reconcile" + resource_configurator.prompt_for_volume_setup.return_value = "reconcile_volume" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_config = RemorphConfigs( + reconcile=ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_schema="tpch_sf1000", + target_catalog="tpch", + target_schema="1000gb", + source_catalog="snowflake_sample_data", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph", + schema="reconcile", + volume="reconcile_volume", + ), + ) + ) + assert config == expected_config + installation.assert_file_written( + "reconcile.yml", + { + "data_source": "snowflake", + "report_type": "all", + "secret_scope": "remorph_snowflake", + "database_config": { + "source_catalog": "snowflake_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + ) + + +def test_configure_all_override_installation(ws): + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("all"), + r"Do you want to override the existing installation?": "yes", + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + r"Select the Data Source": RECONCILE_DATA_SOURCES.index("snowflake"), + r"Select the report type": RECONCILE_REPORT_TYPES.index("all"), + r"Enter Secret scope name to store .* connection details / secrets": "remorph_snowflake", + r"Enter source catalog name for .*": "snowflake_sample_data", + r"Enter source schema name for .*": "tpch_sf1000", + r"Enter target catalog name for Databricks": "tpch", + r"Enter target schema name for Databricks": "1000gb", + } + ) + installation = MockInstallation( + { + "config.yml": { + "source_dialect": "snowflake", + "catalog_name": "transpiler_test", + "input_source": "sf_queries", + "output_folder": "out_dir", + "schema_name": "convertor_test", + "sdk_config": { + "warehouse_id": "abc", + }, + "version": 1, + }, + "reconcile.yml": { + "data_source": "snowflake", + "report_type": "all", + "secret_scope": "remorph_snowflake", + "database_config": { + "source_catalog": "snowflake_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + } + ) + + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "remorph" + resource_configurator.prompt_for_schema_setup.return_value = "reconcile" + resource_configurator.prompt_for_volume_setup.return_value = "reconcile_volume" + + ctx = ApplicationContext(ws) + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + config = workspace_installer.configure() + expected_morph_config = TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + skip_validation=True, + catalog_name="remorph", + schema_name="transpiler", + mode="current", + ) + + expected_reconcile_config = ReconcileConfig( + data_source="snowflake", + report_type="all", + secret_scope="remorph_snowflake", + database_config=DatabaseConfig( + source_schema="tpch_sf1000", + target_catalog="tpch", + target_schema="1000gb", + source_catalog="snowflake_sample_data", + ), + metadata_config=ReconcileMetadataConfig( + catalog="remorph", + schema="reconcile", + volume="reconcile_volume", + ), + ) + expected_config = RemorphConfigs(transpile=expected_morph_config, reconcile=expected_reconcile_config) + assert config == expected_config + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "input_source": "/tmp/queries/snow", + "mode": "current", + "output_folder": "/tmp/queries/databricks", + "schema_name": "transpiler", + "skip_validation": True, + "source_dialect": "snowflake", + "version": 1, + }, + ) + + installation.assert_file_written( + "reconcile.yml", + { + "data_source": "snowflake", + "report_type": "all", + "secret_scope": "remorph_snowflake", + "database_config": { + "source_catalog": "snowflake_sample_data", + "source_schema": "tpch_sf1000", + "target_catalog": "tpch", + "target_schema": "1000gb", + }, + "metadata_config": { + "catalog": "remorph", + "schema": "reconcile", + "volume": "reconcile_volume", + }, + "version": 1, + }, + ) + + +def test_runs_upgrades_on_more_recent_version(ws): + installation = MockInstallation( + { + 'version.json': {'version': '0.3.0', 'wheel': '...', 'date': '...'}, + 'state.json': { + 'resources': { + 'dashboards': {'Reconciliation Metrics': 'abc'}, + 'jobs': {'Reconciliation Runner': '12345'}, + } + }, + 'config.yml': { + "source_dialect": "snowflake", + "catalog_name": "upgrades", + "input_source": "queries", + "output_folder": "out", + "schema_name": "test", + "sdk_config": { + "warehouse_id": "dummy", + }, + "version": 1, + }, + } + ) + + ctx = ApplicationContext(ws) + prompts = MockPrompts( + { + r"Select a module to configure:": MODULES.index("transpile"), + r"Do you want to override the existing installation?": "yes", + r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), + r"Enter input SQL path.*": "/tmp/queries/snow", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + } + ) + wheels = create_autospec(WheelsV2) + + mock_workspace_installation = create_autospec(WorkspaceInstallation) + + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=create_autospec(ResourceConfigurator), + workspace_installation=mock_workspace_installation, + wheels=wheels, + ) + + workspace_installer = WorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + + workspace_installer.run() + + mock_workspace_installation.install.assert_called_once_with( + RemorphConfigs( + transpile=TranspileConfig( + source_dialect="snowflake", + input_source="/tmp/queries/snow", + output_folder="/tmp/queries/databricks", + catalog_name="remorph", + schema_name="transpiler", + mode="current", + skip_validation=True, + ) + ) + ) diff --git a/tests/unit/test_uninstall.py b/tests/unit/test_uninstall.py new file mode 100644 index 0000000000..42380b7e2d --- /dev/null +++ b/tests/unit/test_uninstall.py @@ -0,0 +1,30 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import iam + +from databricks.labs.remorph import uninstall +from databricks.labs.remorph.config import RemorphConfigs +from databricks.labs.remorph.contexts.application import ApplicationContext +from databricks.labs.remorph.deployment.installation import WorkspaceInstallation + + +@pytest.fixture +def ws(): + w = create_autospec(WorkspaceClient) + w.current_user.me.side_effect = lambda: iam.User( + user_name="me@example.com", groups=[iam.ComplexValue(display="admins")] + ) + return w + + +def test_uninstaller_run(ws): + ws_installation = create_autospec(WorkspaceInstallation) + ctx = ApplicationContext(ws) + ctx.replace( + workspace_installation=ws_installation, + remorph_config=RemorphConfigs(), + ) + uninstall.run(ctx) + ws_installation.uninstall.assert_called_once() diff --git a/tests/unit/transpiler/__init__.py b/tests/unit/transpiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/transpiler/helpers/__init__.py b/tests/unit/transpiler/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/transpiler/helpers/functional_test_cases.py b/tests/unit/transpiler/helpers/functional_test_cases.py new file mode 100644 index 0000000000..a9c0614e0b --- /dev/null +++ b/tests/unit/transpiler/helpers/functional_test_cases.py @@ -0,0 +1,64 @@ +from sqlglot import ParseError, UnsupportedError +from sqlglot.errors import SqlglotError + + +class FunctionalTestFile: + """A single test case with the required options""" + + def __init__(self, databricks_sql: str, source: str, test_name: str, target: str): + self.databricks_sql = databricks_sql + self.source = source + self.test_name = test_name + self.target = target + + +class FunctionalTestFileWithExpectedException(FunctionalTestFile): + """A single test case with the required options and expected exceptions""" + + def __init__( + self, + databricks_sql: str, + source: str, + test_name: str, + expected_exception: SqlglotError, + target: str, + ): + self.expected_exception = expected_exception + super().__init__(databricks_sql, source, test_name, target) + + +# This dict has the details about which tests have expected exceptions (Either UnsupportedError or ParseError) + +expected_exceptions: dict[str, type[SqlglotError]] = { + 'test_regexp_replace_2': ParseError, + 'test_monthname_8': ParseError, + 'test_monthname_9': ParseError, + 'test_regexp_substr_2': ParseError, + 'test_try_cast_3': ParseError, + 'test_array_slice_3': UnsupportedError, + 'test_right_2': ParseError, + 'test_arrayagg_8': ParseError, + 'test_repeat_2': ParseError, + 'test_nvl2_3': ParseError, + 'test_array_contains_2': ParseError, + 'test_iff_2': ParseError, + 'test_nullif_2': ParseError, + 'test_timestampadd_6': ParseError, + 'test_dayname_4': ParseError, + 'test_date_part_2': ParseError, + 'test_approx_percentile_2': ParseError, + 'test_date_trunc_5': ParseError, + 'test_position_2': ParseError, + 'test_split_part_8': ParseError, + 'test_split_part_7': ParseError, + 'test_trunc_2': UnsupportedError, + 'test_to_number_9': UnsupportedError, + 'test_to_number_10': ParseError, + 'test_startswith_2': ParseError, + 'test_regexp_like_2': ParseError, + 'test_left_2': ParseError, + 'test_parse_json_extract_path_text_4': ParseError, + 'test_extract_2': ParseError, + 'test_approx_percentile_5': ParseError, + 'test_approx_percentile_7': ParseError, +} diff --git a/tests/unit/transpiler/test_bigquery.py b/tests/unit/transpiler/test_bigquery.py new file mode 100644 index 0000000000..a95fc0d4fd --- /dev/null +++ b/tests/unit/transpiler/test_bigquery.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory + +path = Path(__file__).parent / Path('../../resources/functional/bigquery/') +functional_tests = get_functional_test_files_from_directory(path, "bigquery", "databricks", False) +test_names = [f.test_name for f in functional_tests] + + +@pytest.mark.parametrize("sample", functional_tests, ids=test_names) +def test_bigquery(dialect_context, sample: FunctionalTestFile): + validate_source_transpile, _ = dialect_context + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"bigquery": sample.source}) diff --git a/tests/unit/transpiler/test_bigquery_expected_exceptions.py b/tests/unit/transpiler/test_bigquery_expected_exceptions.py new file mode 100644 index 0000000000..a62e8b9ec0 --- /dev/null +++ b/tests/unit/transpiler/test_bigquery_expected_exceptions.py @@ -0,0 +1,22 @@ +# Logic for processing test cases with expected exceptions, can be removed if not needed. +from pathlib import Path + +import pytest + +from ..conftest import ( + FunctionalTestFileWithExpectedException, + get_functional_test_files_from_directory, +) + +path_expected_exceptions = Path(__file__).parent / Path('../../resources/functional/bigquery_expected_exceptions/') +functional_tests_expected_exceptions = get_functional_test_files_from_directory( + path_expected_exceptions, "bigquery", "databricks", True +) +test_names_expected_exceptions = [f.test_name for f in functional_tests_expected_exceptions] + + +@pytest.mark.parametrize("sample", functional_tests_expected_exceptions, ids=test_names_expected_exceptions) +def test_bigquery_expected_exceptions(dialect_context, sample: FunctionalTestFileWithExpectedException): + validate_source_transpile, _ = dialect_context + with pytest.raises(type(sample.expected_exception)): + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"bigquery": sample.source}) diff --git a/tests/unit/transpiler/test_databricks.py b/tests/unit/transpiler/test_databricks.py new file mode 100644 index 0000000000..7341344468 --- /dev/null +++ b/tests/unit/transpiler/test_databricks.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory + +path = Path(__file__).parent / Path('../../resources/functional/snowflake/') +functional_tests = get_functional_test_files_from_directory(path, "snowflake", "databricks", False) +test_names = [f.test_name for f in functional_tests] + + +@pytest.mark.parametrize("sample", functional_tests, ids=test_names) +def test_databricks(dialect_context, sample: FunctionalTestFile): + validate_source_transpile, _ = dialect_context + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"snowflake": sample.source}, pretty=True) diff --git a/tests/unit/transpiler/test_databricks_expected_exceptions.py b/tests/unit/transpiler/test_databricks_expected_exceptions.py new file mode 100644 index 0000000000..4691979d66 --- /dev/null +++ b/tests/unit/transpiler/test_databricks_expected_exceptions.py @@ -0,0 +1,22 @@ +# Logic for processing test cases with expected exceptions, can be removed if not needed. +from pathlib import Path + +import pytest + +from ..conftest import ( + FunctionalTestFileWithExpectedException, + get_functional_test_files_from_directory, +) + +path_expected_exceptions = Path(__file__).parent / Path('../../resources/functional/snowflake_expected_exceptions/') +functional_tests_expected_exceptions = get_functional_test_files_from_directory( + path_expected_exceptions, "snowflake", "databricks", True +) +test_names_expected_exceptions = [f.test_name for f in functional_tests_expected_exceptions] + + +@pytest.mark.parametrize("sample", functional_tests_expected_exceptions, ids=test_names_expected_exceptions) +def test_databricks_expected_exceptions(dialect_context, sample: FunctionalTestFileWithExpectedException): + validate_source_transpile, _ = dialect_context + with pytest.raises(type(sample.expected_exception)): + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"snowflake": sample.source}) diff --git a/tests/unit/transpiler/test_execute.py b/tests/unit/transpiler/test_execute.py new file mode 100644 index 0000000000..a8b03a82eb --- /dev/null +++ b/tests/unit/transpiler/test_execute.py @@ -0,0 +1,493 @@ +import re +import shutil +from pathlib import Path +from unittest.mock import create_autospec, patch + +import pytest + +from databricks.connect import DatabricksSession +from databricks.labs.lsql.backends import MockBackend +from databricks.labs.lsql.core import Row +from databricks.labs.remorph.config import TranspileConfig, ValidationResult +from databricks.labs.remorph.helpers.file_utils import make_dir +from databricks.labs.remorph.helpers.validation import Validator +from databricks.labs.remorph.transpiler.execute import ( + transpile, + transpile_column_exp, + transpile_sql, +) +from databricks.sdk.core import Config + +# pylint: disable=unspecified-encoding + + +def safe_remove_dir(dir_path: Path): + if dir_path.exists(): + shutil.rmtree(dir_path) + + +def safe_remove_file(file_path: Path): + if file_path.exists(): + file_path.unlink() + + +def write_data_to_file(path: Path, content: str): + with path.open("w") as writable: + writable.write(content) + + +@pytest.fixture +def initial_setup(tmp_path: Path): + input_dir = tmp_path / "remorph_transpile" + query_1_sql_file = input_dir / "query1.sql" + query_2_sql_file = input_dir / "query2.sql" + query_3_sql_file = input_dir / "query3.sql" + stream_1_sql_file = input_dir / "stream1.sql" + call_center_ddl_file = input_dir / "call_center.ddl" + file_text = input_dir / "file.txt" + safe_remove_dir(input_dir) + make_dir(input_dir) + + query_1_sql = """select i_manufact, sum(ss_ext_sales_price) ext_price from date_dim, store_sales where + d_date_sk = ss_sold_date_sk and substr(ca_zip,1,5) <> substr(s_zip,1,5) group by i_manufact order by i_manufact + limit 100 ;""" + + call_center_ddl = """create table call_center + ( + cc_call_center_sk int , + cc_call_center_id varchar(16) + ) + + CLUSTER BY(cc_call_center_sk) + """ + + query_2_sql = """select wswscs.d_week_seq d_week_seq1,sun_sales sun_sales1,mon_sales mon_sales1 from wswscs, + date_dim where date_dim.d_week_seq = wswscs.d_week_seq and d_year = 2001""" + + query_3_sql = """with wscs as + (select sold_date_sk + ,sales_price + from (select ws_sold_date_sk sold_date_sk + ,ws_ext_sales_price sales_price + from web_sales + union all + select cs_sold_date_sk sold_date_sk + ,cs_ext_sales_price sales_price + from catalog_sales)), + wswscs as + (select d_week_seq, + sum(case when (d_day_name='Sunday') then sales_price else null end) sun_sales, + sum(case when (d_day_name='Monday') then sales_price else null end) mon_sales, + sum(case when (d_day_name='Tuesday') then sales_price else null end) tue_sales, + sum(case when (d_day_name='Wednesday') then sales_price else null end) wed_sales, + sum(case when (d_day_name='Thursday') then sales_price else null end) thu_sales, + sum(case when (d_day_name='Friday') then sales_price else null end) fri_sales, + sum(case when (d_day_name='Saturday') then sales_price else null end) sat_sales + from wscs + ,date_dim + where d_date_sk = sold_date_sk + group by d_week_seq) + select d_week_seq1 + ,round(sun_sales1/sun_sales2,2) + ,round(mon_sales1/mon_sales2,2) + ,round(tue_sales1/tue_sales2,2) + ,round(wed_sales1/wed_sales2,2) + ,round(thu_sales1/thu_sales2,2) + ,round(fri_sales1/fri_sales2,2) + ,round(sat_sales1/sat_sales2,2) + from + (select wswscs.d_week_seq d_week_seq1 + ,sun_sales sun_sales1 + ,mon_sales mon_sales1 + ,tue_sales tue_sales1 + ,wed_sales wed_sales1 + ,thu_sales thu_sales1 + ,fri_sales fri_sales1 + ,sat_sales sat_sales1 + from wswscs,date_dim + where date_dim.d_week_seq = wswscs.d_week_seq and + d_year = 2001) y, + (select wswscs.d_week_seq d_week_seq2 + ,sun_sales sun_sales2 + ,mon_sales mon_sales2 + ,tue_sales tue_sales2 + ,wed_sales wed_sales2 + ,thu_sales thu_sales2 + ,fri_sales fri_sales2 + ,sat_sales sat_sales2 + from wswscs + ,date_dim + where date_dim.d_week_seq = wswscs.d_week_seq2 and + d_year = 2001+1) z + where d_week_seq1=d_week_seq2-53 + order by d_week_seq1; + """ + + stream_1_sql = """CREATE STREAM unsupported_stream AS SELECT * FROM some_table;""" + + write_data_to_file(query_1_sql_file, query_1_sql) + write_data_to_file(call_center_ddl_file, call_center_ddl) + write_data_to_file(query_2_sql_file, query_2_sql) + write_data_to_file(query_3_sql_file, query_3_sql) + write_data_to_file(stream_1_sql_file, stream_1_sql) + write_data_to_file(file_text, "This is a test file") + + return input_dir + + +def test_with_dir_skip_validation(initial_setup, mock_workspace_client): + input_dir = initial_setup + config = TranspileConfig( + input_source=str(input_dir), + output_folder="None", + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + + # call morph + with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): + status = transpile(mock_workspace_client, config) + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 6, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 5, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 1 + ), "no_of_sql_failed_while_validating does not match expected value" + assert stat["error_log_file"], "error_log_file is None or empty" + assert Path(stat["error_log_file"]).name.startswith("err_") and Path(stat["error_log_file"]).name.endswith( + ".lst" + ), "error_log_file does not match expected pattern 'err_*.lst'" + + expected_file_name = f"{input_dir}/query3.sql" + expected_exception = f"Unsupported operation found in file {input_dir}/query3.sql." + pattern = r"ValidationError\(file_name='(?P[^']+)', exception='(?P[^']+)'\)" + + with open(Path(status[0]["error_log_file"])) as file: + for line in file: + # Skip empty lines + if line.strip() == "": + continue + + match = re.match(pattern, line) + + if match: + # Extract information using group names from the pattern + error_info = match.groupdict() + # Perform assertions + assert error_info["file_name"] == expected_file_name, "File name does not match the expected value" + assert expected_exception in error_info["exception"], "Exception does not match the expected value" + else: + print("No match found.") + # cleanup + safe_remove_dir(input_dir) + safe_remove_file(Path(status[0]["error_log_file"])) + + +def test_with_dir_with_output_folder_skip_validation(initial_setup, mock_workspace_client): + input_dir = initial_setup + config = TranspileConfig( + input_source=str(input_dir), + output_folder=str(input_dir / "output_transpiled"), + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): + status = transpile(mock_workspace_client, config) + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 6, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 5, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 1 + ), "no_of_sql_failed_while_validating does not match expected value" + assert stat["error_log_file"], "error_log_file is None or empty" + assert Path(stat["error_log_file"]).name.startswith("err_") and Path(stat["error_log_file"]).name.endswith( + ".lst" + ), "error_log_file does not match expected pattern 'err_*.lst'" + + expected_file_name = f"{input_dir}/query3.sql" + expected_exception = f"Unsupported operation found in file {input_dir}/query3.sql." + pattern = r"ValidationError\(file_name='(?P[^']+)', exception='(?P[^']+)'\)" + + with open(Path(status[0]["error_log_file"])) as file: + for line in file: + # Skip empty lines + if line.strip() == "": + continue + + match = re.match(pattern, line) + + if match: + # Extract information using group names from the pattern + error_info = match.groupdict() + # Perform assertions + assert error_info["file_name"] == expected_file_name, "File name does not match the expected value" + assert expected_exception in error_info["exception"], "Exception does not match the expected value" + else: + print("No match found.") + + # cleanup + safe_remove_dir(input_dir) + safe_remove_file(Path(status[0]["error_log_file"])) + + +def test_with_file(initial_setup, mock_workspace_client): + input_dir = initial_setup + sdk_config = create_autospec(Config) + spark = create_autospec(DatabricksSession) + config = TranspileConfig( + input_source=str(input_dir / "query1.sql"), + output_folder="None", + sdk_config=sdk_config, + source_dialect="snowflake", + skip_validation=False, + ) + mock_validate = create_autospec(Validator) + mock_validate.spark = spark + mock_validate.validate_format_result.return_value = ValidationResult( + """ Mock validated query """, "Mock validation error" + ) + + with ( + patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ), + patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), + ): + status = transpile(mock_workspace_client, config) + + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 1, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 1, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 1 + ), "no_of_sql_failed_while_validating does not match expected value" + assert Path(stat["error_log_file"]).name.startswith("err_") and Path(stat["error_log_file"]).name.endswith( + ".lst" + ), "error_log_file does not match expected pattern 'err_*.lst'" + + expected_content = f""" +ValidationError(file_name='{input_dir}/query1.sql', exception='Mock validation error') + """.strip() + + with open(Path(status[0]["error_log_file"])) as file: + content = file.read().strip() + assert content == expected_content, "File content does not match the expected content" + # cleanup + safe_remove_dir(input_dir) + safe_remove_file(Path(status[0]["error_log_file"])) + + +def test_with_file_with_output_folder_skip_validation(initial_setup, mock_workspace_client): + input_dir = initial_setup + config = TranspileConfig( + input_source=str(input_dir / "query1.sql"), + output_folder=str(input_dir / "output_transpiled"), + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ): + status = transpile(mock_workspace_client, config) + + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 1, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 1, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 0 + ), "no_of_sql_failed_while_validating does not match expected value" + assert stat["error_log_file"] == "None", "error_log_file does not match expected value" + # cleanup + safe_remove_dir(input_dir) + + +def test_with_not_a_sql_file_skip_validation(initial_setup, mock_workspace_client): + input_dir = initial_setup + config = TranspileConfig( + input_source=str(input_dir / "file.txt"), + output_folder="None", + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ): + status = transpile(mock_workspace_client, config) + + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 0, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 0, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 0 + ), "no_of_sql_failed_while_validating does not match expected value" + assert stat["error_log_file"] == "None", "error_log_file does not match expected value" + # cleanup + safe_remove_dir(input_dir) + + +def test_with_not_existing_file_skip_validation(initial_setup, mock_workspace_client): + input_dir = initial_setup + config = TranspileConfig( + input_source=str(input_dir / "file_not_exist.txt"), + output_folder="None", + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + with pytest.raises(FileNotFoundError): + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ): + transpile(mock_workspace_client, config) + + # cleanup + safe_remove_dir(input_dir) + + +def test_morph_sql(mock_workspace_client): + config = TranspileConfig( + source_dialect="snowflake", + skip_validation=False, + catalog_name="catalog", + schema_name="schema", + ) + query = """select col from table;""" + + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend( + rows={ + "EXPLAIN SELECT": [Row(plan="== Physical Plan ==")], + } + ), + ): + transpiler_result, validation_result = transpile_sql(mock_workspace_client, config, query) + assert transpiler_result.transpiled_sql[0] == 'SELECT\n col\nFROM table' + assert validation_result.exception_msg is None + + +def test_morph_column_exp(mock_workspace_client): + config = TranspileConfig( + source_dialect="snowflake", + skip_validation=True, + catalog_name="catalog", + schema_name="schema", + ) + query = ["case when col1 is null then 1 else 0 end", "col2 * 2", "current_timestamp()"] + + with patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend( + rows={ + "EXPLAIN SELECT": [Row(plan="== Physical Plan ==")], + } + ), + ): + result = transpile_column_exp(mock_workspace_client, config, query) + assert len(result) == 3 + assert result[0][0].transpiled_sql[0] == 'CASE WHEN col1 IS NULL THEN 1 ELSE 0 END' + assert result[1][0].transpiled_sql[0] == 'col2 * 2' + assert result[2][0].transpiled_sql[0] == 'CURRENT_TIMESTAMP()' + assert result[0][0].parse_error_list == [] + assert result[1][0].parse_error_list == [] + assert result[2][0].parse_error_list == [] + assert result[0][1] is None + assert result[1][1] is None + assert result[2][1] is None + + +def test_with_file_with_success(initial_setup, mock_workspace_client): + input_dir = initial_setup + sdk_config = create_autospec(Config) + spark = create_autospec(DatabricksSession) + config = TranspileConfig( + input_source=str(input_dir / "query1.sql"), + output_folder="None", + sdk_config=sdk_config, + source_dialect="snowflake", + skip_validation=False, + ) + mock_validate = create_autospec(Validator) + mock_validate.spark = spark + mock_validate.validate_format_result.return_value = ValidationResult(""" Mock validated query """, None) + + with ( + patch( + 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', + return_value=MockBackend(), + ), + patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), + ): + status = transpile(mock_workspace_client, config) + # assert the status + assert status is not None, "Status returned by morph function is None" + assert isinstance(status, list), "Status returned by morph function is not a list" + assert len(status) > 0, "Status returned by morph function is an empty list" + for stat in status: + assert stat["total_files_processed"] == 1, "total_files_processed does not match expected value" + assert stat["total_queries_processed"] == 1, "total_queries_processed does not match expected value" + assert ( + stat["no_of_sql_failed_while_parsing"] == 0 + ), "no_of_sql_failed_while_parsing does not match expected value" + assert ( + stat["no_of_sql_failed_while_validating"] == 0 + ), "no_of_sql_failed_while_validating does not match expected value" + assert stat["error_log_file"] == "None", "error_log_file does not match expected value" + + +def test_with_input_sql_none(initial_setup, mock_workspace_client): + config = TranspileConfig( + input_source=None, + output_folder="None", + sdk_config=None, + source_dialect="snowflake", + skip_validation=True, + ) + + with pytest.raises(ValueError, match="Input SQL path is not provided"): + transpile(mock_workspace_client, config) diff --git a/tests/unit/transpiler/test_lca_utils.py b/tests/unit/transpiler/test_lca_utils.py new file mode 100644 index 0000000000..f662495034 --- /dev/null +++ b/tests/unit/transpiler/test_lca_utils.py @@ -0,0 +1,315 @@ +from unittest.mock import patch + +from sqlglot import parse_one + +from databricks.labs.remorph.config import get_dialect +from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks +from databricks.labs.remorph.transpiler.sqlglot.lca_utils import check_for_unsupported_lca + + +def test_query_with_no_unsupported_lca_usage(): + dialect = get_dialect("snowflake") + sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca, + FROM table1 t + """ + filename = "test_file1.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert not error + + +def test_query_with_valid_alias_usage(): + dialect = get_dialect("snowflake") + sql = """ + WITH web_v1 as ( + select + ws_item_sk item_sk, d_date, + sum(sum(ws_sales_price)) + over (partition by ws_item_sk order by d_date rows between unbounded preceding and current row) cume_sales + from web_sales + ,date_dim + where ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212+11 + and ws_item_sk is not NULL + group by ws_item_sk, d_date), + store_v1 as ( + select + ss_item_sk item_sk, d_date, + sum(sum(ss_sales_price)) + over ( + partition by ss_item_sk order by d_date rows between unbounded preceding and current row + ) cume_sales + from store_sales + ,date_dim + where ss_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212+11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date) + select * + from (select item_sk + ,d_date + ,web_sales + ,store_sales + ,max(web_sales) + over ( + partition by item_sk order by d_date rows between unbounded preceding and current row + ) web_cumulative + ,max(store_sales) + over ( + partition by item_sk order by d_date rows between unbounded preceding and current row + ) store_cumulative + from (select case when web.item_sk is not null then web.item_sk else store.item_sk end item_sk + ,case when web.d_date is not null then web.d_date else store.d_date end d_date + ,web.cume_sales web_sales + ,store.cume_sales store_sales + from web_v1 web full outer join store_v1 store on (web.item_sk = store.item_sk + and web.d_date = store.d_date) + )x )y + where web_cumulative > store_cumulative + order by item_sk + ,d_date + limit 100; + """ + filename = "test_file1.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert not error + + +def test_query_with_lca_in_where(): + dialect = get_dialect("snowflake") + sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca, + FROM table1 t + WHERE ca in ('v1', 'v2') + """ + filename = "test_file2.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert error + + +def test_query_with_lca_in_window(): + dialect = get_dialect("snowflake") + sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca, + ROW_NUMBER() OVER (PARTITION by ca ORDER BY t.col2 DESC) AS rn + FROM table1 t + """ + filename = "test_file3.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert error + + +def test_query_with_error(): + dialect = get_dialect("snowflake") + sql = """ + SELECT + t.col1 + t.col2, + t.col3 AS ca, + FROM table1 t + """ + filename = "test_file4.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert not error + + +def test_query_with_same_alias_and_column_name(): + dialect = get_dialect("snowflake") + sql = """ + select ca_zip + from ( + SELECT + substr(ca_zip,1,5) ca_zip, + trim(name) as name, + count(*) over( partition by ca_zip) + FROM customer_address + WHERE substr(ca_zip,1,5) IN ('89436', '30868')); + """ + filename = "test_file5.sql" + + error = check_for_unsupported_lca(dialect, sql, filename) + assert not error + + +def test_fix_lca_with_valid_lca_usage(normalize_string): + input_sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca + FROM table1 t + """ + expected_sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca + FROM table1 AS t + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_with_lca_in_where(normalize_string): + input_sql = """ + SELECT column_a as customer_id + FROM my_table + WHERE customer_id = '123' + """ + expected_sql = """ + SELECT column_a as customer_id + FROM my_table + WHERE column_a = '123' + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_with_lca_in_window(normalize_string): + input_sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca, + ROW_NUMBER() OVER (PARTITION by ca ORDER BY t.col2 DESC) AS rn + FROM table1 t + """ + expected_sql = """ + SELECT + t.col1, + t.col2, + t.col3 AS ca, + ROW_NUMBER() OVER (PARTITION by t.col3 ORDER BY t.col2 DESC) AS rn + FROM table1 AS t + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_with_lca_in_subquery(normalize_string): + input_sql = """ + SELECT column_a as cid + FROM my_table + WHERE cid in (select cid as customer_id from customer_table where customer_id = '123') + """ + expected_sql = """ + SELECT column_a as cid + FROM my_table + WHERE column_a in (select cid as customer_id from customer_table where cid = '123') + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_with_lca_in_derived_table(normalize_string): + input_sql = """ + SELECT column_a as cid + FROM (select column_x as column_a, column_y as y from my_table where y = '456') + WHERE cid = '123' + """ + expected_sql = """ + SELECT column_a as cid + FROM (select column_x as column_a, column_y as y from my_table where column_y = '456') + WHERE column_a = '123' + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_with_lca_in_subquery_and_derived_table(normalize_string): + input_sql = """ + SELECT column_a as cid + FROM (select column_x as column_a, column_y as y from my_table where y = '456') + WHERE cid in (select cid as customer_id from customer_table where customer_id = '123') + """ + expected_sql = """ + SELECT column_a as cid + FROM (select column_x as column_a, column_y as y from my_table where column_y = '456') + WHERE column_a in (select cid as customer_id from customer_table where cid = '123') + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_lca_in_cte(normalize_string): + input_sql = """ + WITH cte AS (SELECT column_a as customer_id + FROM my_table + WHERE customer_id = '123') + SELECT * FROM cte + """ + expected_sql = """ + WITH cte AS (SELECT column_a as customer_id + FROM my_table + WHERE column_a = '123') + SELECT * FROM cte + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_nested_lca(normalize_string): + input_sql = """ + SELECT + b * c as new_b, + a - new_b as ab_diff + FROM my_table + WHERE ab_diff >= 0 + """ + expected_sql = """ + SELECT + b * c as new_b, + a - new_b as ab_diff + FROM my_table + WHERE a - b * c >= 0 + """ + ast = parse_one(input_sql) + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) + + +def test_fix_nested_lca_with_no_scope(normalize_string): + # This test is to check if the code can handle the case where the scope is not available + # In this case we will not fix the invalid LCA and return the original query + input_sql = """ + SELECT + b * c as new_b, + a - new_b as ab_diff + FROM my_table + WHERE ab_diff >= 0 + """ + expected_sql = """ + SELECT + b * c as new_b, + a - new_b as ab_diff + FROM my_table + WHERE ab_diff >= 0 + """ + ast = parse_one(input_sql) + with patch( + 'databricks.labs.remorph.transpiler.sqlglot.lca_utils.build_scope', + return_value=None, + ): + generated_sql = ast.sql(Databricks, pretty=False) + assert normalize_string(generated_sql) == normalize_string(expected_sql) diff --git a/tests/unit/transpiler/test_oracle.py b/tests/unit/transpiler/test_oracle.py new file mode 100644 index 0000000000..acb196cf0a --- /dev/null +++ b/tests/unit/transpiler/test_oracle.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory + +path = Path(__file__).parent / Path('../../resources/functional/oracle/') +functional_tests = get_functional_test_files_from_directory(path, "oracle", "databricks", False) +test_names = [f.test_name for f in functional_tests] + + +@pytest.mark.parametrize("sample", functional_tests, ids=test_names) +def test_oracle(dialect_context, sample: FunctionalTestFile): + validate_source_transpile, _ = dialect_context + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"oracle": sample.source}) diff --git a/tests/unit/transpiler/test_presto.py b/tests/unit/transpiler/test_presto.py new file mode 100644 index 0000000000..7958c264b9 --- /dev/null +++ b/tests/unit/transpiler/test_presto.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from ..conftest import FunctionalTestFile, get_functional_test_files_from_directory + +path = Path(__file__).parent / Path('../../resources/functional/presto/') +functional_tests = get_functional_test_files_from_directory(path, "presto", "databricks", False) +test_names = [f.test_name for f in functional_tests] + + +@pytest.mark.parametrize("sample", functional_tests, ids=test_names) +def test_presto(dialect_context, sample: FunctionalTestFile): + validate_source_transpile, _ = dialect_context + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"presto": sample.source}) diff --git a/tests/unit/transpiler/test_presto_expected_exceptions.py b/tests/unit/transpiler/test_presto_expected_exceptions.py new file mode 100644 index 0000000000..85821ab847 --- /dev/null +++ b/tests/unit/transpiler/test_presto_expected_exceptions.py @@ -0,0 +1,22 @@ +# Logic for processing test cases with expected exceptions, can be removed if not needed. +from pathlib import Path + +import pytest + +from ..conftest import ( + FunctionalTestFileWithExpectedException, + get_functional_test_files_from_directory, +) + +path_expected_exceptions = Path(__file__).parent / Path('../../resources/functional/presto_expected_exceptions/') +functional_tests_expected_exceptions = get_functional_test_files_from_directory( + path_expected_exceptions, "presto", "databricks", True +) +test_names_expected_exceptions = [f.test_name for f in functional_tests_expected_exceptions] + + +@pytest.mark.parametrize("sample", functional_tests_expected_exceptions, ids=test_names_expected_exceptions) +def test_presto_expected_exceptions(dialect_context, sample: FunctionalTestFileWithExpectedException): + validate_source_transpile, _ = dialect_context + with pytest.raises(type(sample.expected_exception)): + validate_source_transpile(databricks_sql=sample.databricks_sql, source={"presto": sample.source}) diff --git a/tests/unit/transpiler/test_snow.py b/tests/unit/transpiler/test_snow.py new file mode 100644 index 0000000000..a03c7f93a1 --- /dev/null +++ b/tests/unit/transpiler/test_snow.py @@ -0,0 +1,71 @@ +""" + Test Cases to validate source Snowflake dialect +""" + + +def test_parse_parameter(dialect_context): + """ + Function to assert conversion from source: `snowflake(read)` to target: `Databricks(sql)` + """ + validate_source_transpile, _ = dialect_context + sql = """ + SELECT DISTINCT + ABC, + CAST('${SCHEMA_NM}_MV' AS STRING), + sys_dt, + ins_ts, + upd_ts + FROM ${SCHEMA_NM}_MV.${TBL_NM} + WHERE + xyz IS NOT NULL AND src_ent = '${SCHEMA_NM}_MV.COL' + """ + + validate_source_transpile( + databricks_sql=sql, + source={ + "snowflake": """ + SELECT DISTINCT + ABC, + CAST('${SCHEMA_NM}_MV' AS VARCHAR(261)), + sys_dt, + ins_ts, + upd_ts + FROM ${SCHEMA_NM}_MV.${TBL_NM} WHERE xyz IS NOT NULL AND + src_ent = '${SCHEMA_NM}_MV.COL'; + """, + }, + pretty=True, + ) + + +def test_decimal_keyword(dialect_context): + """ + Function to test dec as alias name + """ + validate_source_transpile, _ = dialect_context + + sql = """ + SELECT + dec.id, + dec.key, + xy.value + FROM table AS dec + JOIN table2 AS xy + ON dec.id = xy.id + """ + + validate_source_transpile( + databricks_sql=sql, + source={ + "snowflake": """ + SELECT + dec.id, + dec.key, + xy.value + FROM table dec + JOIN table2 xy + ON dec.id = xy.id + """, + }, + pretty=True, + ) diff --git a/tests/unit/transpiler/test_sql_transpiler.py b/tests/unit/transpiler/test_sql_transpiler.py new file mode 100644 index 0000000000..56636c818e --- /dev/null +++ b/tests/unit/transpiler/test_sql_transpiler.py @@ -0,0 +1,91 @@ +import pytest +from sqlglot import expressions + +from databricks.labs.remorph.transpiler.sqlglot import local_expression +from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine + + +@pytest.fixture +def transpiler(morph_config): + read_dialect = morph_config.get_read_dialect() + return SqlglotEngine(read_dialect) + + +@pytest.fixture +def write_dialect(morph_config): + return morph_config.get_write_dialect() + + +def test_transpile_snowflake(transpiler, write_dialect): + transpiler_result = transpiler.transpile(write_dialect, "SELECT CURRENT_TIMESTAMP(0)", "file.sql", []) + assert transpiler_result.transpiled_sql[0] == "SELECT\n CURRENT_TIMESTAMP()" + + +def test_transpile_exception(transpiler, write_dialect): + transpiler_result = transpiler.transpile( + write_dialect, "SELECT TRY_TO_NUMBER(COLUMN, $99.99, 27) FROM table", "file.sql", [] + ) + assert transpiler_result.transpiled_sql[0] == "" + assert transpiler_result.parse_error_list[0].file_name == "file.sql" + assert "Error Parsing args" in transpiler_result.parse_error_list[0].exception + + +def test_parse_query(transpiler): + parsed_query, _ = transpiler.parse("SELECT TRY_TO_NUMBER(COLUMN, $99.99, 27,2) FROM table", "file.sql") + + expected_result = [ + local_expression.TryToNumber( + this=expressions.Column(this=expressions.Identifier(this="COLUMN", quoted=False)), + expression=expressions.Parameter( + this=expressions.Literal(this=99, is_string=False), + suffix=expressions.Literal(this=0.99, is_string=False), + ), + precision=expressions.Literal(this=27, is_string=False), + scale=expressions.Literal(this=2, is_string=False), + ) + ] + + expected_from_result = expressions.From( + this=expressions.Table(this=expressions.Identifier(this="table", quoted=False)) + ) + + for exp in parsed_query: + if exp: + assert repr(exp.args["expressions"]) == repr(expected_result) + assert repr(exp.args["from"]) == repr(expected_from_result) + + +def test_parse_invalid_query(transpiler): + result, error_list = transpiler.parse("invalid sql query", "file.sql") + assert result is None + assert error_list.file_name == "file.sql" + assert "Invalid expression / Unexpected token." in error_list.exception + + +def test_tokenizer_exception(transpiler, write_dialect): + transpiler_result = transpiler.transpile(write_dialect, "1SELECT ~v\ud83d' ", "file.sql", []) + + assert transpiler_result.transpiled_sql == [""] + assert transpiler_result.parse_error_list[0].file_name == "file.sql" + assert "Error tokenizing" in transpiler_result.parse_error_list[0].exception + + +def test_procedure_conversion(transpiler, write_dialect): + procedure_sql = "CREATE OR REPLACE PROCEDURE my_procedure() AS BEGIN SELECT * FROM my_table; END;" + transpiler_result = transpiler.transpile(write_dialect, procedure_sql, "file.sql", []) + assert ( + transpiler_result.transpiled_sql[0] + == "CREATE OR REPLACE PROCEDURE my_procedure() AS BEGIN\nSELECT\n *\nFROM my_table" + ) + + +def test_find_root_tables(transpiler): + expression, _ = transpiler.parse("SELECT * FROM table_name", "test.sql") + # pylint: disable=protected-access + assert transpiler._find_root_tables(expression[0]) == "table_name" + + +def test_parse_sql_content(transpiler): + result = list(transpiler.parse_sql_content("SELECT * FROM table_name", "test.sql")) + assert result[0][0] == "table_name" + assert result[0][1] == "test.sql"