From 545f220d5d3038e1ab33ea38e39ef949069864d4 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 8 Aug 2018 00:18:42 +0100 Subject: [PATCH] [AIRFLOW-2867] Refactor Code to conform standards (#3714) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Dictionary creation should be written by dictionary literal - Python’s default arguments are evaluated once when the function is defined, not each time the function is called (like it is in say, Ruby). This means that if you use a mutable default argument and mutate it, you will and have mutated that object for all future calls to the function as well. - Functions calling sets which can be replaced by set literal are now replaced by set literal - Replace list literals - Some of the static methods haven't been set static - Remove redundant parentheses --- airflow/contrib/hooks/bigquery_hook.py | 32 ++- airflow/contrib/hooks/databricks_hook.py | 3 +- airflow/contrib/hooks/datastore_hook.py | 4 +- airflow/contrib/hooks/gcp_container_hook.py | 6 +- airflow/contrib/hooks/gcs_hook.py | 2 +- airflow/contrib/hooks/salesforce_hook.py | 3 +- airflow/contrib/hooks/vertica_hook.py | 2 +- airflow/contrib/kubernetes/pod_launcher.py | 2 +- .../contrib/operators/dataflow_operator.py | 2 +- .../contrib/operators/dataproc_operator.py | 6 +- .../operators/gcp_container_operator.py | 4 +- airflow/contrib/operators/gcs_to_bq.py | 8 +- .../operators/mlengine_prediction_summary.py | 6 +- airflow/contrib/operators/mongo_to_s3.py | 6 +- airflow/contrib/operators/mysql_to_gcs.py | 3 +- .../oracle_to_azure_data_lake_transfer.py | 228 +++++++++--------- .../operators/oracle_to_oracle_transfer.py | 4 +- .../contrib/operators/s3_to_gcs_operator.py | 3 +- .../contrib/sensors/emr_job_flow_sensor.py | 3 +- airflow/contrib/sensors/emr_step_sensor.py | 3 +- .../contrib/task_runner/cgroup_task_runner.py | 3 +- airflow/executors/dask_executor.py | 2 +- airflow/hooks/S3_hook.py | 10 +- airflow/hooks/druid_hook.py | 2 +- airflow/hooks/hive_hooks.py | 3 +- airflow/hooks/presto_hook.py | 3 +- airflow/jobs.py | 2 +- airflow/operators/check_operator.py | 2 +- airflow/operators/hive_stats_operator.py | 3 +- airflow/operators/python_operator.py | 3 +- airflow/operators/s3_to_hive_operator.py | 4 +- airflow/sensors/hdfs_sensor.py | 4 +- airflow/utils/cli.py | 6 +- airflow/utils/helpers.py | 2 +- airflow/utils/log/gcs_task_handler.py | 3 +- airflow/www/views.py | 2 +- airflow/www_rbac/forms.py | 2 +- airflow/www_rbac/views.py | 2 +- dev/airflow-pr | 2 +- scripts/perf/scheduler_ops_metrics.py | 2 +- tests/cli/test_cli.py | 4 +- .../executors/test_kubernetes_executor.py | 6 +- tests/contrib/hooks/test_aws_lambda_hook.py | 3 +- tests/contrib/hooks/test_gcp_mlengine_hook.py | 3 +- tests/contrib/hooks/test_mongo_hook.py | 4 +- tests/contrib/hooks/test_redshift_hook.py | 3 +- .../minikube/test_kubernetes_executor.py | 3 +- .../minikube/test_kubernetes_pod_operator.py | 18 +- .../operators/test_dataproc_operator.py | 24 +- .../test_hive_to_dynamodb_operator.py | 3 +- .../operators/test_mysql_to_gcs_operator.py | 11 +- .../test_oracle_to_oracle_transfer.py | 3 +- .../test_aws_redshift_cluster_sensor.py | 3 +- tests/contrib/sensors/test_emr_base_sensor.py | 24 +- tests/core.py | 2 +- tests/models.py | 6 +- tests/operators/docker_operator.py | 3 +- tests/operators/s3_to_hive_operator.py | 8 +- tests/operators/test_virtualenv_operator.py | 3 +- .../experimental/test_kerberos_endpoints.py | 2 +- .../experimental/test_kerberos_endpoints.py | 2 +- 61 files changed, 304 insertions(+), 226 deletions(-) diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index aa8fc382a6a67..2a94580f509d1 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -206,7 +206,7 @@ def create_empty_table(self, dataset_id, table_id, schema_fields=None, - time_partitioning={}, + time_partitioning=None, labels=None ): """ @@ -238,6 +238,8 @@ def create_empty_table(self, :return: """ + if time_partitioning is None: + time_partitioning = dict() project_id = project_id if project_id is not None else self.project_id table_resource = { @@ -286,7 +288,7 @@ def create_external_table(self, quote_character=None, allow_quoted_newlines=False, allow_jagged_rows=False, - src_fmt_configs={}, + src_fmt_configs=None, labels=None ): """ @@ -352,6 +354,8 @@ def create_external_table(self, :type labels: dict """ + if src_fmt_configs is None: + src_fmt_configs = {} project_id, dataset_id, external_table_id = \ _split_tablename(table_input=external_project_dataset_table, default_project_id=self.project_id, @@ -482,7 +486,7 @@ def run_query(self, labels=None, schema_update_options=(), priority='INTERACTIVE', - time_partitioning={}): + time_partitioning=None): """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery table. See here: @@ -548,6 +552,8 @@ def run_query(self, """ # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513] + if time_partitioning is None: + time_partitioning = {} sql = bql if sql is None else sql if bql: @@ -808,8 +814,8 @@ def run_load(self, allow_quoted_newlines=False, allow_jagged_rows=False, schema_update_options=(), - src_fmt_configs={}, - time_partitioning={}): + src_fmt_configs=None, + time_partitioning=None): """ Executes a BigQuery load command to load data from Google Cloud Storage to BigQuery. See here: @@ -880,6 +886,10 @@ def run_load(self, # if it's not, we raise a ValueError # Refer to this link for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat + if src_fmt_configs is None: + src_fmt_configs = {} + if time_partitioning is None: + time_partitioning = {} source_format = source_format.upper() allowed_formats = [ "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", @@ -1011,12 +1021,12 @@ def run_with_configuration(self, configuration): # Wait for query to finish. keep_polling_job = True - while (keep_polling_job): + while keep_polling_job: try: job = jobs.get( projectId=self.project_id, jobId=self.running_job_id).execute() - if (job['status']['state'] == 'DONE'): + if job['status']['state'] == 'DONE': keep_polling_job = False # Check if job had errors. if 'errorResult' in job['status']: @@ -1045,7 +1055,7 @@ def poll_job_complete(self, job_id): jobs = self.service.jobs() try: job = jobs.get(projectId=self.project_id, jobId=job_id).execute() - if (job['status']['state'] == 'DONE'): + if job['status']['state'] == 'DONE': return True except HttpError as err: if err.resp.status in [500, 503]: @@ -1079,13 +1089,13 @@ def cancel_query(self): polling_attempts = 0 job_complete = False - while (polling_attempts < max_polling_attempts and not job_complete): + while polling_attempts < max_polling_attempts and not job_complete: polling_attempts = polling_attempts + 1 job_complete = self.poll_job_complete(self.running_job_id) - if (job_complete): + if job_complete: self.log.info('Job successfully canceled: %s, %s', self.project_id, self.running_job_id) - elif (polling_attempts == max_polling_attempts): + elif polling_attempts == max_polling_attempts: self.log.info( "Stopping polling due to timeout. Job with id %s " "has not completed cancel and may or may not finish.", diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 2e5f1399b4765..54f00e00907c0 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -65,7 +65,8 @@ def __init__( raise ValueError('Retry limit must be greater than equal to 1') self.retry_limit = retry_limit - def _parse_host(self, host): + @staticmethod + def _parse_host(host): """ The purpose of this function is to be robust to improper connections settings provided by users, specifically in the host field. diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py index 5e54cf2a65384..b8c3ca00a0f46 100644 --- a/airflow/contrib/hooks/datastore_hook.py +++ b/airflow/contrib/hooks/datastore_hook.py @@ -172,7 +172,7 @@ def export_to_storage_bucket(self, bucket, namespace=None, """ Export entities from Cloud Datastore to Cloud Storage for backup """ - output_uri_prefix = 'gs://' + ('/').join(filter(None, [bucket, namespace])) + output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) if not entity_filter: entity_filter = {} if not labels: @@ -191,7 +191,7 @@ def import_from_storage_bucket(self, bucket, file, """ Import a backup from Cloud Storage to Cloud Datastore """ - input_url = 'gs://' + ('/').join(filter(None, [bucket, namespace, file])) + input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) if not entity_filter: entity_filter = {} if not labels: diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py index 227cd3d12474e..e5fbda138e0fa 100644 --- a/airflow/contrib/hooks/gcp_container_hook.py +++ b/airflow/contrib/hooks/gcp_container_hook.py @@ -44,7 +44,8 @@ def __init__(self, project_id, location): client_info = ClientInfo(client_library_version='airflow_v' + version.version) self.client = container_v1.ClusterManagerClient(client_info=client_info) - def _dict_to_proto(self, py_dict, proto): + @staticmethod + def _dict_to_proto(py_dict, proto): """ Converts a python dictionary to the proto supplied :param py_dict: The dictionary to convert @@ -90,7 +91,8 @@ def get_operation(self, operation_name): zone=self.location, operation_id=operation_name) - def _append_label(self, cluster_proto, key, val): + @staticmethod + def _append_label(cluster_proto, key, val): """ Append labels to provided Cluster Protobuf diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index 08d44ce7faf92..3d42ec4426dee 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -306,7 +306,7 @@ def list(self, bucket, versions=None, maxResults=None, prefix=None, delimiter=No ids = list() pageToken = None - while(True): + while True: response = service.objects().list( bucket=bucket, versions=versions, diff --git a/airflow/contrib/hooks/salesforce_hook.py b/airflow/contrib/hooks/salesforce_hook.py index ee18b353d2e82..24b67f49fc402 100644 --- a/airflow/contrib/hooks/salesforce_hook.py +++ b/airflow/contrib/hooks/salesforce_hook.py @@ -135,7 +135,8 @@ def get_available_fields(self, obj): return [f['name'] for f in desc['fields']] - def _build_field_list(self, fields): + @staticmethod + def _build_field_list(fields): # join all of the fields in a comma separated list return ",".join(fields) diff --git a/airflow/contrib/hooks/vertica_hook.py b/airflow/contrib/hooks/vertica_hook.py index f3411de994d7f..e6b36b51d5937 100644 --- a/airflow/contrib/hooks/vertica_hook.py +++ b/airflow/contrib/hooks/vertica_hook.py @@ -41,9 +41,9 @@ def get_conn(self): "user": conn.login, "password": conn.password or '', "database": conn.schema, + "host": conn.host or 'localhost' } - conn_config["host"] = conn.host or 'localhost' if not conn.port: conn_config["port"] = 5433 else: diff --git a/airflow/contrib/kubernetes/pod_launcher.py b/airflow/contrib/kubernetes/pod_launcher.py index 8ac5108507345..42f2bfea8adec 100644 --- a/airflow/contrib/kubernetes/pod_launcher.py +++ b/airflow/contrib/kubernetes/pod_launcher.py @@ -104,7 +104,7 @@ def _monitor_pod(self, pod, get_logs): while self.pod_is_running(pod): self.log.info('Pod %s has state %s', pod.name, State.RUNNING) time.sleep(2) - return (self._task_status(self.read_pod(pod)), result) + return self._task_status(self.read_pod(pod)), result def _task_status(self, event): self.log.info( diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index e3c8c1fff1572..fdbb841066a34 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -331,7 +331,7 @@ def execute(self, context): self.py_file, self.py_options) -class GoogleCloudBucketHelper(): +class GoogleCloudBucketHelper(object): """GoogleCloudStorageHook helper class to download GCS object.""" GCS_PREFIX_LENGTH = 5 diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py index 01d137f954900..6dfa2da095e38 100644 --- a/airflow/contrib/operators/dataproc_operator.py +++ b/airflow/contrib/operators/dataproc_operator.py @@ -491,7 +491,8 @@ def _build_scale_cluster_data(self): } return scale_data - def _get_graceful_decommission_timeout(self, timeout): + @staticmethod + def _get_graceful_decommission_timeout(timeout): match = re.match(r"^(\d+)(s|m|h|d)$", timeout) if match: if match.group(2) == "s": @@ -575,7 +576,8 @@ def __init__(self, self.project_id = project_id self.region = region - def _wait_for_done(self, service, operation_name): + @staticmethod + def _wait_for_done(service, operation_name): time.sleep(15) while True: response = service.projects().regions().operations().get( diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py index 615eac8a0f8f6..c99f2a93f2c59 100644 --- a/airflow/contrib/operators/gcp_container_operator.py +++ b/airflow/contrib/operators/gcp_container_operator.py @@ -99,7 +99,7 @@ class GKEClusterCreateOperator(BaseOperator): def __init__(self, project_id, location, - body={}, + body=None, gcp_conn_id='google_cloud_default', api_version='v2', *args, @@ -148,6 +148,8 @@ def __init__(self, """ super(GKEClusterCreateOperator, self).__init__(*args, **kwargs) + if body is None: + body = {} self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.location = location diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index 3a7798030cf49..533cf01de5666 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -143,14 +143,18 @@ def __init__(self, google_cloud_storage_conn_id='google_cloud_default', delegate_to=None, schema_update_options=(), - src_fmt_configs={}, + src_fmt_configs=None, external_table=False, - time_partitioning={}, + time_partitioning=None, *args, **kwargs): super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs) # GCS config + if src_fmt_configs is None: + src_fmt_configs = {} + if time_partitioning is None: + time_partitioning = {} self.bucket = bucket self.source_objects = source_objects self.schema_object = schema_object diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py b/airflow/contrib/operators/mlengine_prediction_summary.py index 4efe81e64151f..5dac0a44a9dcb 100644 --- a/airflow/contrib/operators/mlengine_prediction_summary.py +++ b/airflow/contrib/operators/mlengine_prediction_summary.py @@ -102,10 +102,12 @@ def metric_fn(inst): class JsonCoder(object): - def encode(self, x): + @staticmethod + def encode(x): return json.dumps(x) - def decode(self, x): + @staticmethod + def decode(x): return json.loads(x) diff --git a/airflow/contrib/operators/mongo_to_s3.py b/airflow/contrib/operators/mongo_to_s3.py index 43b5d8b6c357a..8bfa7a52f80bb 100644 --- a/airflow/contrib/operators/mongo_to_s3.py +++ b/airflow/contrib/operators/mongo_to_s3.py @@ -96,7 +96,8 @@ def execute(self, context): return True - def _stringify(self, iterable, joinable='\n'): + @staticmethod + def _stringify(iterable, joinable='\n'): """ Takes an iterable (pymongo Cursor or Array) containing dictionaries and returns a stringified version using python join @@ -105,7 +106,8 @@ def _stringify(self, iterable, joinable='\n'): [json.dumps(doc, default=json_util.default) for doc in iterable] ) - def transform(self, docs): + @staticmethod + def transform(docs): """ Processes pyMongo cursor and returns an iterable with each element being a JSON serializable dictionary diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py index 4d1bb7b329e6c..eb47465e8e8ec 100644 --- a/airflow/contrib/operators/mysql_to_gcs.py +++ b/airflow/contrib/operators/mysql_to_gcs.py @@ -218,7 +218,8 @@ def _upload_to_gcs(self, files_to_upload): for object, tmp_file_handle in files_to_upload.items(): hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json') - def _convert_types(self, schema, col_type_dict, row): + @staticmethod + def _convert_types(schema, col_type_dict, row): """ Takes a value from MySQLdb, and converts it to a value that's safe for JSON/Google cloud storage/BigQuery. Dates are converted to UTC seconds. diff --git a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py index 06a3998defa49..80cec8f462d9a 100644 --- a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py +++ b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py @@ -1,113 +1,115 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from airflow.hooks.oracle_hook import OracleHook -from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.utils.file import TemporaryDirectory - -import unicodecsv as csv -import os - - -class OracleToAzureDataLakeTransfer(BaseOperator): - """ - Moves data from Oracle to Azure Data Lake. The operator runs the query against - Oracle and stores the file locally before loading it into Azure Data Lake. - - - :param filename: file name to be used by the csv file. - :type filename: str - :param azure_data_lake_conn_id: destination azure data lake connection. - :type azure_data_lake_conn_id: str - :param azure_data_lake_path: destination path in azure data lake to put the file. - :type azure_data_lake_path: str - :param oracle_conn_id: source Oracle connection. - :type oracle_conn_id: str - :param sql: SQL query to execute against the Oracle database. (templated) - :type sql: str - :param sql_params: Parameters to use in sql query. (templated) - :type sql_params: str - :param delimiter: field delimiter in the file. - :type delimiter: str - :param encoding: enconding type for the file. - :type encoding: str - :param quotechar: Character to use in quoting. - :type quotechar: str - :param quoting: Quoting strategy. See unicodecsv quoting for more information. - :type quoting: str - """ - - template_fields = ('filename', 'sql', 'sql_params') - ui_color = '#e08c8c' - - @apply_defaults - def __init__( - self, - filename, - azure_data_lake_conn_id, - azure_data_lake_path, - oracle_conn_id, - sql, - sql_params={}, - delimiter=",", - encoding="utf-8", - quotechar='"', - quoting=csv.QUOTE_MINIMAL, - *args, **kwargs): - super(OracleToAzureDataLakeTransfer, self).__init__(*args, **kwargs) - self.filename = filename - self.oracle_conn_id = oracle_conn_id - self.sql = sql - self.sql_params = sql_params - self.azure_data_lake_conn_id = azure_data_lake_conn_id - self.azure_data_lake_path = azure_data_lake_path - self.delimiter = delimiter - self.encoding = encoding - self.quotechar = quotechar - self.quoting = quoting - - def _write_temp_file(self, cursor, path_to_save): - with open(path_to_save, 'wb') as csvfile: - csv_writer = csv.writer(csvfile, delimiter=self.delimiter, - encoding=self.encoding, quotechar=self.quotechar, - quoting=self.quoting) - csv_writer.writerow(map(lambda field: field[0], cursor.description)) - csv_writer.writerows(cursor) - csvfile.flush() - - def execute(self, context): - oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - azure_data_lake_hook = AzureDataLakeHook( - azure_data_lake_conn_id=self.azure_data_lake_conn_id) - - self.log.info("Dumping Oracle query results to local file") - conn = oracle_hook.get_conn() - cursor = conn.cursor() - cursor.execute(self.sql, self.sql_params) - - with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: - self._write_temp_file(cursor, os.path.join(temp, self.filename)) - self.log.info("Uploading local file to Azure Data Lake") - azure_data_lake_hook.upload_file(os.path.join(temp, self.filename), - os.path.join(self.azure_data_lake_path, - self.filename)) - cursor.close() - conn.close() +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.hooks.oracle_hook import OracleHook +from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.file import TemporaryDirectory + +import unicodecsv as csv +import os + + +class OracleToAzureDataLakeTransfer(BaseOperator): + """ + Moves data from Oracle to Azure Data Lake. The operator runs the query against + Oracle and stores the file locally before loading it into Azure Data Lake. + + + :param filename: file name to be used by the csv file. + :type filename: str + :param azure_data_lake_conn_id: destination azure data lake connection. + :type azure_data_lake_conn_id: str + :param azure_data_lake_path: destination path in azure data lake to put the file. + :type azure_data_lake_path: str + :param oracle_conn_id: source Oracle connection. + :type oracle_conn_id: str + :param sql: SQL query to execute against the Oracle database. (templated) + :type sql: str + :param sql_params: Parameters to use in sql query. (templated) + :type sql_params: str + :param delimiter: field delimiter in the file. + :type delimiter: str + :param encoding: enconding type for the file. + :type encoding: str + :param quotechar: Character to use in quoting. + :type quotechar: str + :param quoting: Quoting strategy. See unicodecsv quoting for more information. + :type quoting: str + """ + + template_fields = ('filename', 'sql', 'sql_params') + ui_color = '#e08c8c' + + @apply_defaults + def __init__( + self, + filename, + azure_data_lake_conn_id, + azure_data_lake_path, + oracle_conn_id, + sql, + sql_params=None, + delimiter=",", + encoding="utf-8", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + *args, **kwargs): + super(OracleToAzureDataLakeTransfer, self).__init__(*args, **kwargs) + if sql_params is None: + sql_params = {} + self.filename = filename + self.oracle_conn_id = oracle_conn_id + self.sql = sql + self.sql_params = sql_params + self.azure_data_lake_conn_id = azure_data_lake_conn_id + self.azure_data_lake_path = azure_data_lake_path + self.delimiter = delimiter + self.encoding = encoding + self.quotechar = quotechar + self.quoting = quoting + + def _write_temp_file(self, cursor, path_to_save): + with open(path_to_save, 'wb') as csvfile: + csv_writer = csv.writer(csvfile, delimiter=self.delimiter, + encoding=self.encoding, quotechar=self.quotechar, + quoting=self.quoting) + csv_writer.writerow(map(lambda field: field[0], cursor.description)) + csv_writer.writerows(cursor) + csvfile.flush() + + def execute(self, context): + oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) + azure_data_lake_hook = AzureDataLakeHook( + azure_data_lake_conn_id=self.azure_data_lake_conn_id) + + self.log.info("Dumping Oracle query results to local file") + conn = oracle_hook.get_conn() + cursor = conn.cursor() + cursor.execute(self.sql, self.sql_params) + + with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: + self._write_temp_file(cursor, os.path.join(temp, self.filename)) + self.log.info("Uploading local file to Azure Data Lake") + azure_data_lake_hook.upload_file(os.path.join(temp, self.filename), + os.path.join(self.azure_data_lake_path, + self.filename)) + cursor.close() + conn.close() diff --git a/airflow/contrib/operators/oracle_to_oracle_transfer.py b/airflow/contrib/operators/oracle_to_oracle_transfer.py index 31eb89b7dded2..1db95f7520bb1 100644 --- a/airflow/contrib/operators/oracle_to_oracle_transfer.py +++ b/airflow/contrib/operators/oracle_to_oracle_transfer.py @@ -52,10 +52,12 @@ def __init__( destination_table, oracle_source_conn_id, source_sql, - source_sql_params={}, + source_sql_params=None, rows_chunk=5000, *args, **kwargs): super(OracleToOracleTransfer, self).__init__(*args, **kwargs) + if source_sql_params is None: + source_sql_params = {} self.oracle_destination_conn_id = oracle_destination_conn_id self.destination_table = destination_table self.oracle_source_conn_id = oracle_source_conn_id diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py index 2898af1071773..64d7dc7cab976 100644 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_operator.py @@ -184,7 +184,8 @@ def execute(self, context): # Following functionality may be better suited in # airflow/contrib/hooks/gcs_hook.py - def _gcs_object_is_directory(self, object): + @staticmethod + def _gcs_object_is_directory(object): bucket, blob = _parse_gcs_url(object) return len(blob) == 0 or blob.endswith('/') diff --git a/airflow/contrib/sensors/emr_job_flow_sensor.py b/airflow/contrib/sensors/emr_job_flow_sensor.py index 806b63bda39f4..5a17a012d42c4 100644 --- a/airflow/contrib/sensors/emr_job_flow_sensor.py +++ b/airflow/contrib/sensors/emr_job_flow_sensor.py @@ -50,5 +50,6 @@ def get_emr_response(self): self.log.info('Poking cluster %s', self.job_flow_id) return emr.describe_cluster(ClusterId=self.job_flow_id) - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['Cluster']['Status']['State'] diff --git a/airflow/contrib/sensors/emr_step_sensor.py b/airflow/contrib/sensors/emr_step_sensor.py index afdcbcdc8ff49..6e79f6353ff73 100644 --- a/airflow/contrib/sensors/emr_step_sensor.py +++ b/airflow/contrib/sensors/emr_step_sensor.py @@ -53,5 +53,6 @@ def get_emr_response(self): self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id) return emr.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id) - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['Step']['Status']['State'] diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py index a97eac2af88ef..faa2407f09a97 100644 --- a/airflow/contrib/task_runner/cgroup_task_runner.py +++ b/airflow/contrib/task_runner/cgroup_task_runner.py @@ -193,7 +193,8 @@ def on_finish(self): if self._created_cpu_cgroup: self._delete_cgroup(self.cpu_cgroup_name) - def _get_cgroup_names(self): + @staticmethod + def _get_cgroup_names(): """ :return: a mapping between the subsystem name to the cgroup name :rtype: dict[str, str] diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index a6ba677f8bd7c..23c0b636b6e5c 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -43,7 +43,7 @@ def __init__(self, cluster_address=None): super(DaskExecutor, self).__init__(parallelism=0) def start(self): - if (self.tls_ca) or (self.tls_key) or (self.tls_cert): + if self.tls_ca or self.tls_key or self.tls_cert: from distributed.security import Security security = Security( tls_client_key=self.tls_key, diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py index 5e505ca37c24b..2d64b315342fb 100644 --- a/airflow/hooks/S3_hook.py +++ b/airflow/hooks/S3_hook.py @@ -43,7 +43,7 @@ def parse_s3_url(s3url): else: bucket_name = parsed_url.netloc key = parsed_url.path.strip('/') - return (bucket_name, key) + return bucket_name, key def check_for_bucket(self, bucket_name): """ @@ -206,8 +206,8 @@ def read_key(self, key, bucket_name=None): def select_key(self, key, bucket_name=None, expression='SELECT * FROM S3Object', expression_type='SQL', - input_serialization={'CSV': {}}, - output_serialization={'CSV': {}}): + input_serialization=None, + output_serialization=None): """ Reads a key with S3 Select. @@ -230,6 +230,10 @@ def select_key(self, key, bucket_name=None, For more details about S3 Select parameters: http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content """ + if input_serialization is None: + input_serialization = {'CSV': {}} + if output_serialization is None: + output_serialization = {'CSV': {}} if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py index 81bcaa0c89068..380a3b3314b12 100644 --- a/airflow/hooks/druid_hook.py +++ b/airflow/hooks/druid_hook.py @@ -69,7 +69,7 @@ def submit_indexing_job(self, json_index_spec): url = self.get_conn_url() req_index = requests.post(url, json=json_index_spec, headers=self.header) - if (req_index.status_code != 200): + if req_index.status_code != 200: raise AirflowException('Did not get 200 when ' 'submitting the Druid job to {}'.format(url)) diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 8df96c464fd88..bd6a01a1e6c6b 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -148,7 +148,8 @@ def _prepare_cli_cmd(self): return [hive_bin] + cmd_extra + hive_params_list - def _prepare_hiveconf(self, d): + @staticmethod + def _prepare_hiveconf(d): """ This function prepares a list of hiveconf params from a dictionary of key value pairs. diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py index d6b5293fc0f97..c7ebd8f5dcb70 100644 --- a/airflow/hooks/presto_hook.py +++ b/airflow/hooks/presto_hook.py @@ -56,7 +56,8 @@ def get_conn(self): def _strip_sql(sql): return sql.strip().rstrip(';') - def _get_pretty_exception_message(self, e): + @staticmethod + def _get_pretty_exception_message(e): """ Parses some DatabaseError to provide a better error message """ diff --git a/airflow/jobs.py b/airflow/jobs.py index cc26feee5371b..4f0bdc6a4f4d5 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1195,7 +1195,7 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None): 'task_concurrency') if task_concurrency is not None: num_running = task_concurrency_map[ - ((task_instance.dag_id, task_instance.task_id)) + (task_instance.dag_id, task_instance.task_id) ] if num_running >= task_concurrency: diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index 5a31737fd5a6c..a0d213cf6619c 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -147,7 +147,7 @@ def execute(self, context=None): is_numeric_value_check = isinstance(pass_value_conv, float) tolerance_pct_str = None - if (self.tol is not None): + if self.tol is not None: tolerance_pct_str = str(self.tol * 100) + '%' except_temp = ("Test failed.\nPass value:{pass_value_conv}\n" diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index fe83284193757..b0bb874956a92 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -91,8 +91,7 @@ def __init__( def get_default_exprs(self, col, col_type): if col in self.col_blacklist: return {} - d = {} - d[(col, 'non_null')] = "COUNT({col})" + d = {(col, 'non_null'): "COUNT({col})"} if col_type in ['double', 'int', 'bigint', 'float', 'double']: d[(col, 'sum')] = 'SUM({col})' d[(col, 'min')] = 'MIN({col})' diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 678a3deb7745d..2817f663a8d8d 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -363,7 +363,8 @@ def _generate_pip_install_cmd(self, tmp_dir): cmd = ['{}/bin/pip'.format(tmp_dir), 'install'] return cmd + self.requirements - def _generate_python_cmd(self, tmp_dir, script_filename, + @staticmethod + def _generate_python_cmd(tmp_dir, script_filename, input_filename, output_filename, string_args_filename): # direct path alleviates need to activate return ['{}/bin/python'.format(tmp_dir), script_filename, diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index 5faaf916b7417..b82ebce6fa295 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -261,8 +261,8 @@ def _match_headers(self, header_list): else: return True + @staticmethod def _delete_top_row_and_compress( - self, input_file_name, output_file_ext, dest_dir): @@ -275,7 +275,7 @@ def _delete_top_row_and_compress( os_fh_output, fn_output = \ tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) - with open(input_file_name, 'rb') as f_in,\ + with open(input_file_name, 'rb') as f_in, \ open_fn(fn_output, 'wb') as f_out: f_in.seek(0) next(f_in) diff --git a/airflow/sensors/hdfs_sensor.py b/airflow/sensors/hdfs_sensor.py index c9bac08ecbfe7..d05adef71c276 100644 --- a/airflow/sensors/hdfs_sensor.py +++ b/airflow/sensors/hdfs_sensor.py @@ -39,13 +39,15 @@ class HdfsSensor(BaseSensorOperator): def __init__(self, filepath, hdfs_conn_id='hdfs_default', - ignored_ext=['_COPYING_'], + ignored_ext=None, ignore_copying=True, file_size=None, hook=HDFSHook, *args, **kwargs): super(HdfsSensor, self).__init__(*args, **kwargs) + if ignored_ext is None: + ignored_ext = ['_COPYING_'] self.filepath = filepath self.hdfs_conn_id = hdfs_conn_id self.file_size = file_size diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 4a1e57a062b4b..32303cd90bd8d 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -94,10 +94,8 @@ def _build_metrics(func_name, namespace): :return: dict with metrics """ - metrics = {'sub_command': func_name} - metrics['start_datetime'] = datetime.utcnow() - metrics['full_command'] = '{}'.format(list(sys.argv)) - metrics['user'] = getpass.getuser() + metrics = {'sub_command': func_name, 'start_datetime': datetime.utcnow(), + 'full_command': '{}'.format(list(sys.argv)), 'user': getpass.getuser()} assert isinstance(namespace, Namespace) tmp_dic = vars(namespace) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index db58e650d8ce9..45d0217e230ae 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -127,7 +127,7 @@ def chunks(items, chunk_size): """ Yield successive chunks of a given size from a list of items """ - if (chunk_size <= 0): + if chunk_size <= 0: raise ValueError('Chunk size must be a positive integer') for i in range(0, len(items), chunk_size): yield items[i:i + chunk_size] diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py index 8c34792bb2138..e768882ac5df3 100644 --- a/airflow/utils/log/gcs_task_handler.py +++ b/airflow/utils/log/gcs_task_handler.py @@ -164,7 +164,8 @@ def gcs_write(self, log, remote_log_location, append=True): except Exception as e: self.log.error('Could not write logs to %s: %s', remote_log_location, e) - def parse_gcs_url(self, gsurl): + @staticmethod + def parse_gcs_url(gsurl): """ Given a Google Cloud Storage URL (gs:///), returns a tuple containing the corresponding bucket and blob. diff --git a/airflow/www/views.py b/airflow/www/views.py index 4cc0c2a9e771f..3e41d2d02a1b9 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1972,7 +1972,7 @@ def task_instances(self, session=None): if dttm: dttm = pendulum.parse(dttm) else: - return ("Error: Invalid execution_date") + return "Error: Invalid execution_date" task_instances = { ti.task_id: alchemy_to_dict(ti) diff --git a/airflow/www_rbac/forms.py b/airflow/www_rbac/forms.py index da9d12c7adfd9..61c34888e3568 100644 --- a/airflow/www_rbac/forms.py +++ b/airflow/www_rbac/forms.py @@ -93,7 +93,7 @@ class ConnectionForm(DynamicForm): widget=BS3TextFieldWidget()) conn_type = SelectField( lazy_gettext('Conn Type'), - choices=(models.Connection._types), + choices=models.Connection._types, widget=Select2Widget()) host = StringField( lazy_gettext('Host'), diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index 4673def3a76c3..a9947ae096be3 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -1718,7 +1718,7 @@ def task_instances(self, session=None): if dttm: dttm = pendulum.parse(dttm) else: - return ("Error: Invalid execution_date") + return "Error: Invalid execution_date" task_instances = { ti.task_id: alchemy_to_dict(ti) diff --git a/dev/airflow-pr b/dev/airflow-pr index 4caa520d1f1aa..28fc300939d78 100755 --- a/dev/airflow-pr +++ b/dev/airflow-pr @@ -758,7 +758,7 @@ def standardize_jira_ref(text, only_jira=False): # Cleanup any remaining symbols: pattern = re.compile(r'^\W+(.*)', re.IGNORECASE) - if (pattern.search(text) is not None): + if pattern.search(text) is not None: text = pattern.search(text).groups()[0] def unique(seq): diff --git a/scripts/perf/scheduler_ops_metrics.py b/scripts/perf/scheduler_ops_metrics.py index 7928649977f8f..2658288856489 100644 --- a/scripts/perf/scheduler_ops_metrics.py +++ b/scripts/perf/scheduler_ops_metrics.py @@ -126,7 +126,7 @@ def heartbeat(self): if (len(successful_tis) == num_task_instances or (timezone.utcnow() - self.start_date).total_seconds() > MAX_RUNTIME_SECS): - if (len(successful_tis) == num_task_instances): + if len(successful_tis) == num_task_instances: self.log.info("All tasks processed! Printing stats.") else: self.log.info("Test timeout reached. " diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 34c82bcf9b683..616b9a0f16da4 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -67,12 +67,14 @@ def create_mock_args( ignore_dependencies=False, force=False, run_as_user=None, - executor_config={}, + executor_config=None, cfg_path=None, pickle=None, raw=None, interactive=None, ): + if executor_config is None: + executor_config = {} args = MagicMock(spec=Namespace) args.task_id = task_id args.dag_id = dag_id diff --git a/tests/contrib/executors/test_kubernetes_executor.py b/tests/contrib/executors/test_kubernetes_executor.py index d9da48ce3b0ce..c203e18d5cf8e 100644 --- a/tests/contrib/executors/test_kubernetes_executor.py +++ b/tests/contrib/executors/test_kubernetes_executor.py @@ -28,7 +28,8 @@ class TestAirflowKubernetesScheduler(unittest.TestCase): - def _gen_random_string(self, str_len): + @staticmethod + def _gen_random_string(str_len): return ''.join([random.choice(string.printable) for _ in range(str_len)]) def _cases(self): @@ -47,7 +48,8 @@ def _cases(self): return cases - def _is_valid_name(self, name): + @staticmethod + def _is_valid_name(name): regex = "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" return ( len(name) <= 253 and diff --git a/tests/contrib/hooks/test_aws_lambda_hook.py b/tests/contrib/hooks/test_aws_lambda_hook.py index a35a8fcd1e269..0b9744cd22285 100644 --- a/tests/contrib/hooks/test_aws_lambda_hook.py +++ b/tests/contrib/hooks/test_aws_lambda_hook.py @@ -42,7 +42,8 @@ def test_get_conn_returns_a_boto3_connection(self): function_name="test_function", region_name="us-east-1") self.assertIsNotNone(hook.get_conn()) - def lambda_function(self): + @staticmethod + def lambda_function(): code = textwrap.dedent(""" def lambda_handler(event, context): return event diff --git a/tests/contrib/hooks/test_gcp_mlengine_hook.py b/tests/contrib/hooks/test_gcp_mlengine_hook.py index bb3c5b62586db..c3bc7a9c0da0a 100644 --- a/tests/contrib/hooks/test_gcp_mlengine_hook.py +++ b/tests/contrib/hooks/test_gcp_mlengine_hook.py @@ -61,7 +61,8 @@ def __init__(self, test_cls, responses, expected_requests): for x in expected_requests] self._actual_requests = [] - def _normalize_requests_for_comparison(self, uri, http_method, body): + @staticmethod + def _normalize_requests_for_comparison(uri, http_method, body): parts = urlparse(uri) return ( parts._replace(query=set(parse_qsl(parts.query))), diff --git a/tests/contrib/hooks/test_mongo_hook.py b/tests/contrib/hooks/test_mongo_hook.py index 00fe0f0ef6e08..3b705f1a45fba 100644 --- a/tests/contrib/hooks/test_mongo_hook.py +++ b/tests/contrib/hooks/test_mongo_hook.py @@ -23,10 +23,10 @@ class MongoHookTest(MongoHook): - ''' + """ Extending hook so that a mockmongo collection object can be passed in to get_collection() - ''' + """ def __init__(self, conn_id='mongo_default', *args, **kwargs): super(MongoHookTest, self).__init__(conn_id=conn_id, *args, **kwargs) diff --git a/tests/contrib/hooks/test_redshift_hook.py b/tests/contrib/hooks/test_redshift_hook.py index c69ed8a9dca2a..029dfd38016c5 100644 --- a/tests/contrib/hooks/test_redshift_hook.py +++ b/tests/contrib/hooks/test_redshift_hook.py @@ -35,7 +35,8 @@ class TestRedshiftHook(unittest.TestCase): def setUp(self): configuration.load_test_config() - def _create_clusters(self): + @staticmethod + def _create_clusters(): client = boto3.client('redshift', region_name='us-east-1') client.create_cluster( ClusterIdentifier='test_cluster', diff --git a/tests/contrib/minikube/test_kubernetes_executor.py b/tests/contrib/minikube/test_kubernetes_executor.py index 769baae00fc49..45d4124d07973 100644 --- a/tests/contrib/minikube/test_kubernetes_executor.py +++ b/tests/contrib/minikube/test_kubernetes_executor.py @@ -42,7 +42,8 @@ def get_minikube_host(): class KubernetesExecutorTest(unittest.TestCase): - def _delete_airflow_pod(self): + @staticmethod + def _delete_airflow_pod(): air_pod = check_output(['kubectl', 'get', 'pods']).decode() air_pod = air_pod.split('\n') names = [re.compile('\s+').split(x)[0] for x in air_pod if 'airflow' in x] diff --git a/tests/contrib/minikube/test_kubernetes_pod_operator.py b/tests/contrib/minikube/test_kubernetes_pod_operator.py index 531343e674da9..5cb02d1ff1ba6 100644 --- a/tests/contrib/minikube/test_kubernetes_pod_operator.py +++ b/tests/contrib/minikube/test_kubernetes_pod_operator.py @@ -38,7 +38,8 @@ class KubernetesPodOperatorTest(unittest.TestCase): - def test_config_path_move(self): + @staticmethod + def test_config_path_move(): new_config_path = '/tmp/kube_config' old_config_path = os.path.expanduser('~/.kube/config') shutil.copy(old_config_path, new_config_path) @@ -79,7 +80,8 @@ def test_config_path(self, client_mock, launcher_mock): cluster_context='default', config_file=file_path) - def test_working_pod(self): + @staticmethod + def test_working_pod(): k = KubernetesPodOperator( namespace='default', image="ubuntu:16.04", @@ -91,7 +93,8 @@ def test_working_pod(self): ) k.execute(None) - def test_pod_node_selectors(self): + @staticmethod + def test_pod_node_selectors(): node_selectors = { 'beta.kubernetes.io/os': 'linux' } @@ -108,7 +111,8 @@ def test_pod_node_selectors(self): ) k.execute(None) - def test_pod_affinity(self): + @staticmethod + def test_pod_affinity(): affinity = { 'nodeAffinity': { 'requiredDuringSchedulingIgnoredDuringExecution': { @@ -139,7 +143,8 @@ def test_pod_affinity(self): ) k.execute(None) - def test_logging(self): + @staticmethod + def test_logging(): with mock.patch.object(PodLauncher, 'log') as mock_logger: k = KubernetesPodOperator( namespace='default', @@ -154,7 +159,8 @@ def test_logging(self): k.execute(None) mock_logger.info.assert_any_call(b"+ echo 10\n") - def test_volume_mount(self): + @staticmethod + def test_volume_mount(): with mock.patch.object(PodLauncher, 'log') as mock_logger: volume_mount = VolumeMount('test-volume', mount_path='/root/mount_file', diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py index 038cf6142eebd..e5cc770321b52 100644 --- a/tests/contrib/operators/test_dataproc_operator.py +++ b/tests/contrib/operators/test_dataproc_operator.py @@ -454,7 +454,8 @@ def test_cluster_name_log_sub(self): class DataProcHadoopOperatorTest(unittest.TestCase): # Unit test for the DataProcHadoopOperator - def test_hook_correct_region(self): + @staticmethod + def test_hook_correct_region(): with patch(HOOK) as mock_hook: dataproc_task = DataProcHadoopOperator( task_id=TASK_ID, @@ -465,7 +466,8 @@ def test_hook_correct_region(self): mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION) - def test_dataproc_job_id_is_set(self): + @staticmethod + def test_dataproc_job_id_is_set(): with patch(HOOK) as mock_hook: dataproc_task = DataProcHadoopOperator( task_id=TASK_ID @@ -476,7 +478,8 @@ def test_dataproc_job_id_is_set(self): class DataProcHiveOperatorTest(unittest.TestCase): # Unit test for the DataProcHiveOperator - def test_hook_correct_region(self): + @staticmethod + def test_hook_correct_region(): with patch(HOOK) as mock_hook: dataproc_task = DataProcHiveOperator( task_id=TASK_ID, @@ -487,7 +490,8 @@ def test_hook_correct_region(self): mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION) - def test_dataproc_job_id_is_set(self): + @staticmethod + def test_dataproc_job_id_is_set(): with patch(HOOK) as mock_hook: dataproc_task = DataProcHiveOperator( task_id=TASK_ID @@ -498,7 +502,8 @@ def test_dataproc_job_id_is_set(self): class DataProcPySparkOperatorTest(unittest.TestCase): # Unit test for the DataProcPySparkOperator - def test_hook_correct_region(self): + @staticmethod + def test_hook_correct_region(): with patch(HOOK) as mock_hook: dataproc_task = DataProcPySparkOperator( task_id=TASK_ID, @@ -510,7 +515,8 @@ def test_hook_correct_region(self): mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION) - def test_dataproc_job_id_is_set(self): + @staticmethod + def test_dataproc_job_id_is_set(): with patch(HOOK) as mock_hook: dataproc_task = DataProcPySparkOperator( task_id=TASK_ID, @@ -522,7 +528,8 @@ def test_dataproc_job_id_is_set(self): class DataProcSparkOperatorTest(unittest.TestCase): # Unit test for the DataProcSparkOperator - def test_hook_correct_region(self): + @staticmethod + def test_hook_correct_region(): with patch(HOOK) as mock_hook: dataproc_task = DataProcSparkOperator( task_id=TASK_ID, @@ -533,7 +540,8 @@ def test_hook_correct_region(self): mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY, REGION) - def test_dataproc_job_id_is_set(self): + @staticmethod + def test_dataproc_job_id_is_set(): with patch(HOOK) as mock_hook: dataproc_task = DataProcSparkOperator( task_id=TASK_ID diff --git a/tests/contrib/operators/test_hive_to_dynamodb_operator.py b/tests/contrib/operators/test_hive_to_dynamodb_operator.py index d1f75b50c40eb..e5b2c3e65a257 100644 --- a/tests/contrib/operators/test_hive_to_dynamodb_operator.py +++ b/tests/contrib/operators/test_hive_to_dynamodb_operator.py @@ -52,7 +52,8 @@ def setUp(self): self.hook = AwsDynamoDBHook( aws_conn_id='aws_default', region_name='us-east-1') - def process_data(self, data, *args, **kwargs): + @staticmethod + def process_data(data, *args, **kwargs): return json.loads(data.to_json(orient='records')) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') diff --git a/tests/contrib/operators/test_mysql_to_gcs_operator.py b/tests/contrib/operators/test_mysql_to_gcs_operator.py index 6e2e3f90dfb95..f0eb709500ca0 100644 --- a/tests/contrib/operators/test_mysql_to_gcs_operator.py +++ b/tests/contrib/operators/test_mysql_to_gcs_operator.py @@ -27,7 +27,8 @@ class MySqlToGoogleCloudStorageOperatorTest(unittest.TestCase): - def test_write_local_data_files(self): + @staticmethod + def test_write_local_data_files(): # Configure task_id = "some_test_id" @@ -35,17 +36,15 @@ def test_write_local_data_files(self): bucket = "some_bucket" filename = "some_filename" row_iter = [[1, b'byte_str_1'], [2, b'byte_str_2']] - schema = [] - schema.append({ + schema = [{ 'name': 'location', 'type': 'STRING', 'mode': 'nullable', - }) - schema.append({ + }, { 'name': 'uuid', 'type': 'BYTES', 'mode': 'nullable', - }) + }] schema_str = json.dumps(schema) op = MySqlToGoogleCloudStorageOperator( diff --git a/tests/contrib/operators/test_oracle_to_oracle_transfer.py b/tests/contrib/operators/test_oracle_to_oracle_transfer.py index 83d25e05a30f0..9c738dc8f76da 100644 --- a/tests/contrib/operators/test_oracle_to_oracle_transfer.py +++ b/tests/contrib/operators/test_oracle_to_oracle_transfer.py @@ -34,7 +34,8 @@ class OracleToOracleTransferTest(unittest.TestCase): - def test_execute(self): + @staticmethod + def test_execute(): oracle_destination_conn_id = 'oracle_destination_conn_id' destination_table = 'destination_table' oracle_source_conn_id = 'oracle_source_conn_id' diff --git a/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py index 921b2dac1fae9..95064c2655ac0 100644 --- a/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py +++ b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py @@ -35,7 +35,8 @@ class TestAwsRedshiftClusterSensor(unittest.TestCase): def setUp(self): configuration.load_test_config() - def _create_cluster(self): + @staticmethod + def _create_cluster(): client = boto3.client('redshift', region_name='us-east-1') client.create_cluster( ClusterIdentifier='test_cluster', diff --git a/tests/contrib/sensors/test_emr_base_sensor.py b/tests/contrib/sensors/test_emr_base_sensor.py index 8d00db713fb2d..2215edd09e0f6 100644 --- a/tests/contrib/sensors/test_emr_base_sensor.py +++ b/tests/contrib/sensors/test_emr_base_sensor.py @@ -33,13 +33,15 @@ class EmrBaseSensorSubclass(EmrBaseSensor): NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE'] FAILED_STATE = ['FAILED'] - def get_emr_response(self): + @staticmethod + def get_emr_response(): return { 'SomeKey': {'State': 'COMPLETED'}, 'ResponseMetadata': {'HTTPStatusCode': 200} } - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['SomeKey']['State'] operator = EmrBaseSensorSubclass( @@ -56,13 +58,15 @@ class EmrBaseSensorSubclass(EmrBaseSensor): NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE'] FAILED_STATE = ['FAILED'] - def get_emr_response(self): + @staticmethod + def get_emr_response(): return { 'SomeKey': {'State': 'PENDING'}, 'ResponseMetadata': {'HTTPStatusCode': 200} } - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['SomeKey']['State'] operator = EmrBaseSensorSubclass( @@ -79,13 +83,15 @@ class EmrBaseSensorSubclass(EmrBaseSensor): NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE'] FAILED_STATE = ['FAILED'] - def get_emr_response(self): + @staticmethod + def get_emr_response(): return { 'SomeKey': {'State': 'COMPLETED'}, 'ResponseMetadata': {'HTTPStatusCode': 400} } - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['SomeKey']['State'] operator = EmrBaseSensorSubclass( @@ -102,13 +108,15 @@ class EmrBaseSensorSubclass(EmrBaseSensor): NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE'] FAILED_STATE = ['FAILED'] - def get_emr_response(self): + @staticmethod + def get_emr_response(): return { 'SomeKey': {'State': 'FAILED'}, 'ResponseMetadata': {'HTTPStatusCode': 200} } - def state_from_response(self, response): + @staticmethod + def state_from_response(response): return response['SomeKey']['State'] operator = EmrBaseSensorSubclass( diff --git a/tests/core.py b/tests/core.py index b0471bc807d25..384eddcff1a7d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -2463,7 +2463,7 @@ def test_init_proxy_user(self): class HDFSHookTest(unittest.TestCase): def setUp(self): configuration.load_test_config() - os.environ['AIRFLOW_CONN_HDFS_DEFAULT'] = ('hdfs://localhost:8020') + os.environ['AIRFLOW_CONN_HDFS_DEFAULT'] = 'hdfs://localhost:8020' def test_get_client(self): client = HDFSHook(proxy_user='foo').get_conn() diff --git a/tests/models.py b/tests/models.py index 914b8bc6c4cc0..473cecb10080b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -274,7 +274,7 @@ def test_dag_task_priority_weight_total(self): match = pattern.match(task.task_id) task_depth = int(match.group(1)) # the sum of each stages after this task + itself - correct_weight = ((task_depth) * width + 1) * weight + correct_weight = (task_depth * width + 1) * weight calculated_weight = task.priority_weight_total self.assertEquals(calculated_weight, correct_weight) @@ -1105,7 +1105,7 @@ def process_dag(self, create_dag): dagbag = models.DagBag(include_examples=False) found_dags = dagbag.process_file(f.name) - return (dagbag, found_dags, f.name) + return dagbag, found_dags, f.name def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True): @@ -2411,7 +2411,7 @@ def test_xcom_enable_pickle_type(self): def test_xcom_disable_pickle_type_fail_on_non_json(self): class PickleRce(object): def __reduce__(self): - return (os.system, ("ls -alt",)) + return os.system, ("ls -alt",) configuration.set("core", "xcom_enable_pickling", "False") diff --git a/tests/operators/docker_operator.py b/tests/operators/docker_operator.py index 78a920c30cc13..59d6d5841642d 100644 --- a/tests/operators/docker_operator.py +++ b/tests/operators/docker_operator.py @@ -149,7 +149,8 @@ def test_execute_container_fails(self, client_class_mock): with self.assertRaises(AirflowException): operator.execute(None) - def test_on_kill(self): + @staticmethod + def test_on_kill(): client_mock = mock.Mock(spec=APIClient) operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest') diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py index 6ca6274a2c173..3e41454395fa3 100644 --- a/tests/operators/s3_to_hive_operator.py +++ b/tests/operators/s3_to_hive_operator.py @@ -145,15 +145,17 @@ def _get_fn(self, ext, header): key = self._get_key(ext, header) return self.fn[key] - def _get_key(self, ext, header): + @staticmethod + def _get_key(ext, header): key = ext + "_" + ('h' if header else 'nh') return key - def _check_file_equality(self, fn_1, fn_2, ext): + @staticmethod + def _check_file_equality(fn_1, fn_2, ext): # gz files contain mtime and filename in the header that # causes filecmp to return False even if contents are identical # Hence decompress to test for equality - if(ext.lower() == '.gz'): + if ext.lower() == '.gz': with gzip.GzipFile(fn_1, 'rb') as f_1,\ NamedTemporaryFile(mode='wb') as f_txt_1,\ gzip.GzipFile(fn_2, 'rb') as f_2,\ diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py index 82abe9c8098de..8196d636f4c1f 100644 --- a/tests/operators/test_virtualenv_operator.py +++ b/tests/operators/test_virtualenv_operator.py @@ -140,7 +140,8 @@ def f(): raise Exception self._run_as_operator(f, python_version=3, use_dill=False, requirements=['dill']) - def _invert_python_major_version(self): + @staticmethod + def _invert_python_major_version(): if sys.version_info[0] == 2: return 3 else: diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py b/tests/www/api/experimental/test_kerberos_endpoints.py index 9179cdecc027a..1cf30635fa407 100644 --- a/tests/www/api/experimental/test_kerberos_endpoints.py +++ b/tests/www/api/experimental/test_kerberos_endpoints.py @@ -65,7 +65,7 @@ def test_trigger_dag(self): response.url = 'http://{}'.format(get_hostname()) - class Request(): + class Request: headers = {} response.request = Request() diff --git a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py index 69a8117868f3e..54bbd865b3724 100644 --- a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py +++ b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py @@ -64,7 +64,7 @@ def test_trigger_dag(self): response.url = 'http://{}'.format(socket.getfqdn()) - class Request(): + class Request(object): headers = {} response.request = Request()