Skip to content

Commit

Permalink
D401 support in amazon provider (#37275)
Browse files Browse the repository at this point in the history
  • Loading branch information
guptarohit authored Feb 10, 2024
1 parent 28f94f8 commit 8fac799
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def update_schema(args):


def _get_client():
"""Returns Amazon Verified Permissions client."""
"""Return Amazon Verified Permissions client."""
region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
return boto3.client("verifiedpermissions", region_name=region_name)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/auth_manager/views/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def logout(self):
@expose("/login_callback", methods=("GET", "POST"))
def login_callback(self):
"""
Callback where the user is redirected to after successful login.
Redirect the user to this callback after successful login.
CSRF protection needs to be disabled otherwise the callback won't work.
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/executors/ecs/boto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BotoTaskSchema(Schema):

@post_load
def make_task(self, data, **kwargs):
"""Overwrites marshmallow load() to return an instance of EcsExecutorTask instead of a dictionary."""
"""Overwrite marshmallow load() to return an EcsExecutorTask instance instead of a dictionary."""
# Imported here to avoid circular import.
from airflow.providers.amazon.aws.executors.ecs.utils import EcsExecutorTask

Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, *args, **kwargs):
self.run_task_kwargs = self._load_run_kwargs()

def start(self):
"""This is called by the scheduler when the Executor is being run for the first time."""
"""Call this when the Executor is run for the first time by the scheduler."""
check_health = conf.getboolean(
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
)
Expand Down Expand Up @@ -217,7 +217,7 @@ def sync(self):
self.log.exception("Failed to sync %s", self.__class__.__name__)

def sync_running_tasks(self):
"""Checks and update state on all running tasks."""
"""Check and update state on all running tasks."""
all_task_arns = self.active_workers.get_all_arns()
if not all_task_arns:
self.log.debug("No active Airflow tasks, skipping sync.")
Expand Down Expand Up @@ -324,7 +324,7 @@ def __handle_failed_task(self, task_arn: str, reason: str):

def attempt_task_runs(self):
"""
Takes tasks from the pending_tasks queue, and attempts to find an instance to run it on.
Take tasks from the pending_tasks queue, and attempts to find an instance to run it on.
If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
there are no EC2 instances available, no task is placed and this function will be
Expand Down Expand Up @@ -418,7 +418,7 @@ def _run_task_kwargs(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
) -> dict:
"""
Overrides the Airflow command to update the container overrides so kwargs are specific to this task.
Update the Airflow command by modifying container overrides for task-specific kwargs.
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
"""
Expand All @@ -443,7 +443,7 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None,
)

def end(self, heartbeat_interval=10):
"""Waits for all currently running tasks to end, and doesn't launch any tasks."""
"""Wait for all currently running tasks to end, and don't launch any tasks."""
try:
while True:
self.sync()
Expand Down Expand Up @@ -483,7 +483,7 @@ def _load_run_kwargs(self) -> dict:
return ecs_executor_run_task_kwargs

