Skip to content

Commit

Permalink
[AIRFLOW-2867] Refactor Code to conform standards (apache#3714)
Browse files Browse the repository at this point in the history
- Dictionary creation should be written by dictionary literal
- Python’s default arguments are evaluated once when the function is defined, not each time the function is called (like it is in say, Ruby). This means that if you use a mutable default argument and mutate it, you will and have mutated that object for all future calls to the function as well.
- Functions calling sets which can be replaced by set literal are now replaced by set literal
- Replace list literals
- Some of the static methods haven't been set static
- Remove redundant parentheses
  • Loading branch information
kaxil authored and Alice Berard committed Jan 3, 2019
1 parent e730c7a commit 545f220
Show file tree
Hide file tree
Showing 61 changed files with 304 additions and 226 deletions.
32 changes: 21 additions & 11 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def create_empty_table(self,
dataset_id,
table_id,
schema_fields=None,
time_partitioning={},
time_partitioning=None,
labels=None
):
"""
Expand Down Expand Up @@ -238,6 +238,8 @@ def create_empty_table(self,
:return:
"""
if time_partitioning is None:
time_partitioning = dict()
project_id = project_id if project_id is not None else self.project_id

table_resource = {
Expand Down Expand Up @@ -286,7 +288,7 @@ def create_external_table(self,
quote_character=None,
allow_quoted_newlines=False,
allow_jagged_rows=False,
src_fmt_configs={},
src_fmt_configs=None,
labels=None
):
"""
Expand Down Expand Up @@ -352,6 +354,8 @@ def create_external_table(self,
:type labels: dict
"""

if src_fmt_configs is None:
src_fmt_configs = {}
project_id, dataset_id, external_table_id = \
_split_tablename(table_input=external_project_dataset_table,
default_project_id=self.project_id,
Expand Down Expand Up @@ -482,7 +486,7 @@ def run_query(self,
labels=None,
schema_update_options=(),
priority='INTERACTIVE',
time_partitioning={}):
time_partitioning=None):
"""
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
table. See here:
Expand Down Expand Up @@ -548,6 +552,8 @@ def run_query(self,
"""

# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
if time_partitioning is None:
time_partitioning = {}
sql = bql if sql is None else sql

if bql:
Expand Down Expand Up @@ -808,8 +814,8 @@ def run_load(self,
allow_quoted_newlines=False,
allow_jagged_rows=False,
schema_update_options=(),
src_fmt_configs={},
time_partitioning={}):
src_fmt_configs=None,
time_partitioning=None):
"""
Executes a BigQuery load command to load data from Google Cloud Storage
to BigQuery. See here:
Expand Down Expand Up @@ -880,6 +886,10 @@ def run_load(self,
# if it's not, we raise a ValueError
# Refer to this link for more details:
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
if src_fmt_configs is None:
src_fmt_configs = {}
if time_partitioning is None:
time_partitioning = {}
source_format = source_format.upper()
allowed_formats = [
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
Expand Down Expand Up @@ -1011,12 +1021,12 @@ def run_with_configuration(self, configuration):

# Wait for query to finish.
keep_polling_job = True
while (keep_polling_job):
while keep_polling_job:
try:
job = jobs.get(
projectId=self.project_id,
jobId=self.running_job_id).execute()
if (job['status']['state'] == 'DONE'):
if job['status']['state'] == 'DONE':
keep_polling_job = False
# Check if job had errors.
if 'errorResult' in job['status']:
Expand Down Expand Up @@ -1045,7 +1055,7 @@ def poll_job_complete(self, job_id):
jobs = self.service.jobs()
try:
job = jobs.get(projectId=self.project_id, jobId=job_id).execute()
if (job['status']['state'] == 'DONE'):
if job['status']['state'] == 'DONE':
return True
except HttpError as err:
if err.resp.status in [500, 503]:
Expand Down Expand Up @@ -1079,13 +1089,13 @@ def cancel_query(self):
polling_attempts = 0

job_complete = False
while (polling_attempts < max_polling_attempts and not job_complete):
while polling_attempts < max_polling_attempts and not job_complete:
polling_attempts = polling_attempts + 1
job_complete = self.poll_job_complete(self.running_job_id)
if (job_complete):
if job_complete:
self.log.info('Job successfully canceled: %s, %s',
self.project_id, self.running_job_id)
elif (polling_attempts == max_polling_attempts):
elif polling_attempts == max_polling_attempts:
self.log.info(
"Stopping polling due to timeout. Job with id %s "
"has not completed cancel and may or may not finish.",
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(
raise ValueError('Retry limit must be greater than equal to 1')
self.retry_limit = retry_limit

def _parse_host(self, host):
@staticmethod
def _parse_host(host):
"""
The purpose of this function is to be robust to improper connections
settings provided by users, specifically in the host field.
Expand Down
4 changes: 2 additions & 2 deletions airflow/contrib/hooks/datastore_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def export_to_storage_bucket(self, bucket, namespace=None,
"""
Export entities from Cloud Datastore to Cloud Storage for backup
"""
output_uri_prefix = 'gs://' + ('/').join(filter(None, [bucket, namespace]))
output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace]))
if not entity_filter:
entity_filter = {}
if not labels:
Expand All @@ -191,7 +191,7 @@ def import_from_storage_bucket(self, bucket, file,
"""
Import a backup from Cloud Storage to Cloud Datastore
"""
input_url = 'gs://' + ('/').join(filter(None, [bucket, namespace, file]))
input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file]))
if not entity_filter:
entity_filter = {}
if not labels:
Expand Down
6 changes: 4 additions & 2 deletions airflow/contrib/hooks/gcp_container_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(self, project_id, location):
client_info = ClientInfo(client_library_version='airflow_v' + version.version)
self.client = container_v1.ClusterManagerClient(client_info=client_info)

