Skip to content

Commit

Permalink
feat(mssql): add multi database ingest support (#5516)
Browse files Browse the repository at this point in the history
* feat(mssql): add multi database ingest support

* Delete older golden file.

* Update s3.md

* fix test setup
  • Loading branch information
MugdhaHardikar-GSLab authored Aug 16, 2022
1 parent dfd0d15 commit a449e8b
Show file tree
Hide file tree
Showing 10 changed files with 5,256 additions and 27 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/docs/sources/s3/s3.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ s3://my-bucket/*/*/{table}/{partition[0]}/{partition[1]}/{partition[2]}/*.* # ta
- s3://my-bucket/hr/**
- **/tests/*.csv
- s3://my-bucket/foo/*/my_table/**
-

### Notes

- {table} represents folder for which dataset will be created.
Expand Down
84 changes: 72 additions & 12 deletions metadata-ingestion/src/datahub/ingestion/source/sql/mssql.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging
import urllib.parse
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple

import pydantic

# This import verifies that the dependencies are available.
import sqlalchemy_pytds # noqa: F401
from pydantic.fields import Field
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import ResultProxy, RowProxy

from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
SourceCapability,
Expand All @@ -22,8 +25,11 @@
from datahub.ingestion.source.sql.sql_common import (
BasicSQLAlchemyConfig,
SQLAlchemySource,
make_sqlalchemy_uri,
)

logger: logging.Logger = logging.getLogger(__name__)


class SQLServerConfig(BasicSQLAlchemyConfig):
# defaults
Expand All @@ -37,6 +43,19 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
default={},
desscription="Arguments to URL-encode when connecting. See https://docs.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver15.",
)
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="Regex patterns for databases to filter in ingestion.",
)
database: Optional[str] = Field(
default=None,
description="database (catalog). If set to Null, all databases will be considered for ingestion.",
)

database_alias: Optional[str] = Field(
default=None,
description="Alias to apply to database when ingesting. Ignored when `database` is not set.",
)

@pydantic.validator("uri_args")
def passwords_match(cls, v, values, **kwargs):
Expand All @@ -46,26 +65,29 @@ def passwords_match(cls, v, values, **kwargs):
raise ValueError("uri_args is not supported when ODBC is disabled")
return v

def get_sql_alchemy_url(self, uri_opts: Optional[Dict[str, Any]] = None) -> str:
def get_sql_alchemy_url(
self,
uri_opts: Optional[Dict[str, Any]] = None,
current_db: Optional[str] = None,
) -> str:
if self.use_odbc:
# Ensure that the import is available.
import pyodbc # noqa: F401

self.scheme = "mssql+pyodbc"

uri: str = super().get_sql_alchemy_url(uri_opts=uri_opts)
uri: str = self.sqlalchemy_uri or make_sqlalchemy_uri(
self.scheme, # type: ignore
self.username,
self.password.get_secret_value() if self.password else None,
self.host_port, # type: ignore
current_db if current_db else self.database,
uri_opts=uri_opts,
)
if self.use_odbc:
uri = f"{uri}?{urllib.parse.urlencode(self.uri_args)}"
return uri

def get_identifier(self, schema: str, table: str) -> str:
regular = f"{schema}.{table}"
if self.database_alias:
return f"{self.database_alias}.{regular}"
if self.database:
return f"{self.database}.{regular}"
return regular


@platform_name("Microsoft SQL Server", id="mssql")
@config_class(SQLServerConfig)
Expand Down Expand Up @@ -93,8 +115,9 @@ class SQLServerSource(SQLAlchemySource):

def __init__(self, config: SQLServerConfig, ctx: PipelineContext):
super().__init__(config, ctx, "mssql")

# Cache the table and column descriptions
self.config: SQLServerConfig = config
self.current_database = None
self.table_descriptions: Dict[str, str] = {}
self.column_descriptions: Dict[str, str] = {}
for inspector in self.get_inspectors():
Expand Down Expand Up @@ -183,3 +206,40 @@ def _get_columns(
if description:
column["comment"] = description
return columns

def get_inspectors(self) -> Iterable[Inspector]:
# This method can be overridden in the case that you want to dynamically
# run on multiple databases.
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
with engine.connect() as conn:
if self.config.database and self.config.database != "":
inspector = inspect(conn)
yield inspector
else:
databases = conn.execute(
"SELECT name FROM master.sys.databases WHERE name NOT IN \
('master', 'model', 'msdb', 'tempdb', 'Resource', \
'distribution' , 'reportserver', 'reportservertempdb'); "
)
for db in databases:
if self.config.database_pattern.allowed(db["name"]):
url = self.config.get_sql_alchemy_url(current_db=db["name"])
inspector = inspect(
create_engine(url, **self.config.options).connect()
)
self.current_database = db["name"]
yield inspector

def get_identifier(
self, *, schema: str, entity: str, inspector: Inspector, **kwargs: Any
) -> str:
regular = f"{schema}.{entity}"
if self.config.database:
if self.config.database_alias:
return f"{self.config.database_alias}.{regular}"
return f"{self.config.database}.{regular}"
if self.current_database:
return f"{self.current_database}.{regular}"
return regular
Loading

0 comments on commit a449e8b

Please sign in to comment.