def get_container(self, container_list):
"""Searches task list for core Airflow container."""
"""Search task list for core Airflow container."""
for container in container_list:
try:
if container["name"] == self.container_name:
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def __init__(

def get_task_state(self) -> str:
"""
This is the primary logic that handles state in an ECS task.
Determine the state of an ECS task based on its status and other relevant attributes.
It will determine if a status is:
It can return one of the following statuses:
QUEUED - Task is being provisioned.
RUNNING - Task is launched on ECS.
REMOVED - Task provisioning has failed for some reason. See `stopped_reason`.
Expand Down Expand Up @@ -173,7 +173,7 @@ def add_task(
exec_config: ExecutorConfigType,
attempt_number: int,
):
"""Adds a task to the collection."""
"""Add a task to the collection."""
arn = task.task_arn
self.tasks[arn] = task
self.key_to_arn[airflow_task_key] = arn
Expand All @@ -182,7 +182,7 @@ def add_task(
self.key_to_failure_counts[airflow_task_key] = attempt_number

def update_task(self, task: EcsExecutorTask):
"""Updates the state of the given task based on task ARN."""
"""Update the state of the given task based on task ARN."""
self.tasks[task.task_arn] = task

def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
Expand All @@ -195,7 +195,7 @@ def task_by_arn(self, arn) -> EcsExecutorTask:
return self.tasks[arn]

def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
"""Deletes task from collection based off of Airflow Task Instance Key."""
"""Delete task from collection based off of Airflow Task Instance Key."""
arn = self.key_to_arn[task_key]
task = self.tasks[arn]
del self.key_to_arn[task_key]
Expand Down Expand Up @@ -227,11 +227,11 @@ def info_by_key(self, task_key: TaskInstanceKey) -> EcsTaskInfo:
return self.key_to_task_info[task_key]

def __getitem__(self, value):
"""Gets a task by AWS ARN."""
"""Get a task by AWS ARN."""
return self.task_by_arn(value)

def __len__(self):
"""Determines the number of tasks in collection."""
"""Determine the number of tasks in collection."""
return len(self.tasks)


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def get_log_events_async(
skip: int = 0,
start_from_head: bool = True,
) -> AsyncGenerator[Any, dict[str, Any]]:
"""A generator for log items in a single stream. This will yield all the items that are available.
"""Yield all the available items in a single log stream.
:param log_group: The name of the log group.
:param log_stream_name: The name of the specific stream.
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_conn(self) -> RedshiftConnection:
return redshift_connector.connect(**conn_kwargs)

def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
"""Returns Redshift specific information for OpenLineage."""
"""Return Redshift specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

authority = self._get_openlineage_redshift_authority_part(connection)
Expand Down Expand Up @@ -252,9 +252,9 @@ def _get_identifier_from_hostname(self, hostname: str) -> str:
return hostname

def get_openlineage_database_dialect(self, connection: Connection) -> str:
"""Returns redshift dialect."""
"""Return redshift dialect."""
return "redshift"

def get_openlineage_default_schema(self) -> str | None:
"""Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def __init__(
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftClusterTrigger arguments and classpath."""
"""Serialize RedshiftClusterTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
{
Expand All @@ -302,7 +302,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Simple async function run until the cluster status match the target status."""
"""Run async until the cluster status matches the target status."""
try:
hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id)
while True:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/triggers/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftDataTrigger arguments and classpath."""
"""Serialize RedshiftDataTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
{
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/triggers/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes SageMakerTrainingPrintLogTrigger arguments and classpath."""
"""Serialize SageMakerTrainingPrintLogTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
{
Expand All @@ -235,7 +235,7 @@ def hook(self) -> SageMakerHook:
return SageMakerHook(aws_conn_id=self.aws_conn_id)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator."""
"""Make async connection to sagemaker async hook and gets job status for a job submitted by the operator."""
stream_names: list[str] = [] # The list of log streams
positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position

Expand Down
10 changes: 0 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1354,16 +1354,6 @@ combine-as-imports = true
"airflow/providers/airbyte/operators/airbyte.py" = ["D401"]
"airflow/providers/airbyte/sensors/airbyte.py" = ["D401"]
"airflow/providers/airbyte/triggers/airbyte.py" = ["D401"]
"airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py" = ["D401"]
"airflow/providers/amazon/aws/auth_manager/views/auth.py" = ["D401"]
"airflow/providers/amazon/aws/executors/ecs/boto_schema.py" = ["D401"]
"airflow/providers/amazon/aws/executors/ecs/ecs_executor.py" = ["D401"]
"airflow/providers/amazon/aws/executors/ecs/utils.py" = ["D401"]
"airflow/providers/amazon/aws/hooks/redshift_sql.py" = ["D401"]
"airflow/providers/amazon/aws/hooks/logs.py" = ["D401"]
"airflow/providers/amazon/aws/triggers/redshift_cluster.py" = ["D401"]
"airflow/providers/amazon/aws/triggers/redshift_data.py" = ["D401"]
"airflow/providers/amazon/aws/triggers/sagemaker.py" = ["D401"]
"airflow/providers/cncf/kubernetes/callbacks.py" = ["D401"]
"airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py" = ["D401"]
"airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py" = ["D401"]
Expand Down

0 comments on commit 8fac799

Please sign in to comment.