def _dict_to_proto(self, py_dict, proto):
@staticmethod
def _dict_to_proto(py_dict, proto):
"""
Converts a python dictionary to the proto supplied
:param py_dict: The dictionary to convert
Expand Down Expand Up @@ -90,7 +91,8 @@ def get_operation(self, operation_name):
zone=self.location,
operation_id=operation_name)

def _append_label(self, cluster_proto, key, val):
@staticmethod
def _append_label(cluster_proto, key, val):
"""
Append labels to provided Cluster Protobuf
Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def list(self, bucket, versions=None, maxResults=None, prefix=None, delimiter=No

ids = list()
pageToken = None
while(True):
while True:
response = service.objects().list(
bucket=bucket,
versions=versions,
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/hooks/salesforce_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def get_available_fields(self, obj):

return [f['name'] for f in desc['fields']]

def _build_field_list(self, fields):
@staticmethod
def _build_field_list(fields):
# join all of the fields in a comma separated list
return ",".join(fields)

Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/hooks/vertica_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def get_conn(self):
"user": conn.login,
"password": conn.password or '',
"database": conn.schema,
"host": conn.host or 'localhost'
}

conn_config["host"] = conn.host or 'localhost'
if not conn.port:
conn_config["port"] = 5433
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/kubernetes/pod_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _monitor_pod(self, pod, get_logs):
while self.pod_is_running(pod):
self.log.info('Pod %s has state %s', pod.name, State.RUNNING)
time.sleep(2)
return (self._task_status(self.read_pod(pod)), result)
return self._task_status(self.read_pod(pod)), result

def _task_status(self, event):
self.log.info(
Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/operators/dataflow_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def execute(self, context):
self.py_file, self.py_options)


class GoogleCloudBucketHelper():
class GoogleCloudBucketHelper(object):
"""GoogleCloudStorageHook helper class to download GCS object."""
GCS_PREFIX_LENGTH = 5

Expand Down
6 changes: 4 additions & 2 deletions airflow/contrib/operators/dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ def _build_scale_cluster_data(self):
}
return scale_data

def _get_graceful_decommission_timeout(self, timeout):
@staticmethod
def _get_graceful_decommission_timeout(timeout):
match = re.match(r"^(\d+)(s|m|h|d)$", timeout)
if match:
if match.group(2) == "s":
Expand Down Expand Up @@ -575,7 +576,8 @@ def __init__(self,
self.project_id = project_id
self.region = region

def _wait_for_done(self, service, operation_name):
@staticmethod
def _wait_for_done(service, operation_name):
time.sleep(15)
while True:
response = service.projects().regions().operations().get(
Expand Down
4 changes: 3 additions & 1 deletion airflow/contrib/operators/gcp_container_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class GKEClusterCreateOperator(BaseOperator):
def __init__(self,
project_id,
location,
body={},
body=None,
gcp_conn_id='google_cloud_default',
api_version='v2',
*args,
Expand Down Expand Up @@ -148,6 +148,8 @@ def __init__(self,
"""
super(GKEClusterCreateOperator, self).__init__(*args, **kwargs)

if body is None:
body = {}
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.location = location
Expand Down
8 changes: 6 additions & 2 deletions airflow/contrib/operators/gcs_to_bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,18 @@ def __init__(self,
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
schema_update_options=(),
src_fmt_configs={},
src_fmt_configs=None,
external_table=False,
time_partitioning={},
time_partitioning=None,
*args, **kwargs):

super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)

# GCS config
if src_fmt_configs is None:
src_fmt_configs = {}
if time_partitioning is None:
time_partitioning = {}
self.bucket = bucket
self.source_objects = source_objects
self.schema_object = schema_object
Expand Down
6 changes: 4 additions & 2 deletions airflow/contrib/operators/mlengine_prediction_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def metric_fn(inst):


class JsonCoder(object):
def encode(self, x):
@staticmethod
def encode(x):
return json.dumps(x)

def decode(self, x):
@staticmethod
def decode(x):
return json.loads(x)


Expand Down
6 changes: 4 additions & 2 deletions airflow/contrib/operators/mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def execute(self, context):

return True

def _stringify(self, iterable, joinable='\n'):
@staticmethod
def _stringify(iterable, joinable='\n'):
"""
Takes an iterable (pymongo Cursor or Array) containing dictionaries and
returns a stringified version using python join
Expand All @@ -105,7 +106,8 @@ def _stringify(self, iterable, joinable='\n'):
[json.dumps(doc, default=json_util.default) for doc in iterable]
)

def transform(self, docs):
@staticmethod
def transform(docs):
"""
Processes pyMongo cursor and returns an iterable with each element being
a JSON serializable dictionary
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/operators/mysql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def _upload_to_gcs(self, files_to_upload):
for object, tmp_file_handle in files_to_upload.items():
hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json')

def _convert_types(self, schema, col_type_dict, row):
@staticmethod
def _convert_types(schema, col_type_dict, row):
"""
Takes a value from MySQLdb, and converts it to a value that's safe for
JSON/Google cloud storage/BigQuery. Dates are converted to UTC seconds.
Expand Down
Loading

0 comments on commit 545f220

Please sign in to comment.