Skip to content

Commit

Permalink
feat(ingest): add offline flag to SQL parser CLI (datahub-project#11635)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and sleeperdeep committed Dec 17, 2024
1 parent 1369198 commit 8042b7d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
35 changes: 29 additions & 6 deletions metadata-ingestion/src/datahub/cli/check_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,13 @@ def sql_format(sql: str, platform: str) -> None:
@click.option(
"--sql",
type=str,
required=True,
help="The SQL query to parse",
)
@click.option(
"--sql-file",
type=click.Path(exists=True, dir_okay=False, readable=True),
help="The SQL file to parse",
)
@click.option(
"--platform",
type=str,
Expand Down Expand Up @@ -218,25 +222,44 @@ def sql_format(sql: str, platform: str) -> None:
type=str,
help="The default schema to use for unqualified table names",
)
@click.option(
"--online/--offline",
type=bool,
is_flag=True,
default=True,
help="Run in offline mode and disable schema-aware parsing.",
)
@telemetry.with_telemetry()
def sql_lineage(
sql: str,
sql: Optional[str],
sql_file: Optional[str],
platform: str,
default_db: Optional[str],
default_schema: Optional[str],
platform_instance: Optional[str],
env: str,
online: bool,
) -> None:
"""Parse the lineage of a SQL query.
This performs schema-aware parsing in order to generate column-level lineage.
If the relevant tables are not in DataHub, this will be less accurate.
In online mode (the default), we perform schema-aware parsing in order to generate column-level lineage.
If offline mode is enabled or if the relevant tables are not in DataHub, this will be less accurate.
"""

graph = get_default_graph()
from datahub.sql_parsing.sqlglot_lineage import create_lineage_sql_parsed_result

if sql is None:
if sql_file is None:
raise click.UsageError("Either --sql or --sql-file must be provided")
sql = pathlib.Path(sql_file).read_text()

graph = None
if online:
graph = get_default_graph()

lineage = graph.parse_sql_lineage(
lineage = create_lineage_sql_parsed_result(
sql,
graph=graph,
platform=platform,
platform_instance=platform_instance,
env=env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,9 @@ def _get_direct_raw_col_upstreams(
# Parse the column name out of the node name.
# Sqlglot calls .sql(), so we have to do the inverse.
normalized_col = sqlglot.parse_one(node.name).this.name
if node.subfield:
if hasattr(node, "subfield") and node.subfield:
# The hasattr check is necessary, since it lets us be compatible with
# sqlglot versions that don't have the subfield attribute.
normalized_col = f"{normalized_col}.{node.subfield}"

direct_raw_col_upstreams.add(
Expand Down

0 comments on commit 8042b7d

Please sign in to comment.