diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index e7b0fcea239..f0722e40d7b 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -28,7 +28,7 @@ class BigQueryAdapter(PostgresAdapter): - context_functions = [ + config_functions = [ # deprecated -- use versions that take relations instead "query_for_existing", "execute_model", @@ -51,11 +51,11 @@ class BigQueryAdapter(PostgresAdapter): "drop_relation", "rename_relation", - "get_columns_in_table" - ] + "get_columns_in_table", - Relation = BigQueryRelation - Column = dbt.schema.BigQueryColumn + # formerly profile functions + "add_query", + ] SCOPE = ('https://www.googleapis.com/auth/bigquery', 'https://www.googleapis.com/auth/cloud-platform', @@ -68,6 +68,8 @@ class BigQueryAdapter(PostgresAdapter): } QUERY_TIMEOUT = 300 + Relation = BigQueryRelation + Column = dbt.schema.BigQueryColumn @classmethod def handle_error(cls, error, message, sql): @@ -78,20 +80,19 @@ def handle_error(cls, error, message, sql): raise dbt.exceptions.DatabaseException(error_msg) - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name='master'): try: yield except google.cloud.exceptions.BadRequest as e: message = "Bad request while running:\n{sql}" - cls.handle_error(e, message, sql) + self.handle_error(e, message, sql) except google.cloud.exceptions.Forbidden as e: message = "Access denied while running:\n{sql}" - cls.handle_error(e, message, sql) + self.handle_error(e, message, sql) except Exception as e: logger.debug("Unhandled error while running:\n{}".format(sql)) @@ -106,12 +107,10 @@ def type(cls): def date_function(cls): return 'CURRENT_TIMESTAMP()' - @classmethod - def begin(cls, config, name='master'): + def begin(self, name): pass - @classmethod - def commit(cls, config, connection): + def commit(self, connection): pass @classmethod @@ -181,12 +180,11 @@ def close(cls, connection): return connection - @classmethod - def list_relations(cls, config, schema, model_name=None): - connection = cls.get_connection(config, model_name) + def list_relations(self, schema, model_name=None): + connection = self.get_connection(model_name) client = connection.handle - bigquery_dataset = cls.get_dataset(config, schema, model_name) + bigquery_dataset = self.get_dataset(schema, model_name) all_tables = client.list_tables( bigquery_dataset, @@ -203,12 +201,11 @@ def list_relations(cls, config, schema, model_name=None): # This will 404 if the dataset does not exist. This behavior mirrors # the implementation of list_relations for other adapters try: - return [cls.bq_table_to_relation(table) for table in all_tables] + return [self.bq_table_to_relation(table) for table in all_tables] except google.api_core.exceptions.NotFound as e: return [] - @classmethod - def get_relation(cls, config, schema=None, identifier=None, + def get_relation(self, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -216,32 +213,27 @@ def get_relation(cls, config, schema=None, identifier=None, 'of relations to use') if relations_list is None and identifier is not None: - table = cls.get_bq_table(config, schema, identifier) + table = self.get_bq_table(schema, identifier) - return cls.bq_table_to_relation(table) + return self.bq_table_to_relation(table) - return super(BigQueryAdapter, cls).get_relation( - config, schema, identifier, relations_list, + return super(BigQueryAdapter, self).get_relation( + schema, identifier, relations_list, model_name) - @classmethod - def drop_relation(cls, config, relation, model_name=None): - conn = cls.get_connection(config, model_name) + def drop_relation(self, relation, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, relation.schema, model_name) + dataset = self.get_dataset(relation.schema, model_name) relation_object = dataset.table(relation.identifier) client.delete_table(relation_object) - @classmethod - def rename(cls, config, schema, - from_name, to_name, model_name=None): + def rename(self, schema, from_name, to_name, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename` is not implemented for this adapter!') - @classmethod - def rename_relation(cls, config, from_relation, to_relation, - model_name=None): + def rename_relation(self, from_relation, to_relation, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename_relation` is not implemented for this adapter!') @@ -250,13 +242,12 @@ def get_timeout(cls, conn): credentials = conn['credentials'] return credentials.get('timeout_seconds', cls.QUERY_TIMEOUT) - @classmethod - def materialize_as_view(cls, config, dataset, model): + def materialize_as_view(self, dataset, model): model_name = model.get('name') model_alias = model.get('alias') model_sql = model.get('injected_sql') - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle view_ref = dataset.table(model_alias) @@ -266,7 +257,7 @@ def materialize_as_view(cls, config, dataset, model): logger.debug("Model SQL ({}):\n{}".format(model_name, model_sql)) - with cls.exception_handler(config, model_sql, model_name, model_name): + with self.exception_handler(model_sql, model_name, model_name): client.create_table(view) return "CREATE VIEW" @@ -286,26 +277,24 @@ def poll_until_job_completes(cls, job, timeout): elif job.error_result: raise job.exception() - @classmethod - def make_date_partitioned_table(cls, config, dataset_name, identifier, + def make_date_partitioned_table(self, dataset_name, identifier, model_name=None): - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, dataset_name, identifier) + dataset = self.get_dataset(dataset_name, identifier) table_ref = dataset.table(identifier) table = google.cloud.bigquery.Table(table_ref) table.partitioning_type = 'DAY' return client.create_table(table) - @classmethod - def materialize_as_table(cls, config, dataset, model, model_sql, + def materialize_as_table(self, dataset, model, model_sql, decorator=None): model_name = model.get('name') model_alias = model.get('alias') - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle if decorator is None: @@ -322,14 +311,13 @@ def materialize_as_table(cls, config, dataset, model, model_sql, query_job = client.query(model_sql, job_config=job_config) # this waits for the job to complete - with cls.exception_handler(config, model_sql, model_alias, - model_name): - query_job.result(timeout=cls.get_timeout(conn)) + with self.exception_handler(model_sql, model_alias, + model_name): + query_job.result(timeout=self.get_timeout(conn)) return "CREATE TABLE" - @classmethod - def execute_model(cls, config, model, + def execute_model(self, model, materialization, sql_override=None, decorator=None, model_name=None): @@ -337,20 +325,19 @@ def execute_model(cls, config, model, sql_override = model.get('injected_sql') if flags.STRICT_MODE: - connection = cls.get_connection(config, model.get('name')) + connection = self.get_connection(model.get('name')) Connection(**connection) model_name = model.get('name') model_schema = model.get('schema') - dataset = cls.get_dataset(config, - model_schema, model_name) + dataset = self.get_dataset(model_schema, model_name) if materialization == 'view': - res = cls.materialize_as_view(config, dataset, model) + res = self.materialize_as_view(dataset, model) elif materialization == 'table': - res = cls.materialize_as_table( - config, dataset, model, + res = self.materialize_as_table( + dataset, model, sql_override, decorator) else: msg = "Invalid relation type: '{}'".format(materialization) @@ -358,9 +345,8 @@ def execute_model(cls, config, model, return res - @classmethod - def raw_execute(cls, config, sql, model_name=None, fetch=False, **kwargs): - conn = cls.get_connection(config, model_name) + def raw_execute(self, sql, model_name=None, fetch=False, **kwargs): + conn = self.get_connection(model_name) client = conn.handle logger.debug('On %s: %s', model_name, sql) @@ -370,20 +356,18 @@ def raw_execute(cls, config, sql, model_name=None, fetch=False, **kwargs): query_job = client.query(sql, job_config) # this blocks until the query has completed - with cls.exception_handler(config, sql, model_name): + with self.exception_handler(sql, model_name): iterator = query_job.result() return query_job, iterator - @classmethod - def create_temporary_table(cls, config, sql, model_name=None, - **kwargs): + def create_temporary_table(self, sql, model_name=None, **kwargs): # BQ queries always return a temp table with their results - query_job, _ = cls.raw_execute(config, sql, model_name) + query_job, _ = self.raw_execute(sql, model_name) bq_table = query_job.destination - return cls.Relation.create( + return self.Relation.create( project=bq_table.project, schema=bq_table.dataset_id, identifier=bq_table.table_id, @@ -393,17 +377,15 @@ def create_temporary_table(cls, config, sql, model_name=None, }, type=BigQueryRelation.Table) - @classmethod - def alter_table_add_columns(cls, config, relation, columns, - model_name=None): + def alter_table_add_columns(self, relation, columns, model_name=None): logger.debug('Adding columns ({}) to table {}".'.format( columns, relation)) - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, relation.schema, model_name) + dataset = self.get_dataset(relation.schema, model_name) table_ref = dataset.table(relation.name) table = client.get_table(table_ref) @@ -414,13 +396,11 @@ def alter_table_add_columns(cls, config, relation, columns, new_table = google.cloud.bigquery.Table(table_ref, schema=new_schema) client.update_table(new_table, ['schema']) - @classmethod - def execute(cls, config, sql, model_name=None, fetch=None, **kwargs): - _, iterator = cls.raw_execute(config, sql, model_name, fetch, - **kwargs) + def execute(self, sql, model_name=None, fetch=None, **kwargs): + _, iterator = self.raw_execute(sql, model_name, fetch, **kwargs) if fetch: - res = cls.get_table_from_response(iterator) + res = self.get_table_from_response(iterator) else: res = dbt.clients.agate_helper.empty_table() @@ -428,9 +408,8 @@ def execute(cls, config, sql, model_name=None, fetch=None, **kwargs): status = 'OK' return status, res - @classmethod - def execute_and_fetch(cls, config, sql, model_name, auto_begin=None): - status, table = cls.execute(config, sql, model_name, fetch=True) + def execute_and_fetch(self, sql, model_name, auto_begin=None): + status, table = self.execute(sql, model_name, fetch=True) return status, table @classmethod @@ -441,118 +420,106 @@ def get_table_from_response(cls, resp): # BigQuery doesn't support BEGIN/COMMIT, so stub these out. - @classmethod - def add_begin_query(cls, config, name): + def add_begin_query(self, name): pass - @classmethod - def add_commit_query(cls, config, name): + def add_commit_query(self, name): pass - @classmethod - def create_schema(cls, config, schema, model_name=None): + def create_schema(self, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, schema, model_name) + dataset = self.get_dataset(schema, model_name) # Emulate 'create schema if not exists ...' try: client.get_dataset(dataset) except google.api_core.exceptions.NotFound: - with cls.exception_handler(config, 'create dataset', model_name): + with self.exception_handler('create dataset', model_name): client.create_dataset(dataset) - @classmethod - def drop_tables_in_schema(cls, config, dataset): - conn = cls.get_connection(config) + def drop_tables_in_schema(self, dataset): + conn = self.get_connection() client = conn.handle for table in client.list_tables(dataset): client.delete_table(table.reference) - @classmethod - def drop_schema(cls, config, schema, model_name=None): + def drop_schema(self, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - if not cls.check_schema_exists(config, - schema, model_name): + if not self.check_schema_exists(schema, model_name): return - conn = cls.get_connection(config) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, schema, model_name) - with cls.exception_handler(config, 'drop dataset', model_name): - cls.drop_tables_in_schema(config, dataset) + dataset = self.get_dataset(schema, model_name) + with self.exception_handler('drop dataset', model_name): + self.drop_tables_in_schema(dataset) client.delete_dataset(dataset) - @classmethod - def get_existing_schemas(cls, config, model_name=None): - conn = cls.get_connection(config, model_name) + def get_existing_schemas(self, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - with cls.exception_handler(config, 'list dataset', model_name): + with self.exception_handler('list dataset', model_name): all_datasets = client.list_datasets(include_all=True) return [ds.dataset_id for ds in all_datasets] - @classmethod - def get_columns_in_table(cls, config, schema_name, table_name, + def get_columns_in_table(self, schema_name, table_name, database=None, model_name=None): # BigQuery does not have databases -- the database parameter is here # for consistency with the base implementation - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle try: dataset_ref = client.dataset(schema_name) table_ref = dataset_ref.table(table_name) table = client.get_table(table_ref) - return cls.get_dbt_columns_from_bq_table(table) + return self.get_dbt_columns_from_bq_table(table) except (ValueError, google.cloud.exceptions.NotFound) as e: logger.debug("get_columns_in_table error: {}".format(e)) return [] - @classmethod - def get_dbt_columns_from_bq_table(cls, table): + def get_dbt_columns_from_bq_table(self, table): "Translates BQ SchemaField dicts into dbt BigQueryColumn objects" columns = [] for col in table.schema: # BigQuery returns type labels that are not valid type specifiers - dtype = cls.Column.translate_type(col.field_type) - column = cls.Column( + dtype = self.Column.translate_type(col.field_type) + column = self.Column( col.name, dtype, col.fields, col.mode) columns.append(column) return columns - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): - conn = cls.get_connection(config, model_name) + def check_schema_exists(self, schema, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - with cls.exception_handler(config, 'get dataset', model_name): + with self.exception_handler('get dataset', model_name): all_datasets = client.list_datasets(include_all=True) return any([ds.dataset_id == schema for ds in all_datasets]) - @classmethod - def get_dataset(cls, config, dataset_name, model_name=None): - conn = cls.get_connection(config, model_name) + def get_dataset(self, dataset_name, model_name=None): + conn = self.get_connection(model_name) dataset_ref = conn.handle.dataset(dataset_name) return google.cloud.bigquery.Dataset(dataset_ref) - @classmethod - def bq_table_to_relation(cls, bq_table): + def bq_table_to_relation(self, bq_table): if bq_table is None: return None - return cls.Relation.create( + return self.Relation.create( project=bq_table.project, schema=bq_table.dataset_id, identifier=bq_table.table_id, @@ -560,13 +527,12 @@ def bq_table_to_relation(cls, bq_table): 'schema': True, 'identifier': True }, - type=cls.RELATION_TYPES.get(bq_table.table_type)) + type=self.RELATION_TYPES.get(bq_table.table_type)) - @classmethod - def get_bq_table(cls, config, dataset_name, identifier, model_name=None): - conn = cls.get_connection(config, model_name) + def get_bq_table(self, dataset_name, identifier, model_name=None): + conn = self.get_connection(model_name) - dataset = cls.get_dataset(config, dataset_name, model_name) + dataset = self.get_dataset(dataset_name, model_name) table_ref = dataset.table(identifier) @@ -576,16 +542,15 @@ def get_bq_table(cls, config, dataset_name, identifier, model_name=None): return None @classmethod - def warning_on_hooks(cls, hook_type): + def warning_on_hooks(hook_type): msg = "{} is not supported in bigquery and will be ignored" dbt.ui.printer.print_timestamped_line(msg.format(hook_type), dbt.ui.printer.COLOR_FG_YELLOW) - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): if model_name in ['on-run-start', 'on-run-end']: - cls.warning_on_hooks(model_name) + self.warning_on_hooks(model_name) else: raise dbt.exceptions.NotImplementedException( '`add_query` is not implemented for this adapter!') @@ -598,16 +563,13 @@ def is_cancelable(cls): def quote(cls, identifier): return '`{}`'.format(identifier) - @classmethod - def quote_schema_and_table(cls, config, schema, - table, model_name=None): - return cls.render_relation(config, cls.quote(schema), cls.quote(table)) + def quote_schema_and_table(self, schema, table, model_name=None): + return self.render_relation(self.quote(schema), self.quote(table)) - @classmethod - def render_relation(cls, config, schema, table): - connection = cls.get_connection(config) + def render_relation(cls, schema, table): + connection = self.get_connection() project = connection.credentials.project - return '{}.{}.{}'.format(cls.quote(project), schema, table) + return '{}.{}.{}'.format(self.quote(project), schema, table) @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -636,13 +598,12 @@ def _agate_to_schema(cls, agate_table, column_override): google.cloud.bigquery.SchemaField(col_name, type_)) return bq_schema - @classmethod - def load_dataframe(cls, config, schema, table_name, agate_table, + def load_dataframe(self, schema, table_name, agate_table, column_override, model_name=None): - bq_schema = cls._agate_to_schema(agate_table, column_override) - dataset = cls.get_dataset(config, schema, None) + bq_schema = self._agate_to_schema(agate_table, column_override) + dataset = self.get_dataset(schema, None) table = dataset.table(table_name) - conn = cls.get_connection(config, None) + conn = self.get_connection(None) client = conn.handle load_config = google.cloud.bigquery.LoadJobConfig() @@ -653,21 +614,19 @@ def load_dataframe(cls, config, schema, table_name, agate_table, job = client.load_table_from_file(f, table, rewind=True, job_config=load_config) - with cls.exception_handler(config, "LOAD TABLE"): - cls.poll_until_job_completes(job, cls.get_timeout(conn)) + with self.exception_handler("LOAD TABLE"): + self.poll_until_job_completes(job, self.get_timeout(conn)) - @classmethod - def expand_target_column_types(cls, config, temp_table, + def expand_target_column_types(self, temp_table, to_schema, to_table, model_name=None): # This is a no-op on BigQuery pass - @classmethod - def _flat_columns_in_table(cls, table): + def _flat_columns_in_table(self, table): """An iterator over the flattened columns for a given schema and table. Resolves child columns as having the name "parent.child". """ - for col in cls.get_dbt_columns_from_bq_table(table): + for col in self.get_dbt_columns_from_bq_table(table): flattened = col.flatten() for subcol in flattened: yield subcol @@ -727,9 +686,8 @@ def _get_stats_columns(cls, table, relation_type): ) return zip(column_names, column_values) - @classmethod - def get_catalog(cls, config, manifest): - connection = cls.get_connection(config, 'catalog') + def get_catalog(self, manifest): + connection = self.get_connection('catalog') client = connection.handle schemas = { @@ -749,11 +707,11 @@ def get_catalog(cls, config, manifest): 'column_type', 'column_comment', ) - all_names = column_names + cls._get_stats_column_names() + all_names = column_names + self._get_stats_column_names() columns = [] for schema_name in schemas: - relations = cls.list_relations(config, schema_name) + relations = self.list_relations(schema_name) for relation in relations: # This relation contains a subset of the info we care about. @@ -762,9 +720,9 @@ def get_catalog(cls, config, manifest): table_ref = dataset_ref.table(relation.identifier) table = client.get_table(table_ref) - flattened = cls._flat_columns_in_table(table) - relation_stats = dict(cls._get_stats_columns(table, - relation.type)) + flattened = self._flat_columns_in_table(table) + relation_stats = dict(self._get_stats_columns(table, + relation.type)) for index, column in enumerate(flattened, start=1): column_data = ( diff --git a/dbt/adapters/bigquery/relation.py b/dbt/adapters/bigquery/relation.py index f5807dc8e3e..d962d493d72 100644 --- a/dbt/adapters/bigquery/relation.py +++ b/dbt/adapters/bigquery/relation.py @@ -86,7 +86,7 @@ def matches(self, project=None, schema=None, identifier=None): return True @classmethod - def create_from_node(cls, config, node, **kwargs): + def _create_from_node(cls, config, node, **kwargs): return cls.create( project=config.credentials.project, schema=node.get('schema'), diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 5fdf83667f5..8fe29e4a684 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -52,7 +52,7 @@ def test(row): class DefaultAdapter(object): requires = {} - context_functions = [ + config_functions = [ "get_columns_in_table", "get_missing_columns", "expand_target_column_types", @@ -75,9 +75,8 @@ class DefaultAdapter(object): "drop_relation", "rename_relation", "truncate_relation", - ] - profile_functions = [ + # formerly profile functions "execute", "add_query", ] @@ -88,17 +87,18 @@ class DefaultAdapter(object): "quote", "convert_type" ] - Relation = DefaultRelation Column = Column + def __init__(self, config): + self.config = config + ### # ADAPTER-SPECIFIC FUNCTIONS -- each of these must be overridden in # every adapter ### - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name=None): raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') @@ -118,15 +118,12 @@ def get_status(cls, cursor): raise dbt.exceptions.NotImplementedException( '`get_status` is not implemented for this adapter!') - @classmethod - def alter_column_type(cls, config, schema, table, - column_name, new_column_type, model_name=None): + def alter_column_type(self, schema, table, column_name, new_column_type, + model_name=None): raise dbt.exceptions.NotImplementedException( '`alter_column_type` is not implemented for this adapter!') - @classmethod - def query_for_existing(cls, config, schemas, - model_name=None): + def query_for_existing(self, schemas, model_name=None): if not isinstance(schemas, (list, tuple)): schemas = [schemas] @@ -134,23 +131,20 @@ def query_for_existing(cls, config, schemas, for schema in schemas: all_relations.extend( - cls.list_relations(config, schema, model_name)) + self.list_relations(schema, model_name)) return {relation.identifier: relation.type for relation in all_relations} - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): raise dbt.exceptions.NotImplementedException( '`get_existing_schemas` is not implemented for this adapter!') - @classmethod - def check_schema_exists(cls, config, schema): + def check_schema_exists(self, schema): raise dbt.exceptions.NotImplementedException( '`check_schema_exists` is not implemented for this adapter!') - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): raise dbt.exceptions.NotImplementedException( '`cancel_connection` is not implemented for this adapter!') @@ -170,19 +164,17 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) - @classmethod - def drop(cls, config, schema, relation, relation_type, model_name=None): + def drop(self, schema, relation, relation_type, model_name=None): identifier = relation - relation = cls.Relation.create( + relation = self.Relation.create( schema=schema, identifier=identifier, type=relation_type, - quote_policy=config.quoting) + quote_policy=self.config.quoting) - return cls.drop_relation(config, relation, model_name) + return self.drop_relation(relation, model_name) - @classmethod - def drop_relation(cls, config, relation, model_name=None): + def drop_relation(self, relation, model_name=None): if relation.type is None: dbt.exceptions.raise_compiler_error( 'Tried to drop relation {}, but its type is null.' @@ -190,72 +182,65 @@ def drop_relation(cls, config, relation, model_name=None): sql = 'drop {} if exists {} cascade'.format(relation.type, relation) - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) - @classmethod - def truncate(cls, config, schema, table, model_name=None): - relation = cls.Relation.create( + def truncate(self, schema, table, model_name=None): + relation = self.Relation.create( schema=schema, identifier=table, type='table', - quote_policy=config.quoting) + quote_policy=self.config.quoting) - return cls.truncate_relation(config, relation, model_name) + return self.truncate_relation(relation, model_name) - @classmethod - def truncate_relation(cls, config, - relation, model_name=None): + def truncate_relation(self, relation, model_name=None): sql = 'truncate table {}'.format(relation) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) - @classmethod - def rename(cls, config, schema, - from_name, to_name, model_name=None): - quote_policy = config.quoting - from_relation = cls.Relation.create( + def rename(self, schema, from_name, to_name, model_name=None): + quote_policy = self.config.quoting + from_relation = self.Relation.create( schema=schema, identifier=from_name, quote_policy=quote_policy ) - to_relation = cls.Relation.create( + to_relation = self.Relation.create( identifier=to_name, quote_policy=quote_policy ) - return cls.rename_relation( - config, + return self.rename_relation( from_relation=from_relation, to_relation=to_relation, model_name=model_name) - @classmethod - def rename_relation(cls, config, from_relation, to_relation, + def rename_relation(self, from_relation, to_relation, model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation.include(schema=False)) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) @classmethod def is_cancelable(cls): return True - @classmethod - def get_missing_columns(cls, config, - from_schema, from_table, - to_schema, to_table, - model_name=None): + def get_missing_columns(self, from_schema, from_table, + to_schema, to_table, model_name=None): """Returns dict of {column:type} for columns in from_table that are missing from to_table""" - from_columns = {col.name: col for col in - cls.get_columns_in_table( - config, from_schema, from_table, - model_name=model_name)} - to_columns = {col.name: col for col in - cls.get_columns_in_table( - config, to_schema, to_table, - model_name=model_name)} + from_columns = { + col.name: col for col in + self.get_columns_in_table( + from_schema, from_table, + model_name=model_name) + } + to_columns = { + col.name: col for col in + self.get_columns_in_table( + to_schema, to_table, + model_name=model_name) + } missing_columns = set(from_columns.keys()) - set(to_columns.keys()) @@ -287,18 +272,17 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): return sql - @classmethod - def get_columns_in_table(cls, config, schema_name, + def get_columns_in_table(self, schema_name, table_name, database=None, model_name=None): - sql = cls._get_columns_in_table_sql(schema_name, table_name, database) - connection, cursor = cls.add_query(config, sql, model_name) + sql = self._get_columns_in_table_sql(schema_name, table_name, database) + connection, cursor = self.add_query(sql, model_name) data = cursor.fetchall() columns = [] for row in data: name, data_type, char_size, numeric_size = row - column = cls.Column(name, data_type, char_size, numeric_size) + column = self.Column(name, data_type, char_size, numeric_size) columns.append(column) return columns @@ -307,20 +291,19 @@ def get_columns_in_table(cls, config, schema_name, def _table_columns_to_dict(cls, columns): return {col.name: col for col in columns} - @classmethod - def expand_target_column_types(cls, config, + def expand_target_column_types(self, temp_table, to_schema, to_table, model_name=None): - reference_columns = cls._table_columns_to_dict( - cls.get_columns_in_table( - config, None, temp_table, model_name=model_name)) + reference_columns = self._table_columns_to_dict( + self.get_columns_in_table(None, temp_table, model_name=model_name) + ) - target_columns = cls._table_columns_to_dict( - cls.get_columns_in_table( - config, to_schema, to_table, - model_name=model_name)) + target_columns = self._table_columns_to_dict( + self.get_columns_in_table(to_schema, to_table, + model_name=model_name) + ) for column_name, reference_column in reference_columns.items(): target_column = target_columns.get(column_name) @@ -328,38 +311,35 @@ def expand_target_column_types(cls, config, if target_column is not None and \ target_column.can_expand_to(reference_column): col_string_size = reference_column.string_size() - new_type = cls.Column.string_type(col_string_size) + new_type = self.Column.string_type(col_string_size) logger.debug("Changing col type from %s to %s in table %s.%s", target_column.data_type, new_type, to_schema, to_table) - cls.alter_column_type(config, to_schema, - to_table, column_name, new_type, - model_name) + self.alter_column_type(to_schema, to_table, column_name, + new_type, model_name) ### # RELATIONS ### - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): raise dbt.exceptions.NotImplementedException( '`list_relations` is not implemented for this adapter!') - @classmethod - def _make_match_kwargs(cls, config, schema, identifier): - if identifier is not None and config.quoting['identifier'] is False: + def _make_match_kwargs(self, schema, identifier): + quoting = self.config.quoting + if identifier is not None and quoting['identifier'] is False: identifier = identifier.lower() - if schema is not None and config.quoting['schema'] is False: + if schema is not None and quoting['schema'] is False: schema = schema.lower() return filter_null_values({'identifier': identifier, 'schema': schema}) - @classmethod - def get_relation(cls, config, schema=None, identifier=None, + def get_relation(self, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -367,11 +347,11 @@ def get_relation(cls, config, schema=None, identifier=None, 'of relations to use') if relations_list is None: - relations_list = cls.list_relations(config, schema, model_name) + relations_list = self.list_relations(schema, model_name) matches = [] - search = cls._make_match_kwargs(config, schema, identifier) + search = self._make_match_kwargs(schema, identifier) for relation in relations_list: if relation.matches(**search): @@ -389,16 +369,14 @@ def get_relation(cls, config, schema=None, identifier=None, ### # SANE ANSI SQL DEFAULTS ### - @classmethod - def get_create_schema_sql(cls, config, schema): - schema = cls._quote_as_configured(config, schema, 'schema') + def get_create_schema_sql(self, schema): + schema = self.quote_as_configured(schema, 'schema') return ('create schema if not exists {schema}' .format(schema=schema)) - @classmethod - def get_drop_schema_sql(cls, config, schema): - schema = cls._quote_as_configured(config, schema, 'schema') + def get_drop_schema_sql(self, schema): + schema = self.quote_as_configured(schema, 'schema') return ('drop schema if exists {schema} cascade' .format(schema=schema)) @@ -407,12 +385,10 @@ def get_drop_schema_sql(cls, config, schema): # ODBC FUNCTIONS -- these should not need to change for every adapter, # although some adapters may override them ### - @classmethod - def get_default_schema(cls, config): - return config.credentials.schema + def get_default_schema(self): + return self.config.credentials.schema - @classmethod - def get_connection(cls, config, name=None, recache_if_missing=True): + def get_connection(self, name=None, recache_if_missing=True): global connections_in_use if name is None: @@ -429,22 +405,21 @@ def get_connection(cls, config, name=None, recache_if_missing=True): '(recache_if_missing is off).'.format(name)) logger.debug('Acquiring new {} connection "{}".' - .format(cls.type(), name)) + .format(self.type(), name)) - connection = cls.acquire_connection(config, name) + connection = self.acquire_connection(name) connections_in_use[name] = connection - return cls.get_connection(config, name) + return self.get_connection(name) - @classmethod - def cancel_open_connections(cls, config): + def cancel_open_connections(self): global connections_in_use for name, connection in connections_in_use.items(): if name == 'master': continue - cls.cancel_connection(config, connection) + self.cancel_connection(connection) yield name @classmethod @@ -453,17 +428,16 @@ def total_connections_allocated(cls): return len(connections_in_use) + len(connections_available) - @classmethod - def acquire_connection(cls, config, name): + def acquire_connection(self, name): global connections_available, lock # we add a magic number, 2 because there are overhead connections, # one for pre- and post-run hooks and other misc operations that occur # before the run starts, and one for integration tests. - max_connections = config.threads + 2 + max_connections = self.config.threads + 2 with lock: - num_allocated = cls.total_connections_allocated() + num_allocated = self.total_connections_allocated() if len(connections_available) > 0: logger.debug('Re-using an available connection from the pool.') @@ -481,41 +455,37 @@ def acquire_connection(cls, config, name): .format(num_allocated)) result = Connection( - type=cls.type(), + type=self.type(), name=name, state='init', transaction_open=False, handle=None, - credentials=config.credentials + credentials=self.config.credentials ) - return cls.open_connection(result) + return self.open_connection(result) - @classmethod - def release_connection(cls, config, name='master'): + def release_connection(self, name): global connections_in_use, connections_available, lock - if name not in connections_in_use: - return + with lock: - to_release = cls.get_connection(config, name, recache_if_missing=False) + if name not in connections_in_use: + return - try: - lock.acquire() + to_release = self.get_connection(name, recache_if_missing=False) if to_release.state == 'open': if to_release.transaction_open is True: - cls.rollback(to_release) + self.rollback(to_release) to_release.name = None connections_available.append(to_release) else: - cls.close(to_release) + self.close(to_release) del connections_in_use[name] - finally: - lock.release() @classmethod def cleanup_connections(cls): @@ -538,23 +508,18 @@ def cleanup_connections(cls): connections_in_use = {} connections_available = [] - @classmethod - def reload(cls, connection): - return cls.get_connection(connection.credentials, - connection.name) + def reload(self, connection): + return self.get_connection(connection.name) - @classmethod - def add_begin_query(cls, config, name): - return cls.add_query(config, 'BEGIN', name, auto_begin=False) + def add_begin_query(self, name): + return self.add_query('BEGIN', name, auto_begin=False) - @classmethod - def add_commit_query(cls, config, name): - return cls.add_query(config, 'COMMIT', name, auto_begin=False) + def add_commit_query(self, name): + return self.add_query('COMMIT', name, auto_begin=False) - @classmethod - def begin(cls, config, name='master'): + def begin(self, name): global connections_in_use - connection = cls.get_connection(config, name) + connection = self.get_connection(name) if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -564,15 +529,14 @@ def begin(cls, config, name='master'): 'Tried to begin a new transaction on connection "{}", but ' 'it already had one open!'.format(connection.get('name'))) - cls.add_begin_query(config, name) + self.add_begin_query(name) connection.transaction_open = True connections_in_use[name] = connection return connection - @classmethod - def commit_if_has_connection(cls, config, name): + def commit_if_has_connection(self, name): global connections_in_use if name is None: @@ -581,18 +545,17 @@ def commit_if_has_connection(cls, config, name): if connections_in_use.get(name) is None: return - connection = cls.get_connection(config, name, False) + connection = self.get_connection(name, False) - return cls.commit(config, connection) + return self.commit(connection) - @classmethod - def commit(cls, config, connection): + def commit(self, connection): global connections_in_use if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) - connection = cls.reload(connection) + connection = self.reload(connection) if connection.transaction_open is False: raise dbt.exceptions.InternalException( @@ -600,19 +563,18 @@ def commit(cls, config, connection): 'it does not have one open!'.format(connection.name)) logger.debug('On {}: COMMIT'.format(connection.name)) - cls.add_commit_query(config, connection.name) + self.add_commit_query(connection.name) connection.transaction_open = False connections_in_use[connection.name] = connection return connection - @classmethod - def rollback(cls, connection): + def rollback(self, connection): if dbt.flags.STRICT_MODE: Connection(**connection) - connection = cls.reload(connection) + connection = self.reload(connection) if connection.transaction_open is False: raise dbt.exceptions.InternalException( @@ -640,19 +602,18 @@ def close(cls, connection): return connection - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): - connection = cls.get_connection(config, model_name) + connection = self.get_connection(model_name) connection_name = connection.name if auto_begin and connection.transaction_open is False: - cls.begin(config, connection_name) + self.begin(connection_name) logger.debug('Using {} connection "{}".' - .format(cls.type(), connection_name)) + .format(self.type(), connection_name)) - with cls.exception_handler(config, sql, model_name, connection_name): + with self.exception_handler(sql, model_name, connection_name): if abridge_sql_log: logger.debug('On %s: %s....', connection_name, sql[0:512]) else: @@ -663,99 +624,82 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, cursor.execute(sql, bindings) logger.debug("SQL status: %s in %0.2f seconds", - cls.get_status(cursor), (time.time() - pre)) + self.get_status(cursor), (time.time() - pre)) return connection, cursor - @classmethod - def clear_transaction(cls, config, conn_name='master'): - conn = cls.begin(config, conn_name) - cls.commit(config, conn) + def clear_transaction(self, conn_name='master'): + conn = self.begin(conn_name) + self.commit(conn) return conn_name - @classmethod - def execute_one(cls, config, sql, model_name=None, auto_begin=False): - cls.get_connection(config, model_name) + def execute_one(self, sql, model_name=None, auto_begin=False): + self.get_connection(model_name) - return cls.add_query(config, sql, model_name, auto_begin) + return self.add_query(sql, model_name, auto_begin) - @classmethod - def execute_and_fetch(cls, config, sql, model_name=None, + def execute_and_fetch(self, sql, model_name=None, auto_begin=False): - _, cursor = cls.execute_one(config, sql, model_name, auto_begin) + _, cursor = self.execute_one(sql, model_name, auto_begin) - status = cls.get_status(cursor) - table = cls.get_result_from_cursor(cursor) + status = self.get_status(cursor) + table = self.get_result_from_cursor(cursor) return status, table - @classmethod - def execute(cls, config, sql, model_name=None, auto_begin=False, + def execute(self, sql, model_name=None, auto_begin=False, fetch=False): if fetch: - return cls.execute_and_fetch(config, sql, model_name, auto_begin) + return self.execute_and_fetch(sql, model_name, auto_begin) else: - _, cursor = cls.execute_one(config, sql, model_name, auto_begin) - status = cls.get_status(cursor) + _, cursor = self.execute_one(sql, model_name, auto_begin) + status = self.get_status(cursor) return status, dbt.clients.agate_helper.empty_table() - @classmethod - def execute_all(cls, config, sqls, model_name=None): - connection = cls.get_connection(config, model_name) + def execute_all(self, sqls, model_name=None): + connection = self.get_connection(model_name) if len(sqls) == 0: return connection for i, sql in enumerate(sqls): - connection, _ = cls.add_query(config, sql, model_name) + connection, _ = self.add_query(sql, model_name) return connection - @classmethod - def create_schema(cls, config, schema, model_name=None): + def create_schema(self, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - sql = cls.get_create_schema_sql(config, schema) - res = cls.add_query(config, sql, model_name) + sql = self.get_create_schema_sql(schema) + res = self.add_query(sql, model_name) - cls.commit_if_has_connection(config, model_name) + self.commit_if_has_connection(model_name) return res - @classmethod - def drop_schema(cls, config, schema, model_name=None): + def drop_schema(self, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - sql = cls.get_drop_schema_sql(config, schema) - return cls.add_query(config, sql, model_name) + sql = self.get_drop_schema_sql(schema) + return self.add_query(sql, model_name) - @classmethod - def already_exists(cls, config, schema, table, model_name=None): - relation = cls.get_relation(config, schema=schema, identifier=table) + def already_exists(self, schema, table, model_name=None): + relation = self.get_relation(schema=schema, identifier=table) return relation is not None @classmethod def quote(cls, identifier): return '"{}"'.format(identifier) - @classmethod - def _quote_as_configured(cls, config, identifier, quote_key): - """This is the actual implementation of quote_as_configured, without - the extra arguments needed for use inside materialization code. - """ - default = cls.Relation.DEFAULTS['quote_policy'].get(quote_key) - if config.quoting.get(quote_key, default): - return cls.quote(identifier) - else: - return identifier - - @classmethod - def quote_as_configured(cls, config, identifier, quote_key, - model_name=None): + def quote_as_configured(self, identifier, quote_key, model_name=None): """Quote or do not quote the given identifer as configured in the project config for the quote key. The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ - return cls._quote_as_configured(config, identifier, quote_key) + default = self.Relation.DEFAULTS['quote_policy'].get(quote_key) + if self.config.quoting.get(quote_key, default): + return self.quote(identifier) + else: + return identifier @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -809,8 +753,7 @@ def convert_agate_type(cls, agate_table, col_idx): ### # Operations involving the manifest ### - @classmethod - def run_operation(cls, config, manifest, operation_name): + def run_operation(self, manifest, operation_name): """Look the operation identified by operation_name up in the manifest and run it. @@ -824,7 +767,7 @@ def run_operation(cls, config, manifest, operation_name): import dbt.context.runtime context = dbt.context.runtime.generate( operation, - config, + self.config, manifest, ) @@ -838,13 +781,12 @@ def run_operation(cls, config, manifest, operation_name): def _filter_table(cls, table, manifest): return table.where(_filter_schemas(manifest)) - @classmethod - def get_catalog(cls, config, manifest): + def get_catalog(self, manifest): try: - table = cls.run_operation(config, manifest, - GET_CATALOG_OPERATION_NAME) + table = self.run_operation(manifest, + GET_CATALOG_OPERATION_NAME) finally: - cls.release_connection(config, GET_CATALOG_OPERATION_NAME) + self.release_connection(GET_CATALOG_OPERATION_NAME) - results = cls._filter_table(table, manifest) + results = self._filter_table(table, manifest) return results diff --git a/dbt/adapters/default/relation.py b/dbt/adapters/default/relation.py index b5d1f46d38f..03928bc3939 100644 --- a/dbt/adapters/default/relation.py +++ b/dbt/adapters/default/relation.py @@ -173,14 +173,27 @@ def quoted(self, identifier): identifier=identifier) @classmethod - def create_from_node(cls, project, node, table_name=None, **kwargs): + def _create_from_node(cls, config, node, table_name, quote_policy, + **kwargs): return cls.create( - database=project.credentials.dbname, + database=config.credentials.dbname, schema=node.get('schema'), identifier=node.get('alias'), table_name=table_name, + quote_policy=quote_policy, **kwargs) + @classmethod + def create_from_node(cls, config, node, table_name=None, quote_policy=None, + **kwargs): + if quote_policy is None: + quote_policy = {} + + quote_policy = dbt.utils.merge(config.quoting, quote_policy) + return cls._create_from_node(config=config, quote_policy=quote_policy, + node=node, table_name=table_name, + **kwargs) + @classmethod def create(cls, database=None, schema=None, identifier=None, table_name=None, diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index 7d450dc2b73..b1ec2bbcd0f 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -7,21 +7,26 @@ import dbt.exceptions +import threading -adapters = { + +ADAPTER_TYPES = { 'postgres': PostgresAdapter, 'redshift': RedshiftAdapter, 'snowflake': SnowflakeAdapter, 'bigquery': BigQueryAdapter } +_ADAPTERS = {} +_ADAPTER_LOCK = threading.Lock() + -def get_adapter_by_name(adapter_name): - adapter = adapters.get(adapter_name, None) +def get_adapter_class_by_name(adapter_name): + adapter = ADAPTER_TYPES.get(adapter_name, None) if adapter is None: message = "Invalid adapter type {}! Must be one of {}" - adapter_names = ", ".join(adapters.keys()) + adapter_names = ", ".join(ADAPTER_TYPES.keys()) formatted_message = message.format(adapter_name, adapter_names) raise dbt.exceptions.RuntimeException(formatted_message) @@ -30,4 +35,28 @@ def get_adapter_by_name(adapter_name): def get_adapter(config): - return get_adapter_by_name(config.credentials.type) + adapter_name = config.credentials.type + if adapter_name in _ADAPTERS: + return _ADAPTERS[adapter_name] + + adapter_type = get_adapter_class_by_name(adapter_name) + with _ADAPTER_LOCK: + # check again, in case something was setting it before + if adapter_name in _ADAPTERS: + return _ADAPTERS[adapter_name] + + adapter = adapter_type(config) + _ADAPTERS[adapter_name] = adapter + return adapter + + +def reset_adapters(): + """Clear the adapters. This is useful for tests, which change configs. + """ + with _ADAPTER_LOCK: + _ADAPTERS.clear() + + +def get_relation_class_by_name(adapter_name): + adapter = get_adapter_class_by_name(adapter_name) + return adapter.Relation diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 3ad39e3fb46..29355df5a76 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -14,10 +14,8 @@ class PostgresAdapter(dbt.adapters.default.DefaultAdapter): DEFAULT_TCP_KEEPALIVE = 0 # 0 means to use the default value - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, - connection_name=None): + def exception_handler(self, sql, model_name=None, connection_name=None): try: yield @@ -26,7 +24,7 @@ def exception_handler(cls, config, sql, model_name=None, try: # attempt to release the connection - cls.release_connection(config, connection_name) + self.release_connection(connection_name) except psycopg2.Error: logger.debug("Failed to release connection!") pass @@ -37,7 +35,7 @@ def exception_handler(cls, config, sql, model_name=None, except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.RuntimeException(e) @classmethod @@ -96,8 +94,7 @@ def open_connection(cls, connection): return connection - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): connection_name = connection.name pid = connection.handle.get_backend_pid() @@ -105,7 +102,7 @@ def cancel_connection(cls, config, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) - _, cursor = cls.add_query(config, sql, 'master') + _, cursor = self.add_query(sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) @@ -113,8 +110,7 @@ def cancel_connection(cls, config, connection): # DATABASE INSPECTION FUNCTIONS # These require the profile AND project, as they need to know # database-specific configs at the project level. - @classmethod - def alter_column_type(cls, config, schema, table, column_name, + def alter_column_type(self, schema, table, column_name, new_column_type, model_name=None): """ 1. Create a new column (w/ temp name and correct type) @@ -123,10 +119,10 @@ def alter_column_type(cls, config, schema, table, column_name, 4. Rename the new column to existing column """ - relation = cls.Relation.create( + relation = self.Relation.create( schema=schema, identifier=table, - quote_policy=config.quoting + quote_policy=self.config.quoting ) opts = { @@ -143,12 +139,11 @@ def alter_column_type(cls, config, schema, table, column_name, alter table {relation} rename column "{tmp_column}" to "{old_column}"; """.format(**opts).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) return connection, cursor - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): sql = """ select tablename as name, schemaname as schema, 'table' as type from pg_tables where schemaname ilike '{schema}' @@ -157,13 +152,13 @@ def list_relations(cls, config, schema, model_name=None): where schemaname ilike '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, + auto_begin=False) results = cursor.fetchall() - return [cls.Relation.create( - database=config.credentials.dbname, + return [self.Relation.create( + database=self.config.credentials.dbname, schema=_schema, identifier=name, quote_policy={ @@ -173,24 +168,21 @@ def list_relations(cls, config, schema, model_name=None): type=type) for (name, _schema, type) in results] - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): sql = "select distinct nspname from pg_namespace" - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): + def check_schema_exists(self, schema, model_name=None): sql = """ select count(*) from pg_namespace where nspname = '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, + auto_begin=False) results = cursor.fetchone() return results[0] > 0 diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index a1f36cb297e..a15ec186f0e 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -156,8 +156,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): table_schema_filter=table_schema_filter).strip() return sql - @classmethod - def drop_relation(cls, config, relation, model_name=None): + def drop_relation(self, relation, model_name=None): """ In Redshift, DROP TABLE ... CASCADE should not be used inside a transaction. Redshift doesn't prevent the CASCADE @@ -178,18 +177,18 @@ def drop_relation(cls, config, relation, model_name=None): with drop_lock: - connection = cls.get_connection(config, model_name) + connection = self.get_connection(model_name) if connection.transaction_open: - cls.commit(config, connection) + self.commit(connection) - cls.begin(config, connection.name) + self.begin(connection.name) - to_return = super(PostgresAdapter, cls).drop_relation( - config, relation, model_name) + to_return = super(RedshiftAdapter, self).drop_relation( + relation, model_name) - cls.commit(config, connection) - cls.begin(config, connection.name) + self.commit(connection) + self.begin(connection.name) return to_return diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 47bc08efa96..44963c38a22 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -20,11 +20,10 @@ class SnowflakeAdapter(PostgresAdapter): Relation = SnowflakeRelation - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name='master'): - connection = cls.get_connection(config, connection_name) + connection = self.get_connection(connection_name) try: yield @@ -36,7 +35,7 @@ def exception_handler(cls, config, sql, model_name=None, if 'Empty SQL statement' in msg: logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' @@ -44,12 +43,12 @@ def exception_handler(cls, config, sql, model_name=None, 'Please double check your profile and try again.') .format(msg)) else: - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.RuntimeException(e.msg) @classmethod @@ -102,8 +101,7 @@ def open_connection(cls, connection): return connection - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): sql = """ select table_name as name, table_schema as schema, table_type as type @@ -111,8 +109,7 @@ def list_relations(cls, config, schema, model_name=None): where table_schema ilike '{schema}' """.format(schema=schema).strip() # noqa - _, cursor = cls.add_query( - config, sql, model_name, auto_begin=False) + _, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() @@ -121,8 +118,8 @@ def list_relations(cls, config, schema, model_name=None): 'VIEW': 'view' } - return [cls.Relation.create( - database=config.credentials.database, + return [self.Relation.create( + database=self.config.credentials.database, schema=_schema, identifier=name, quote_policy={ @@ -132,38 +129,32 @@ def list_relations(cls, config, schema, model_name=None): type=relation_type_lookup.get(type)) for (name, _schema, type) in results] - @classmethod - def rename_relation(cls, config, from_relation, to_relation, + def rename_relation(self, from_relation, to_relation, model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) - @classmethod - def add_begin_query(cls, config, name): - return cls.add_query(config, 'BEGIN', name, auto_begin=False) + def add_begin_query(self, name): + return self.add_query('BEGIN', name, auto_begin=False) - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): sql = "select distinct schema_name from information_schema.schemata" - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): + def check_schema_exists(self, schema, model_name=None): sql = """ select count(*) from information_schema.schemata where upper(schema_name) = upper('{schema}') """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchone() return results[0] > 0 @@ -177,8 +168,7 @@ def _split_queries(cls, sql): split_query = snowflake.connector.util_text.split_statements(sql_buf) return [part[0] for part in split_query] - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): connection = None @@ -189,7 +179,7 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, # which allows any iterable thing to be passed as a binding. bindings = tuple(bindings) - queries = cls._split_queries(sql) + queries = self._split_queries(sql) for individual_query in queries: # hack -- after the last ';', remove comments and don't run @@ -202,9 +192,10 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, if without_comments == "": continue - connection, cursor = super(PostgresAdapter, cls).add_query( - config, individual_query, model_name, auto_begin, - bindings=bindings, abridge_sql_log=abridge_sql_log) + connection, cursor = super(SnowflakeAdapter, self).add_query( + individual_query, model_name, auto_begin, bindings=bindings, + abridge_sql_log=abridge_sql_log + ) if cursor is None: raise dbt.exceptions.RuntimeException( @@ -224,19 +215,18 @@ def _filter_table(cls, table, manifest): ) return super(SnowflakeAdapter, cls)._filter_table(lowered, manifest) - @classmethod - def _make_match_kwargs(cls, config, schema, identifier): - if identifier is not None and config.quoting['identifier'] is False: + def _make_match_kwargs(self, schema, identifier): + quoting = self.config.quoting + if identifier is not None and quoting['identifier'] is False: identifier = identifier.upper() - if schema is not None and config.quoting['schema'] is False: + if schema is not None and quoting['schema'] is False: schema = schema.upper() return filter_null_values({'identifier': identifier, 'schema': schema}) - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): handle = connection.handle sid = handle.session_id @@ -246,7 +236,7 @@ def cancel_connection(cls, config, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) - _, cursor = cls.add_query(config, sql, 'master') + _, cursor = self.add_query(sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/dbt/adapters/snowflake/relation.py b/dbt/adapters/snowflake/relation.py index bd879965404..bf5b61c6485 100644 --- a/dbt/adapters/snowflake/relation.py +++ b/dbt/adapters/snowflake/relation.py @@ -1,4 +1,5 @@ from dbt.adapters.default.relation import DefaultRelation +import dbt.utils class SnowflakeRelation(DefaultRelation): @@ -44,7 +45,7 @@ class SnowflakeRelation(DefaultRelation): } @classmethod - def create_from_node(cls, config, node, **kwargs): + def _create_from_node(cls, config, node, **kwargs): return cls.create( database=config.credentials.database, schema=node.get('schema'), diff --git a/dbt/config.py b/dbt/config.py index d56826f5daf..cadb7d8411f 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -12,6 +12,7 @@ PackageConfig, ProfileConfig from dbt.context.common import env_var from dbt import compat +from dbt.adapters.factory import get_relation_class_by_name from dbt.logger import GLOBAL_LOGGER as logger @@ -19,18 +20,6 @@ DEFAULT_THREADS = 1 DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True DEFAULT_USE_COLORS = True -DEFAULT_QUOTING_GLOBAL = { - 'identifier': True, - 'schema': True, -} -# some adapters need different quoting rules, for example snowflake gets a bit -# weird with quoting on -DEFAULT_QUOTING_ADAPTER = { - 'snowflake': { - 'identifier': False, - 'schema': False, - }, -} DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt') @@ -676,8 +665,8 @@ def from_parts(cls, project, profile, cli_vars): :returns RuntimeConfig: The new configuration. """ quoting = deepcopy( - DEFAULT_QUOTING_ADAPTER.get(profile.credentials.type, - DEFAULT_QUOTING_GLOBAL) + get_relation_class_by_name(profile.credentials.type) + .DEFAULTS['quote_policy'] ) quoting.update(project.quoting) return cls( @@ -725,11 +714,14 @@ def new_project(self, project_root): # load the new project and its packages project = Project.from_project_root(project_root) - return self.from_parts( + cfg = self.from_parts( project=project, profile=profile, cli_vars=deepcopy(self.cli_vars) ) + # force our quoting back onto the new project. + cfg.quoting = deepcopy(self.quoting) + return cfg def serialize(self): """Serialize the full configuration to a single dictionary. For any diff --git a/dbt/context/common.py b/dbt/context/common.py index 8187207270e..0c5dbc23ff6 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -1,3 +1,5 @@ +import copy +import functools import json import os @@ -24,51 +26,67 @@ import datetime +class RelationProxy(object): + def __init__(self, adapter): + self.quoting_config = adapter.config.quoting + self.relation_type = adapter.Relation + + def __getattr__(self, key): + return getattr(self.relation_type, key) + + def create(self, *args, **kwargs): + kwargs['quote_policy'] = dbt.utils.merge( + self.quoting_config, + kwargs.pop('quote_policy', {}) + ) + return self.relation_type.create(*args, **kwargs) + + class DatabaseWrapper(object): """ - Wrapper for runtime database interaction. Should only call adapter - functions. + Wrapper for runtime database interaction. Mostly a compatibility layer now. """ - - def __init__(self, model, adapter, config): + def __init__(self, model, adapter): self.model = model self.adapter = adapter - self.config = config - self.Relation = adapter.Relation - - # Fun with metaprogramming - # Most adapter functions take `profile` as the first argument, and - # `model_name` as the last. This automatically injects those arguments. - # In model code, these functions can be called without those two args. - for context_function in self.adapter.context_functions: - setattr(self, - context_function, - self.wrap(context_function, (self.config,))) - - for profile_function in self.adapter.profile_functions: - setattr(self, - profile_function, - self.wrap(profile_function, (self.config,))) - - for raw_function in self.adapter.raw_functions: - setattr(self, - raw_function, - getattr(self.adapter, raw_function)) - - def wrap(self, fn, arg_prefix): + self.Relation = RelationProxy(adapter) + + self._wrapped = frozenset( + self.adapter.config_functions + ) + self._proxied = frozenset(self.adapter.raw_functions) + + def wrap(self, name): + func = getattr(self.adapter, name) + + @functools.wraps(func) def wrapped(*args, **kwargs): - args = arg_prefix + args kwargs['model_name'] = self.model.get('name') - return getattr(self.adapter, fn)(*args, **kwargs) + return func(*args, **kwargs) return wrapped + def __getattr__(self, name): + if name in self._wrapped: + return self.wrap(name) + elif name in self._proxied: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + @property + def config(self): + return self.adapter.config + def type(self): return self.adapter.type() def commit(self): - return self.adapter.commit_if_has_connection( - self.config, self.model.get('name')) + return self.adapter.commit_if_has_connection(self.model.get('name')) def _add_macros(context, model, manifest): @@ -303,36 +321,7 @@ def _return(value): def get_this_relation(db_wrapper, config, model): - return db_wrapper.adapter.Relation.create_from_node( - config, model) - - -def create_relation(relation_type, quoting_config): - - class RelationWithContext(relation_type): - @classmethod - def create(cls, *args, **kwargs): - quote_policy = quoting_config - - if 'quote_policy' in kwargs: - quote_policy = dbt.utils.merge( - quote_policy, - kwargs.pop('quote_policy')) - - return relation_type.create(*args, - quote_policy=quote_policy, - **kwargs) - - return RelationWithContext - - -def create_adapter(adapter_type, relation_type): - - class AdapterWithContext(adapter_type): - - Relation = relation_type - - return AdapterWithContext + return db_wrapper.Relation.create_from_node(config, model) def generate_base(model, model_dict, config, manifest, source_config, @@ -357,16 +346,12 @@ def generate_base(model, model_dict, config, manifest, source_config, pre_hooks = None post_hooks = None - relation_type = create_relation(adapter.Relation, - config.quoting) + db_wrapper = DatabaseWrapper(model_dict, adapter) - db_wrapper = DatabaseWrapper(model_dict, - create_adapter(adapter, relation_type), - config) context = dbt.utils.merge(context, { "adapter": db_wrapper, "api": { - "Relation": relation_type, + "Relation": db_wrapper.Relation, "Column": adapter.Column, }, "column": adapter.Column, diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index f46c51b2107..6e69ac2b1b3 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -115,6 +115,12 @@ 'schema': { 'type': 'boolean', }, + 'database': { + 'type': 'boolean', + }, + 'project': { + 'type': 'boolean', + } }, }, 'models': { diff --git a/dbt/node_runners.py b/dbt/node_runners.py index f67618937a2..20a25a835e1 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -144,7 +144,7 @@ def _safe_release_connection(self): """ node_name = self.node.name try: - self.adapter.release_connection(self.config, node_name) + self.adapter.release_connection(node_name) except Exception as exc: logger.debug( 'Error releasing connection for node {}: {!s}\n{}' @@ -229,7 +229,7 @@ def compile(self, manifest): def _compile_node(cls, adapter, config, node, manifest): compiler = dbt.compilation.Compiler(config) node = compiler.compile_node(node, manifest) - node = cls._inject_runtime_config(adapter, config, node) + node = cls._inject_runtime_config(adapter, node) if(node.injected_sql is not None and not (dbt.utils.is_type(node, NodeType.Archive))): @@ -247,30 +247,29 @@ def _compile_node(cls, adapter, config, node, manifest): return node @classmethod - def _inject_runtime_config(cls, adapter, config, node): + def _inject_runtime_config(cls, adapter, node): wrapped_sql = node.wrapped_sql - context = cls._node_context(adapter, config, node) + context = cls._node_context(adapter, node) sql = dbt.clients.jinja.get_rendered(wrapped_sql, context) node.wrapped_sql = sql return node @classmethod - def _node_context(cls, adapter, config, node): + def _node_context(cls, adapter, node): def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( - config, schema_name, - table_name, model_name=node.alias) + schema_name, table_name, model_name=node.alias + ) def call_get_missing_columns(from_schema, from_table, to_schema, to_table): return adapter.get_missing_columns( - config, from_schema, from_table, - to_schema, to_table, node.alias) + from_schema, from_table, to_schema, to_table, node.alias + ) def call_already_exists(schema, table): - return adapter.already_exists( - config, schema, table, node.alias) + return adapter.already_exists(schema, table, node.alias) return { "run_started_at": dbt.tracking.active_user.run_started_at, @@ -304,7 +303,7 @@ def run_hooks(cls, config, adapter, manifest, hook_type): # implement a for-loop over these sql statements in jinja-land. # Also, consider configuring psycopg2 (and other adapters?) to # ensure that a transaction is only created if dbt initiates it. - adapter.clear_transaction(config, model_name) + adapter.clear_transaction(model_name) compiled = cls._compile_node(adapter, config, hook, manifest) statement = compiled.wrapped_sql @@ -317,10 +316,10 @@ def run_hooks(cls, config, adapter, manifest, hook_type): sql = hook_dict.get('sql', '') if len(sql.strip()) > 0: - adapter.execute(config, sql, model_name=model_name, - auto_begin=False, fetch=False) + adapter.execute(sql, model_name=model_name, auto_begin=False, + fetch=False) - adapter.release_connection(config, model_name) + adapter.release_connection(model_name) @classmethod def safe_run_hooks(cls, config, adapter, manifest, hook_type): @@ -339,12 +338,12 @@ def create_schemas(cls, config, adapter, manifest): # is the one defined in the profile. Create this schema if it # does not exist, otherwise subsequent queries will fail. Generally, # dbt expects that this schema will exist anyway. - required_schemas.add(adapter.get_default_schema(config)) + required_schemas.add(adapter.get_default_schema()) - existing_schemas = set(adapter.get_existing_schemas(config)) + existing_schemas = set(adapter.get_existing_schemas()) for schema in (required_schemas - existing_schemas): - adapter.create_schema(config, schema) + adapter.create_schema(schema) @classmethod def before_run(cls, config, adapter, manifest): @@ -443,7 +442,6 @@ def print_start_line(self): def execute_test(self, test): res, table = self.adapter.execute_and_fetch( - self.config, test.wrapped_sql, test.name, auto_begin=True) diff --git a/dbt/parser/base.py b/dbt/parser/base.py index 1ee4bb15178..89ca46a5c33 100644 --- a/dbt/parser/base.py +++ b/dbt/parser/base.py @@ -119,7 +119,7 @@ def parse_node(cls, node, node_path, root_project_config, db_wrapper = context['adapter'] adapter = db_wrapper.adapter runtime_config = db_wrapper.config - adapter.release_connection(runtime_config, parsed_node.name) + adapter.release_connection(parsed_node.name) # Special macro defined in the global project schema_override = config.config.get('schema') diff --git a/dbt/runner.py b/dbt/runner.py index 9a359b17c86..e968ca7d83b 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -150,7 +150,7 @@ def execute_nodes(self, linker, Runner, manifest, node_dependency_list): dbt.ui.printer.print_timestamped_line(msg, yellow) raise - for conn_name in adapter.cancel_open_connections(self.config): + for conn_name in adapter.cancel_open_connections(): dbt.ui.printer.print_cancel_line(conn_name) dbt.ui.printer.print_run_end_messages(node_results, diff --git a/dbt/task/generate.py b/dbt/task/generate.py index 05808f1016f..6cc1df5e02c 100644 --- a/dbt/task/generate.py +++ b/dbt/task/generate.py @@ -213,7 +213,7 @@ def run(self): adapter = get_adapter(self.config) dbt.ui.printer.print_timestamped_line("Building catalog") - results = adapter.get_catalog(self.config, manifest) + results = adapter.get_catalog(manifest) results = [ dict(zip(results.column_names, row)) diff --git a/test/integration/001_simple_copy_test/test_simple_copy.py b/test/integration/001_simple_copy_test/test_simple_copy.py index 6f9930827d3..7cc4ed580e4 100644 --- a/test/integration/001_simple_copy_test/test_simple_copy.py +++ b/test/integration/001_simple_copy_test/test_simple_copy.py @@ -16,6 +16,7 @@ def dir(path): def models(self): return self.dir("models") + class TestSimpleCopy(BaseTestSimpleCopy): @use_profile("postgres") def test__postgres__simple_copy(self): diff --git a/test/integration/004_simple_archive_test/test_simple_archive.py b/test/integration/004_simple_archive_test/test_simple_archive.py index 970990ecaef..b6924bd7c40 100644 --- a/test/integration/004_simple_archive_test/test_simple_archive.py +++ b/test/integration/004_simple_archive_test/test_simple_archive.py @@ -162,8 +162,8 @@ def test__bigquery__archive_with_new_field(self): # A more thorough test would assert that archived == expected, but BigQuery does not support the # "EXCEPT DISTINCT" operator on nested fields! Instead, just check that schemas are congruent. - expected_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_expected') - archived_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_actual') + expected_cols = self.adapter.get_columns_in_table(self.unique_schema(), 'archive_expected') + archived_cols = self.adapter.get_columns_in_table(self.unique_schema(), 'archive_actual') self.assertTrue(len(expected_cols) > 0, "source table does not exist -- bad test") self.assertEqual(len(expected_cols), len(archived_cols), "actual and expected column lengths are different") diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index ca67a3f6ae3..955e324eb27 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -26,6 +26,11 @@ def __eq__(self, other): return isinstance(other, basestring) and self.contains in other +def _read_file(path): + with open(path) as fp: + return fp.read() + + class TestDocsGenerate(DBTIntegrationTest): def setUp(self): super(TestDocsGenerate,self).setUp() @@ -729,7 +734,7 @@ def expected_seeded_manifest(self): 'path': 'model.sql', 'original_file_path': model_sql_path, 'package_name': 'test', - 'raw_sql': open(model_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(model_sql_path).rstrip('\n'), 'refs': [['seed']], 'depends_on': {'nodes': ['seed.test.seed'], 'macros': []}, 'unique_id': 'model.test.model', @@ -920,7 +925,7 @@ def expected_seeded_manifest(self): def expected_postgres_references_manifest(self): my_schema_name = self.unique_schema() docs_path = self.dir('ref_models/docs.md') - docs_file = open(docs_path).read().lstrip() + docs_file = _read_file(docs_path).lstrip() return { 'nodes': { 'model.test.ephemeral_copy': { @@ -1204,7 +1209,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': clustered_sql_path, 'package_name': 'test', 'path': 'clustered.sql', - 'raw_sql': open(clustered_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(clustered_sql_path).rstrip('\n'), 'refs': [['seed']], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1258,7 +1263,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': nested_view_sql_path, 'package_name': 'test', 'path': 'nested_view.sql', - 'raw_sql': open(nested_view_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(nested_view_sql_path).rstrip('\n'), 'refs': [['nested_table']], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1312,7 +1317,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': nested_table_sql_path, 'package_name': 'test', 'path': 'nested_table.sql', - 'raw_sql': open(nested_table_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(nested_table_sql_path).rstrip('\n'), 'refs': [], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1388,7 +1393,7 @@ def expected_redshift_incremental_view_manifest(self): "path": "model.sql", "original_file_path": model_sql_path, "package_name": "test", - "raw_sql": open(model_sql_path).read().rstrip('\n'), + "raw_sql": _read_file(model_sql_path).rstrip('\n'), "refs": [["seed"]], "depends_on": { "nodes": ["seed.test.seed"], diff --git a/test/integration/base.py b/test/integration/base.py index e83bc2dc32b..9dc0040b95b 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -11,7 +11,7 @@ import dbt.flags as flags -from dbt.adapters.factory import get_adapter +from dbt.adapters.factory import get_adapter, reset_adapters from dbt.config import RuntimeConfig from dbt.logger import GLOBAL_LOGGER as logger @@ -267,7 +267,7 @@ def load_config(self): adapter = get_adapter(config) adapter.cleanup_connections() - connection = adapter.acquire_connection(config, '__test') + connection = adapter.acquire_connection('__test') self.handle = connection.handle self.adapter_type = connection.type self.adapter = adapter @@ -277,9 +277,7 @@ def load_config(self): self._create_schema() def quote_as_configured(self, value, quote_key): - return self.adapter.quote_as_configured( - self.config, value, quote_key - ) + return self.adapter.quote_as_configured(value, quote_key) def _clean_files(self): if os.path.exists(DBT_PROFILES): @@ -298,7 +296,8 @@ def _clean_files(self): def tearDown(self): self._clean_files() - self.adapter = get_adapter(self.config) + if not hasattr(self, 'adapter'): + self.adapter = get_adapter(self.config) self._drop_schema() @@ -307,10 +306,11 @@ def tearDown(self): self.handle.close() self.adapter.cleanup_connections() + reset_adapters() def _create_schema(self): if self.adapter_type == 'bigquery': - self.adapter.create_schema(self.config, self.unique_schema(), '__test') + self.adapter.create_schema(self.unique_schema(), '__test') else: schema = self.quote_as_configured(self.unique_schema(), 'schema') self.run_sql('CREATE SCHEMA {}'.format(schema)) @@ -318,7 +318,7 @@ def _create_schema(self): def _drop_schema(self): if self.adapter_type == 'bigquery': - self.adapter.drop_schema(self.config, self.unique_schema(), '__test') + self.adapter.drop_schema(self.unique_schema(), '__test') else: had_existing = False try: @@ -342,6 +342,8 @@ def profile_config(self): return {} def run_dbt(self, args=None, expect_pass=True, strict=True): + # clear the adapter cache + reset_adapters() if args is None: args = ["run"] @@ -385,9 +387,8 @@ def run_sql_bigquery(self, sql, fetch): """Run an SQL query on a bigquery adapter. No cursors, transactions, etc. to worry about""" - adapter = get_adapter(self.config) do_fetch = fetch != 'None' - _, res = adapter.execute(self.config, sql, fetch=do_fetch) + _, res = self.adapter.execute(sql, fetch=do_fetch) # convert dataframe to matrix-ish repr if fetch == 'one': @@ -449,11 +450,8 @@ def filter_many_columns(self, column): def get_table_columns(self, table, schema=None): schema = self.unique_schema() if schema is None else schema - columns = self.adapter.get_columns_in_table( - self.config, - schema, - table - ) + columns = self.adapter.get_columns_in_table(schema, table) + return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index d1f0f402153..f00283e539f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -11,38 +11,56 @@ fake_conn = {"handle": None, "state": "open", "type": "bigquery"} +from .utils import config_from_parts_or_dicts + class TestBigQueryAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True + self.raw_profile = { + 'outputs': { + 'oauth': { + 'type': 'bigquery', + 'method': 'oauth', + 'project': 'dbt-unit-000000', + 'schema': 'dummy_schema', + 'threads': 1, + }, + 'service_account': { + 'type': 'bigquery', + 'method': 'service-account', + 'project': 'dbt-unit-000000', + 'schema': 'dummy_schema', + 'keyfile': '/tmp/dummy-service-account.json', + 'threads': 1, + }, + }, + 'target': 'oauth', + } - self.oauth_credentials = BigQueryCredentials( - method='oauth', - project='dbt-unit-000000', - schema='dummy_schema' - ) - self.oauth_profile = MagicMock( - credentials=self.oauth_credentials, - threads=1 - ) + self.project_cfg = { + 'name': 'X', + 'version': '0.1', + 'project-root': '/tmp/dbt/does-not-exist', + } - self.service_account_credentials = BigQueryCredentials( - method='service-account', - project='dbt-unit-000000', - schema='dummy_schema', - keyfile='/tmp/dummy-service-account.json' - ) - self.service_account_profile = MagicMock( - credentials=self.service_account_credentials, - threads=1 + def get_adapter(self, profile): + project = self.project_cfg.copy() + project['profile'] = profile + + config = config_from_parts_or_dicts( + project=project, + profile=self.raw_profile, ) + return BigQueryAdapter(config) @patch('dbt.adapters.bigquery.BigQueryAdapter.open_connection', return_value=fake_conn) def test_acquire_connection_oauth_validations(self, mock_open_connection): + adapter = self.get_adapter('oauth') try: - connection = BigQueryAdapter.acquire_connection(self.oauth_profile, 'dummy') + connection = adapter.acquire_connection('dummy') self.assertEquals(connection.get('type'), 'bigquery') except dbt.exceptions.ValidationException as e: @@ -56,8 +74,9 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): @patch('dbt.adapters.bigquery.BigQueryAdapter.open_connection', return_value=fake_conn) def test_acquire_connection_service_account_validations(self, mock_open_connection): + adapter = self.get_adapter('service_account') try: - connection = BigQueryAdapter.acquire_connection(self.service_account_profile, 'dummy') + connection = adapter.acquire_connection('dummy') self.assertEquals(connection.get('type'), 'bigquery') except dbt.exceptions.ValidationException as e: diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 4c0d3d3a819..67b7cc26d2b 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -724,7 +724,11 @@ def test_from_parts(self): expected_project = project.to_project_config() self.assertEqual(expected_project['quoting'], {}) - expected_project['quoting'] = {'identifier': True, 'schema': True} + expected_project['quoting'] = { + 'database': True, + 'identifier': True, + 'schema': True, + } self.assertEqual(config.to_project_config(), expected_project) def test_str(self): @@ -771,7 +775,7 @@ def test_from_args(self): self.assertEqual(config.clean_targets, ['target']) self.assertEqual(config.log_path, 'logs') self.assertEqual(config.modules_path, 'dbt_modules') - self.assertEqual(config.quoting, {'identifier': True, 'schema': True}) + self.assertEqual(config.quoting, {'database': True, 'identifier': True, 'schema': True}) self.assertEqual(config.models, {}) self.assertEqual(config.on_run_start, []) self.assertEqual(config.on_run_end, []) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index a3ac793411d..7dea6a40a3c 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -40,10 +40,13 @@ def setUp(self): self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + @property + def adapter(self): + return PostgresAdapter(self.config) + def test_acquire_connection_validations(self): try: - connection = PostgresAdapter.acquire_connection(self.config, - 'dummy') + connection = self.adapter.acquire_connection('dummy') self.assertEquals(connection.type, 'postgres') except ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) @@ -52,14 +55,14 @@ def test_acquire_connection_validations(self): .format(str(e))) def test_acquire_connection(self): - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') self.assertEquals(connection.state, 'open') self.assertNotEquals(connection.handle, None) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -71,9 +74,11 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - credentials = self.config.credentials.incorporate(keepalives_idle=256) - self.config.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + credentials = self.adapter.config.credentials.incorporate( + keepalives_idle=256 + ) + self.adapter.config.credentials = credentials + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -88,7 +93,7 @@ def test_changed_keepalive(self, psycopg2): def test_set_zero_keepalive(self, psycopg2): credentials = self.config.credentials.incorporate(keepalives_idle=0) self.config.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -121,7 +126,7 @@ def test_get_catalog_various_schemas(self, mock_run): # give manifest the dict it wants mock_manifest = mock.MagicMock(spec_set=['nodes'], nodes=nodes) - catalog = PostgresAdapter.get_catalog(mock.MagicMock(), mock_manifest) + catalog = self.adapter.get_catalog(mock_manifest) self.assertEqual( set(map(tuple, catalog)), {('foo', 'bar'), ('FOO', 'baz'), ('quux', 'bar')} @@ -166,26 +171,23 @@ def setUp(self): self.psycopg2 = self.patcher.start() self.psycopg2.connect.return_value = self.handle - conn = PostgresAdapter.get_connection(self.config) + self.adapter = PostgresAdapter(self.config) + self.adapter.get_connection() def tearDown(self): # we want a unique self.handle every time. - PostgresAdapter.cleanup_connections() + self.adapter.cleanup_connections() self.patcher.stop() def test_quoting_on_drop_schema(self): - PostgresAdapter.drop_schema( - config=self.config, - schema='test_schema' - ) + self.adapter.drop_schema(schema='test_schema') self.mock_execute.assert_has_calls([ mock.call('drop schema if exists "test_schema" cascade', None) ]) def test_quoting_on_drop(self): - PostgresAdapter.drop( - config=self.config, + self.adapter.drop( schema='test_schema', relation='test_table', relation_type='table' @@ -195,8 +197,7 @@ def test_quoting_on_drop(self): ]) def test_quoting_on_truncate(self): - PostgresAdapter.truncate( - config=self.config, + self.adapter.truncate( schema='test_schema', table='test_table' ) @@ -205,8 +206,7 @@ def test_quoting_on_truncate(self): ]) def test_quoting_on_rename(self): - PostgresAdapter.rename( - config=self.config, + self.adapter.rename( schema='test_schema', from_name='table_a', to_name='table_b' diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index 4b713737d3d..d1b6559178d 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -8,7 +8,8 @@ from dbt.adapters.redshift import RedshiftAdapter from dbt.exceptions import ValidationException, FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa -from dbt.config import Profile + +from .utils import config_from_parts_or_dicts @classmethod @@ -23,7 +24,7 @@ class TestRedshiftAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - self.profile = Profile.from_raw_profile_info({ + profile_cfg = { 'outputs': { 'test': { 'type': 'redshift', @@ -36,55 +37,72 @@ def setUp(self): } }, 'target': 'test' - }, 'test') + } + + project_cfg = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + 'quoting': { + 'identifier': False, + 'schema': True, + } + } + + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + + @property + def adapter(self): + return RedshiftAdapter(self.config) def test_implicit_database_conn(self): - creds = RedshiftAdapter.get_credentials(self.profile.credentials) - self.assertEquals(creds, self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) + self.assertEquals(creds, self.config.credentials) def test_explicit_database_conn(self): - self.profile.method = 'database' + self.config.method = 'database' - creds = RedshiftAdapter.get_credentials(self.profile.credentials) - self.assertEquals(creds, self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) + self.assertEquals(creds, self.config.credentials) def test_explicit_iam_conn(self): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( method='iam', cluster_id='my_redshift', iam_duration_seconds=1200 ) with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - creds = RedshiftAdapter.get_credentials(self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) - expected_creds = self.profile.credentials.incorporate(password='tmp_password') + expected_creds = self.config.credentials.incorporate(password='tmp_password') self.assertEquals(creds, expected_creds) def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate - self.profile.credentials._contents['method'] = 'badmethod' + self.config.credentials._contents['method'] = 'badmethod' with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile.credentials) + RedshiftAdapter.get_credentials(self.config.credentials) self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( method='iam' ) with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile.credentials) + RedshiftAdapter.get_credentials(self.config.credentials) self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', @@ -97,10 +115,10 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( keepalives_idle=256 ) - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', @@ -113,10 +131,10 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( keepalives_idle=0 ) - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 9aa290c7f57..d85086c04ac 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -49,16 +49,16 @@ def setUp(self): self.snowflake = self.patcher.start() self.snowflake.return_value = self.handle - conn = SnowflakeAdapter.get_connection(self.config) + self.adapter = SnowflakeAdapter(self.config) + self.adapter.get_connection() def tearDown(self): # we want a unique self.handle every time. - SnowflakeAdapter.cleanup_connections() + self.adapter.cleanup_connections() self.patcher.stop() def test_quoting_on_drop_schema(self): - SnowflakeAdapter.drop_schema( - config=self.config, + self.adapter.drop_schema( schema='test_schema' ) @@ -67,8 +67,7 @@ def test_quoting_on_drop_schema(self): ]) def test_quoting_on_drop(self): - SnowflakeAdapter.drop( - config=self.config, + self.adapter.drop( schema='test_schema', relation='test_table', relation_type='table' @@ -78,8 +77,7 @@ def test_quoting_on_drop(self): ]) def test_quoting_on_truncate(self): - SnowflakeAdapter.truncate( - config=self.config, + self.adapter.truncate( schema='test_schema', table='test_table' ) @@ -88,8 +86,7 @@ def test_quoting_on_truncate(self): ]) def test_quoting_on_rename(self): - SnowflakeAdapter.rename( - config=self.config, + self.adapter.rename( schema='test_schema', from_name='table_a', to_name='table_b'