Skip to content

Commit

Permalink
D401 Support - Macros to Operators (Inclusive) (#33337)
Browse files Browse the repository at this point in the history
* D401 Support - airflow/macros thru airflow/operators

* fix static checks
  • Loading branch information
ferruzzi authored Aug 12, 2023
1 parent d69ffaf commit 2efb3a6
Show file tree
Hide file tree
Showing 29 changed files with 165 additions and 161 deletions.
18 changes: 9 additions & 9 deletions airflow/metrics/otel_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _generate_key_name(name: str, attributes: Attributes = None):

def name_is_otel_safe(prefix: str, name: str) -> bool:
"""
Returns True if the provided name and prefix would result in a name that meets the OpenTelemetry standard.
Return True if the provided name and prefix would result in a name that meets the OpenTelemetry standard.
Legal names are defined here:
https://opentelemetry.io/docs/reference/specification/metrics/api/#instrument-name-syntax
Expand All @@ -110,7 +110,7 @@ def _type_as_str(obj: Instrument) -> str:

def _get_otel_safe_name(name: str) -> str:
"""
Verifies that the provided name does not exceed OpenTelemetry's maximum length for metric names.
Verify that the provided name does not exceed OpenTelemetry's maximum length for metric names.
:param name: The original metric name
:returns: The name, truncated to an OTel-acceptable length if required.
Expand Down Expand Up @@ -290,7 +290,7 @@ def clear(self) -> None:
self.map.clear()

def _create_counter(self, name):
"""Creates a new counter or up_down_counter for the provided name."""
"""Create a new counter or up_down_counter for the provided name."""
otel_safe_name = _get_otel_safe_name(name)

if _is_up_down_counter(name):
Expand All @@ -303,7 +303,7 @@ def _create_counter(self, name):

def get_counter(self, name: str, attributes: Attributes = None):
"""
Returns the counter; creates a new one if it did not exist.
Return the counter; creates a new one if it did not exist.
:param name: The name of the counter to fetch or create.
:param attributes: Counter attributes, used to generate a unique key to store the counter.
Expand All @@ -315,7 +315,7 @@ def get_counter(self, name: str, attributes: Attributes = None):

def del_counter(self, name: str, attributes: Attributes = None) -> None:
"""
Deletes a counter.
Delete a counter.
:param name: The name of the counter to delete.
:param attributes: Counter attributes which were used to generate a unique key to store the counter.
Expand All @@ -326,7 +326,7 @@ def del_counter(self, name: str, attributes: Attributes = None) -> None:

def set_gauge_value(self, name: str, value: float | None, delta: bool, tags: Attributes):
"""
Overrides the last reading for a Gauge with a new value.
Override the last reading for a Gauge with a new value.
:param name: The name of the gauge to record.
:param value: The new reading to record.
Expand All @@ -344,7 +344,7 @@ def set_gauge_value(self, name: str, value: float | None, delta: bool, tags: Att

def _create_gauge(self, name: str, attributes: Attributes = None):
"""
Creates a new Observable Gauge with the provided name and the default value.
Create a new Observable Gauge with the provided name and the default value.
:param name: The name of the gauge to fetch or create.
:param attributes: Gauge attributes, used to generate a unique key to store the gauge.
Expand All @@ -361,12 +361,12 @@ def _create_gauge(self, name: str, attributes: Attributes = None):
return gauge

def read_gauge(self, key: str, *args) -> Iterable[Observation]:
"""Callback for the Observable Gauges, returns the Observation for the provided key."""
"""Return the Observation for the provided key; callback for the Observable Gauges."""
yield self.map[key]

def poke_gauge(self, name: str, attributes: Attributes = None) -> GaugeValues:
"""
Returns the value of the gauge; creates a new one with the default value if it did not exist.
Return the value of the gauge; creates a new one with the default value if it did not exist.
:param name: The name of the gauge to fetch or create.
:param attributes: Gauge attributes, used to generate a unique key to store the gauge.
Expand Down
2 changes: 1 addition & 1 deletion airflow/metrics/statsd_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def timer(


def get_statsd_logger(cls) -> SafeStatsdLogger:
"""Returns logger for StatsD."""
"""Return logger for StatsD."""
# no need to check for the scheduler/statsd_on -> this method is only called when it is set
# and previously it would crash with None is callable if it was called without it.
from statsd import StatsClient
Expand Down
2 changes: 1 addition & 1 deletion airflow/metrics/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def stat_name_otel_handler(
max_length: int = OTEL_NAME_MAX_LENGTH,
) -> str:
"""
Verifies that a proposed prefix and name combination will meet OpenTelemetry naming standards.
Verify that a proposed prefix and name combination will meet OpenTelemetry naming standards.
See: https://opentelemetry.io/docs/reference/specification/metrics/api/#instrument-name-syntax
Expand Down
2 changes: 1 addition & 1 deletion airflow/migrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
"""
Returns the primary and unique constraint along with column name.
Return the primary and unique constraint along with column name.
Some tables like `task_instance` are missing the primary key constraint
name and the name is auto-generated by the SQL server, so this function
Expand Down
6 changes: 4 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:

@cache
def get_parse_time_mapped_ti_count(self) -> int:
"""Number of mapped task instances that can be created on DAG run creation.
"""
Return the number of mapped task instances that can be created on DAG run creation.
This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.
Expand All @@ -479,7 +480,8 @@ def get_parse_time_mapped_ti_count(self) -> int:
return group.get_parse_time_mapped_ti_count()

def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
"""Number of mapped TaskInstances that can be created at run time.
"""
Return the number of mapped TaskInstances that can be created at run time.
This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
Expand Down
32 changes: 16 additions & 16 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def __hash__(self):
# including lineage information
def __or__(self, other):
"""
Called for [This Operator] | [Operator].
Return [This Operator] | [Operator].
The inlets of other will be set to pick up the outlets from this operator.
Other will be set as a downstream task of this operator.
Expand All @@ -1010,7 +1010,7 @@ def __or__(self, other):

def __gt__(self, other):
"""
Called for [Operator] > [Outlet].
Return [Operator] > [Outlet].
If other is an attr annotated object it is set as an outlet of this Operator.
"""
Expand All @@ -1026,7 +1026,7 @@ def __gt__(self, other):

def __lt__(self, other):
"""
Called for [Inlet] > [Operator] or [Operator] < [Inlet].
Return [Inlet] > [Operator] or [Operator] < [Inlet].
If other is an attr annotated object it is set as an inlet to this operator.
"""
Expand Down Expand Up @@ -1054,22 +1054,22 @@ def __setattr__(self, key, value):
self.set_xcomargs_dependencies()

def add_inlets(self, inlets: Iterable[Any]):
"""Sets inlets to this operator."""
"""Set inlets to this operator."""
self.inlets.extend(inlets)

def add_outlets(self, outlets: Iterable[Any]):
"""Defines the outlets of this operator."""
"""Define the outlets of this operator."""
self.outlets.extend(outlets)

def get_inlet_defs(self):
"""Gets inlet definitions on this task.
"""Get inlet definitions on this task.
:meta private:
"""
return self.inlets

def get_outlet_defs(self):
"""Gets outlet definitions on this task.
"""Get outlet definitions on this task.
:meta private:
"""
Expand Down Expand Up @@ -1109,7 +1109,7 @@ def dag(self, dag: DAG | None):
self._dag = dag

def has_dag(self):
"""Returns True if the Operator has been assigned to a DAG."""
"""Return True if the Operator has been assigned to a DAG."""
return self._dag is not None

deps: frozenset[BaseTIDep] = frozenset(
Expand All @@ -1134,7 +1134,7 @@ def prepare_for_execution(self) -> BaseOperator:

def set_xcomargs_dependencies(self) -> None:
"""
Resolves upstream dependencies of a task.
Resolve upstream dependencies of a task.
In this way passing an ``XComArg`` as value for a template field
will result in creating upstream relation between two tasks.
Expand Down Expand Up @@ -1163,13 +1163,13 @@ def set_xcomargs_dependencies(self) -> None:

@prepare_lineage
def pre_execute(self, context: Any):
"""This hook is triggered right before self.execute() is called."""
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is not None:
self._pre_execute_hook(context)

def execute(self, context: Context) -> Any:
"""
This is the main method to derive when creating an operator.
Derive when creating an operator.
Context is the same dictionary used as when rendering jinja templates.
Expand All @@ -1180,7 +1180,7 @@ def execute(self, context: Context) -> Any:
@apply_lineage
def post_execute(self, context: Any, result: Any = None):
"""
This hook is triggered right after self.execute() is called.
Execute right after self.execute() is called.
It is passed the execution context and any results returned by the operator.
"""
Expand Down Expand Up @@ -1252,7 +1252,7 @@ def clear(
downstream: bool = False,
session: Session = NEW_SESSION,
):
"""Clears the state of task instances associated with the task, following the parameters specified."""
"""Clear the state of task instances associated with the task, following the parameters specified."""
qry = select(TaskInstance).where(TaskInstance.dag_id == self.dag_id)

if start_date:
Expand Down Expand Up @@ -1355,7 +1355,7 @@ def run(
)

def dry_run(self) -> None:
"""Performs dry run for the operator - just render template fields."""
"""Perform dry run for the operator - just render template fields."""
self.log.info("Dry run")
for field in self.template_fields:
try:
Expand Down Expand Up @@ -1563,7 +1563,7 @@ def get_serialized_fields(cls):
return cls.__serialized_fields

def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
"""Serialize; required by DAGNode."""
return DagAttributeTypes.OP, self.task_id

@property
Expand Down Expand Up @@ -1837,7 +1837,7 @@ def cross_downstream(

def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
"""
Helper to simplify task dependency definition.
Simplify task dependency definition.
E.g.: suppose you want precedence like so::
Expand Down
16 changes: 8 additions & 8 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def parse_netloc_to_hostname(*args, **kwargs):
"""This method is deprecated."""
"""Do not use, this method is deprecated."""
warnings.warn("This method is deprecated.", RemovedInAirflow3Warning)
return _parse_netloc_to_hostname(*args, **kwargs)

Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
@staticmethod
def _validate_extra(extra, conn_id) -> None:
"""
Here we verify that ``extra`` is a JSON-encoded Python dict.
Verify that ``extra`` is a JSON-encoded Python dict.
From Airflow 3.0, we should no longer suppress these errors but raise instead.
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def on_db_load(self):
mask_secret(self.password)

def parse_from_uri(self, **uri):
"""This method is deprecated. Please use uri parameter in constructor."""
"""Use uri parameter in constructor, this method is deprecated."""
warnings.warn(
"This method is deprecated. Please use uri parameter in constructor.",
RemovedInAirflow3Warning,
Expand Down Expand Up @@ -219,7 +219,7 @@ def _parse_from_uri(self, uri: str):

@staticmethod
def _create_host(protocol, host) -> str | None:
"""Returns the connection host with the protocol."""
"""Return the connection host with the protocol."""
if not host:
return host
if protocol:
Expand Down Expand Up @@ -378,9 +378,9 @@ def __repr__(self):

def log_info(self):
"""
This method is deprecated.
Read each field individually or use the default representation (`__repr__`).
You can read each field individually or use the default representation (`__repr__`).
This method is deprecated.
"""
warnings.warn(
"This method is deprecated. You can read each field individually or "
Expand All @@ -396,9 +396,9 @@ def log_info(self):

def debug_info(self):
"""
This method is deprecated.
Read each field individually or use the default representation (`__repr__`).
You can read each field individually or use the default representation (`__repr__`).
This method is deprecated.
"""
warnings.warn(
"This method is deprecated. You can read each field individually or "
Expand Down
Loading

0 comments on commit 2efb3a6

Please sign in to comment.