Skip to content
/ beam Public
forked from apache/beam

Commit

Permalink
Merge pull request apache#32872 Modernize python type hints for apach…
Browse files Browse the repository at this point in the history
…e_beam.
  • Loading branch information
robertwb authored Nov 19, 2024
2 parents c57553c + 2d69dde commit 75fd964
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 182 deletions.
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class _InMemoryResultRecorder(object):
"""

# Class-level value to survive pickling.
_ALL_RESULTS = {} # type: dict[str, list[Any]]
_ALL_RESULTS: dict[str, list[Any]] = {}

def __init__(self):
self._id = id(self)
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Session(object):
def __init__(self, bindings=None):
self._bindings = dict(bindings or {})

def evaluate(self, expr): # type: (Expression) -> Any
def evaluate(self, expr: 'Expression') -> Any:
if expr not in self._bindings:
self._bindings[expr] = expr.evaluate_at(self)
return self._bindings[expr]

def lookup(self, expr): # type: (Expression) -> Any
def lookup(self, expr: 'Expression') -> Any:
return self._bindings[expr]


Expand Down
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import apache_beam as beam
from apache_beam import transforms
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
from apache_beam.pvalue import PCollection
Expand Down Expand Up @@ -101,15 +102,15 @@ def expand(self, input_pcolls):
from apache_beam.dataframe import convert

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: dict[Any, PCollection]
input_dict: dict[Any, PCollection] = _flatten(input_pcolls)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict
}
input_frames = {
input_frames: dict[Any, frame_base.DeferredFrame] = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
} # type: dict[Any, DeferredFrame] # noqa: F821
} # noqa: F821

# Apply the function.
frames_input = _substitute(input_pcolls, input_frames)
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

class AvroBase(object):

_temp_files = [] # type: List[str]
_temp_files: List[str] = []

def __init__(self, methodName='runTest'):
super().__init__(methodName)
Expand Down
37 changes: 12 additions & 25 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
import uuid
from collections import namedtuple
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
Expand All @@ -115,15 +114,13 @@
from apache_beam.options.value_provider import ValueProvider
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.window import BoundedWindow
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.transforms.window import BoundedWindow

__all__ = [
'EmptyMatchTreatment',
'MatchFiles',
Expand Down Expand Up @@ -382,8 +379,7 @@ def create_metadata(
mime_type="application/octet-stream",
compression_type=CompressionTypes.AUTO)

def open(self, fh):
# type: (BinaryIO) -> None
def open(self, fh: BinaryIO) -> None:
raise NotImplementedError

def write(self, record):
Expand Down Expand Up @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is
self._max_num_writers_per_bundle = max_writers_per_bundle

@staticmethod
def _get_sink_fn(input_sink):
# type: (...) -> Callable[[Any], FileSink]
def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]:
if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
return lambda x: input_sink()
elif isinstance(input_sink, FileSink):
Expand All @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink):
return lambda x: TextSink()

@staticmethod
def _get_destination_fn(destination):
# type: (...) -> Callable[[Any], str]
def _get_destination_fn(destination) -> Callable[[Any], str]:
if isinstance(destination, ValueProvider):
return lambda elm: destination.get()
elif callable(destination):
Expand Down Expand Up @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key):


class _WriteShardedRecordsFn(beam.DoFn):

def __init__(self,
base_path,
sink_fn, # type: Callable[[Any], FileSink]
shards # type: int
):
def __init__(
self, base_path, sink_fn: Callable[[Any], FileSink], shards: int):
self.base_path = base_path
self.sink_fn = sink_fn
self.shards = shards
Expand Down Expand Up @@ -805,17 +795,13 @@ def process(


class _AppendShardedDestination(beam.DoFn):
def __init__(
self,
destination, # type: Callable[[Any], str]
shards # type: int
):
def __init__(self, destination: Callable[[Any], str], shards: int):
self.destination_fn = destination
self.shards = shards

# We start the shards for a single destination at an arbitrary point.
self._shard_counter = collections.defaultdict(
lambda: random.randrange(self.shards)) # type: DefaultDict[str, int]
self._shard_counter: DefaultDict[str, int] = collections.defaultdict(
lambda: random.randrange(self.shards))

def _next_shard_for_destination(self, destination):
self._shard_counter[destination] = ((self._shard_counter[destination] + 1) %
Expand All @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
SPILLED_RECORDS = 'spilled_records'
WRITTEN_FILES = 'written_files'

_writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]]
_file_names = None # type: Dict[Tuple[str, BoundedWindow], str]
_writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO,
FileSink]] = None
_file_names: Dict[Tuple[str, BoundedWindow], str] = None

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _parse_location_from_exc(content, job_id):

def _start_job(
self,
request, # type: bigquery.BigqueryJobsInsertRequest
request: 'bigquery.BigqueryJobsInsertRequest',
stream=None,
):
"""Inserts a BigQuery job.
Expand Down Expand Up @@ -1802,9 +1802,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None):


def check_schema_equal(
left, right, *, ignore_descriptions=False, ignore_field_order=False):
# type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool

