Skip to content

Commit

Permalink
bring unit tests up to 0.16.0 support
Browse files Browse the repository at this point in the history
Set the circleci config go to v2.1
  • Loading branch information
Jacob Beck committed Mar 30, 2020
1 parent 936aaf5 commit bb9b56b
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: 2
version: 2.1

jobs:
unit:
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def hive_thrift_connect(host, port, username):

def test_parse_relation(self):
self.maxDiff = None
rel_type = SparkRelation.RelationType.Table
rel_type = SparkRelation.get_relation_type.Table

relation = SparkRelation.create(
database='default_database',
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_parse_relation(self):

def test_parse_relation_with_statistics(self):
self.maxDiff = None
rel_type = SparkRelation.RelationType.Table
rel_type = SparkRelation.get_relation_type.Table

relation = SparkRelation.create(
database='default_database',
Expand Down
93 changes: 84 additions & 9 deletions test/unit/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unit test utility functions.
Note that all imports should be inside the functions to avoid import/mocking
issues.
"""
Expand All @@ -11,6 +12,7 @@

def normalize(path):
"""On windows, neither is enough on its own:
>>> normcase('C:\\documents/ALL CAPS/subdir\\..')
'c:\\documents\\all caps\\subdir\\..'
>>> normpath('C:\\documents/ALL CAPS/subdir\\..')
Expand All @@ -23,6 +25,7 @@ def normalize(path):

class Obj:
which = 'blah'
single_threaded = False


def mock_connection(name):
Expand All @@ -31,20 +34,63 @@ def mock_connection(name):
return conn


def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'):
from dbt.config import Project, Profile, RuntimeConfig
def profile_from_dict(profile, profile_name, cli_vars='{}'):
from dbt.config import Profile, ConfigRenderer
from dbt.context.base import generate_base_context
from dbt.utils import parse_cli_vars
from copy import deepcopy
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)
if not isinstance(project, Project):
project = Project.from_project_config(deepcopy(project), packages)

renderer = ConfigRenderer(generate_base_context(cli_vars))
return Profile.from_raw_profile_info(
profile,
profile_name,
renderer,
)


def project_from_dict(project, profile, packages=None, cli_vars='{}'):
from dbt.context.target import generate_target_context
from dbt.config import Project, ConfigRenderer
from dbt.utils import parse_cli_vars
if not isinstance(cli_vars, dict):
cli_vars = parse_cli_vars(cli_vars)

renderer = ConfigRenderer(generate_target_context(profile, cli_vars))

project_root = project.pop('project-root', os.getcwd())

return Project.render_from_dict(
project_root, project, packages, renderer
)


def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'):
from dbt.config import Project, Profile, RuntimeConfig
from copy import deepcopy

if isinstance(project, Project):
profile_name = project.profile_name
else:
profile_name = project.get('profile')

if not isinstance(profile, Profile):
profile = Profile.from_raw_profile_info(deepcopy(profile),
project.profile_name,
cli_vars)
profile = profile_from_dict(
deepcopy(profile),
profile_name,
cli_vars,
)

if not isinstance(project, Project):
project = project_from_dict(
deepcopy(project),
profile,
packages,
cli_vars,
)

args = Obj()
args.vars = repr(cli_vars)
args.vars = cli_vars
args.profile_dir = '/dev/null'
return RuntimeConfig.from_parts(
project=project,
Expand Down Expand Up @@ -88,3 +134,32 @@ def assert_fails_validation(self, dct, cls=None):

with self.assertRaises(ValidationError):
cls.from_dict(dct)


def generate_name_macros(package):
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.node_types import NodeType
name_sql = {}
for component in ('database', 'schema', 'alias'):
if component == 'alias':
source = 'node.name'
else:
source = f'target.{component}'
name = f'generate_{component}_name'
sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}'
name_sql[name] = sql

all_sql = '\n'.join(name_sql.values())
for name, sql in name_sql.items():
pm = ParsedMacro(
name=name,
resource_type=NodeType.Macro,
unique_id=f'macro.{package}.{name}',
package_name=package,
original_file_path=normalize('macros/macro.sql'),
root_path='./dbt_modules/root',
path=normalize('macros/macro.sql'),
raw_sql=all_sql,
macro_sql=sql,
)
yield pm

0 comments on commit bb9b56b

Please sign in to comment.