left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
*,
ignore_descriptions: bool = False,
ignore_field_order: bool = False) -> bool:
"""Check whether schemas are equivalent.
This comparison function differs from using == to compare TableSchema
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True):

class GcsIO(object):
"""Google Cloud Storage I/O client."""
def __init__(self, storage_client=None, pipeline_options=None):
# type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None
def __init__(
self,
storage_client: Optional[storage.Client] = None,
pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None:
if pipeline_options is None:
pipeline_options = PipelineOptions()
elif isinstance(pipeline_options, dict):
Expand Down
54 changes: 26 additions & 28 deletions sdks/python/apache_beam/metrics/monitoring_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def create_labels(ptransform=None, namespace=None, name=None, pcollection=None):
return labels


def int64_user_counter(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_counter(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.
Args:
Expand All @@ -199,9 +198,12 @@ def int64_user_counter(namespace, name, metric, ptransform=None):
USER_COUNTER_URN, SUM_INT64_TYPE, metric, labels)


def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_counter(
urn,
metric,
ptransform=None,
pcollection=None,
labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.
Args:
Expand All @@ -217,9 +219,8 @@ def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
return create_monitoring_info(urn, SUM_INT64_TYPE, metric, labels)


def int64_user_distribution(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_distribution(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the distribution monitoring info for the URN, metric and labels.
Args:
Expand All @@ -234,9 +235,11 @@ def int64_user_distribution(namespace, name, metric, ptransform=None):
USER_DISTRIBUTION_URN, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_distribution(urn, metric, ptransform=None, pcollection=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_distribution(
urn,
metric,
ptransform=None,
pcollection=None) -> metrics_pb2.MonitoringInfo:
"""Return a distribution monitoring info for the URN, metric and labels.
Args:
Expand All @@ -251,9 +254,8 @@ def int64_distribution(urn, metric, ptransform=None, pcollection=None):
return create_monitoring_info(urn, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_user_gauge(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_gauge(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.
Args:
Expand All @@ -276,9 +278,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None):
USER_GAUGE_URN, LATEST_INT64_TYPE, payload, labels)


def int64_gauge(urn, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_gauge(urn, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.
Args:
Expand Down Expand Up @@ -320,9 +320,8 @@ def user_set_string(namespace, name, metric, ptransform=None):
USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels)


def create_monitoring_info(urn, type_urn, payload, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def create_monitoring_info(
urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, type, metric and labels.
Args:
Expand Down Expand Up @@ -366,9 +365,9 @@ def is_user_monitoring_info(monitoring_info_proto):
return monitoring_info_proto.urn in USER_METRIC_URNS


def extract_metric_result_map_value(monitoring_info_proto):
# type: (...) -> Union[None, int, DistributionResult, GaugeResult, set]

def extract_metric_result_map_value(
monitoring_info_proto
) -> Union[None, int, DistributionResult, GaugeResult, set]:
"""Returns the relevant GaugeResult, DistributionResult or int value for
counter metric, set for StringSet metric.
Expand Down Expand Up @@ -408,14 +407,13 @@ def get_step_name(monitoring_info_proto):
return monitoring_info_proto.labels.get(PTRANSFORM_LABEL)


def to_key(monitoring_info_proto):
# type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]

def to_key(
monitoring_info_proto: metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]:
"""Returns a key based on the URN and labels.
This is useful in maps to prevent reporting the same MonitoringInfo twice.
"""
key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable]
key_items: List[Hashable] = list(monitoring_info_proto.labels.items())
key_items.append(monitoring_info_proto.urn)
return frozenset(key_items)

Expand Down
26 changes: 9 additions & 17 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ def __getstate__(self):
return self.__dict__

@classmethod
def _add_argparse_args(cls, parser):
# type: (_BeamArgumentParser) -> None
def _add_argparse_args(cls, parser: _BeamArgumentParser) -> None:
# Override this in subclasses to provide options.
pass

Expand Down Expand Up @@ -317,11 +316,8 @@ def from_dictionary(cls, options):
def get_all_options(
self,
drop_default=False,
add_extra_args_fn=None, # type: Optional[Callable[[_BeamArgumentParser], None]]
retain_unknown_options=False
):
# type: (...) -> Dict[str, Any]

add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None,
retain_unknown_options=False) -> Dict[str, Any]:
"""Returns a dictionary of all defined arguments.
Returns a dictionary of all defined arguments (arguments that are defined in
Expand Down Expand Up @@ -446,9 +442,7 @@ def from_urn(key):
def display_data(self):
return self.get_all_options(drop_default=True, retain_unknown_options=True)

def view_as(self, cls):
# type: (Type[PipelineOptionsT]) -> PipelineOptionsT

def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT:
"""Returns a view of current object as provided PipelineOption subclass.
Example Usage::
Expand Down Expand Up @@ -487,13 +481,11 @@ def view_as(self, cls):
view._all_options = self._all_options
return view

def _visible_option_list(self):
# type: () -> List[str]
def _visible_option_list(self) -> List[str]:
return sorted(
option for option in dir(self._visible_options) if option[0] != '_')

def __dir__(self):
# type: () -> List[str]
def __dir__(self) -> List[str]:
return sorted(
dir(type(self)) + list(self.__dict__) + self._visible_option_list())

Expand Down Expand Up @@ -643,9 +635,9 @@ def additional_option_ptransform_fn():


# Optional type checks that aren't enabled by default.
additional_type_checks = {
additional_type_checks: Dict[str, Callable[[], None]] = {
'ptransform_fn': additional_option_ptransform_fn,
} # type: Dict[str, Callable[[], None]]
}


def enable_all_additional_type_checks():
Expand Down Expand Up @@ -1840,7 +1832,7 @@ class OptionsContext(object):
Can also be used as a decorator.
"""
overrides = [] # type: List[Dict[str, Any]]
overrides: List[Dict[str, Any]] = []

def __init__(self, **options):
self.options = options
Expand Down
Loading

0 comments on commit 75fd964

Please sign in to comment.