diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index cccc73662ce8..e44c2535156e 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -1641,10 +1641,65 @@ def decode_from_stream(self, stream, nested): return self._window_coder_impl.decode_from_stream(stream, nested) def estimate_size(self, value: Any, nested: bool = False) -> int: - estimated_size = 0 - estimated_size += TimestampCoderImpl().estimate_size(value) - estimated_size += self._window_coder_impl.estimate_size(value, nested) - return estimated_size + return ( + TimestampCoderImpl().estimate_size(value.max_timestamp()) + + self._window_coder_impl.estimate_size(value, nested)) + + +_OpaqueWindow = None + + +def _create_opaque_window(end, encoded_window): + # This is lazy to avoid circular import issues. + global _OpaqueWindow + if _OpaqueWindow is None: + from apache_beam.transforms.window import BoundedWindow + + class _OpaqueWindow(BoundedWindow): + def __init__(self, end, encoded_window): + super().__init__(end) + self.encoded_window = encoded_window + + def __repr__(self): + return 'OpaqueWindow(%s, %s)' % (self.end, self.encoded_window) + + def __hash__(self): + return hash(self.encoded_window) + + def __eq__(self, other): + return ( + type(self) == type(other) and self.end == other.end and + self.encoded_window == other.encoded_window) + + return _OpaqueWindow(end, encoded_window) + + +class TimestampPrefixingOpaqueWindowCoderImpl(StreamCoderImpl): + """For internal use only; no backwards-compatibility guarantees. + + A coder for unknown window types, which prefix required max_timestamp to + encoded original window. + + The coder encodes and decodes custom window types with following format: + window's max_timestamp() + length prefixed encoded window + """ + def __init__(self) -> None: + pass + + def encode_to_stream(self, value, stream, nested): + TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, True) + stream.write(value.encoded_window, True) + + def decode_from_stream(self, stream, nested): + max_timestamp = TimestampCoderImpl().decode_from_stream(stream, True) + return _create_opaque_window( + max_timestamp.successor(), stream.read_all(True)) + + def estimate_size(self, value: Any, nested: bool = False) -> int: + return ( + TimestampCoderImpl().estimate_size(value.max_timestamp()) + + len(value.encoded_window)) row_coders_registered = False diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 7c5c8e09303d..d5e717d7b9cc 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -1628,6 +1628,34 @@ def __hash__(self): common_urns.coders.CUSTOM_WINDOW.urn, TimestampPrefixingWindowCoder) +class TimestampPrefixingOpaqueWindowCoder(FastCoder): + """For internal use only; no backwards-compatibility guarantees. + + Coder which decodes windows as bytes.""" + def __init__(self) -> None: + pass + + def _create_impl(self): + return coder_impl.TimestampPrefixingOpaqueWindowCoderImpl() + + def is_deterministic(self) -> bool: + return True + + def __repr__(self): + return 'TimestampPrefixingOpaqueWindowCoder' + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash((type(self))) + + +Coder.register_structured_urn( + python_urns.TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER, + TimestampPrefixingOpaqueWindowCoder) + + class BigIntegerCoder(FastCoder): def _create_impl(self): return coder_impl.BigIntegerCoderImpl() diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 70582e7992a6..7dcfae83f10e 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -164,6 +164,7 @@ def tearDownClass(cls): coders.SinglePrecisionFloatCoder, coders.ToBytesCoder, coders.BigIntegerCoder, # tested in DecimalCoder + coders.TimestampPrefixingOpaqueWindowCoder, ]) cls.seen_nested -= set( [coders.ProtoCoder, coders.ProtoPlusCoder, CustomCoder]) @@ -739,6 +740,15 @@ def test_timestamp_prefixing_window_coder(self): coders.IntervalWindowCoder()), )), (window.IntervalWindow(0, 10), )) + def test_timestamp_prefixing_opaque_window_coder(self): + sdk_coder = coders.TimestampPrefixingWindowCoder( + coders.LengthPrefixCoder(coders.PickleCoder())) + safe_coder = coders.TimestampPrefixingOpaqueWindowCoder() + for w in [window.IntervalWindow(1, 123), window.GlobalWindow()]: + round_trip = sdk_coder.decode( + safe_coder.encode(safe_coder.decode(sdk_coder.encode(w)))) + self.assertEqual(w, round_trip) + def test_decimal_coder(self): test_coder = coders.DecimalCoder() diff --git a/sdks/python/apache_beam/portability/python_urns.py b/sdks/python/apache_beam/portability/python_urns.py index ed9ec6a07258..b96ab8c65be8 100644 --- a/sdks/python/apache_beam/portability/python_urns.py +++ b/sdks/python/apache_beam/portability/python_urns.py @@ -40,6 +40,10 @@ # Components: The coders for the tuple elements, in order. TUPLE_CODER = "beam:coder:tuple:v1" +# This allows us to decode TimestampedPrefixed(LengthPrefixed(AnyWindowCoder)). +TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER = ( + "beam:timestamp_prefixed_opaque_window_coder:v1") + # Invoke UserFns in process, via direct function calls. # Payload: None. EMBEDDED_PYTHON = "beam:env:embedded_python:v1" diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index 66c5be544e7f..f69ee1c24c4e 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -314,6 +314,9 @@ def test_register_finalizations(self): def test_custom_merging_window(self): raise unittest.SkipTest("https://github.com/apache/beam/issues/20641") + def test_custom_window_type(self): + raise unittest.SkipTest("https://github.com/apache/beam/issues/20641") + # Inherits all other tests. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 9eeeaa4bb24e..4a35da8dd274 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -1117,6 +1117,19 @@ def test_custom_merging_window(self): from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn self.assertEqual(GenericMergingWindowFn._HANDLES, {}) + def test_custom_window_type(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([1, 2, 100, 101, 102]) + | beam.Map(lambda t: window.TimestampedValue(('k', t), t)) + | beam.WindowInto(EvenOddWindows()) + | beam.GroupByKey() + | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1])))) + assert_that( + res, + equal_to([('k', [1]), ('k', [2]), ('k', [101]), ('k', [100, 102])])) + @unittest.skip('BEAM-9119: test is flaky') def test_large_elements(self): with self.create_pipeline() as p: @@ -2379,6 +2392,47 @@ def get_window_coder(self): return coders.IntervalWindowCoder() +class ColoredFixedWindow(window.BoundedWindow): + def __init__(self, end, color): + super().__init__(end) + self.color = color + + def __hash__(self): + return hash((self.end, self.color)) + + def __eq__(self, other): + return ( + type(self) == type(other) and self.end == other.end and + self.color == other.color) + + +class ColoredFixedWindowCoder(beam.coders.Coder): + kv_coder = beam.coders.TupleCoder( + [beam.coders.TimestampCoder(), beam.coders.StrUtf8Coder()]) + + def encode(self, colored_window): + return self.kv_coder.encode((colored_window.end, colored_window.color)) + + def decode(self, encoded_window): + return ColoredFixedWindow(*self.kv_coder.decode(encoded_window)) + + def is_deterministic(self): + return True + + +class EvenOddWindows(window.NonMergingWindowFn): + def assign(self, context): + timestamp = context.timestamp + return [ + ColoredFixedWindow( + timestamp - timestamp % 10 + 10, + 'red' if timestamp.micros // 1000000 % 2 else 'black') + ] + + def get_window_coder(self): + return ColoredFixedWindowCoder() + + class ExpectingSideInputsFn(beam.DoFn): def __init__(self, name): self._name = name diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 864f57807da1..71f1400e783b 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -428,7 +428,7 @@ def __init__( self._known_coder_urns = set.union( # Those which are required. self._REQUIRED_CODER_URNS, - # Those common coders which are understood by all environments. + # Those common coders which are understood by many environments. self._COMMON_CODER_URNS.intersection( *( set(env.capabilities) @@ -515,8 +515,40 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id): # type: (str) -> Tuple[str, str] coder = self.components.coders[coder_id] if coder.spec.urn == common_urns.coders.LENGTH_PREFIX.urn: + # If the coder is already length prefixed, we can use it as is, and + # have the runner treat it as opaque bytes. return coder_id, self.bytes_coder_id + elif (coder.spec.urn == common_urns.coders.WINDOWED_VALUE.urn and + self.components.coders[coder.component_coder_ids[1]].spec.urn not in + self._known_coder_urns): + # A WindowedValue coder with an unknown window type. + # This needs to be encoded in such a way that we still have access to its + # timestmap. + lp_elem_coder = self.maybe_length_prefixed_coder( + coder.component_coder_ids[0]) + tp_window_coder = self.timestamped_prefixed_window_coder( + coder.component_coder_ids[1]) + new_coder_id = unique_name( + self.components.coders, coder_id + '_timestamp_prefixed') + self.components.coders[new_coder_id].CopyFrom( + beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.coders.WINDOWED_VALUE.urn), + component_coder_ids=[lp_elem_coder, tp_window_coder])) + safe_coder_id = unique_name( + self.components.coders, coder_id + '_timestamp_prefixed_opaque') + self.components.coders[safe_coder_id].CopyFrom( + beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.coders.WINDOWED_VALUE.urn), + component_coder_ids=[ + self.safe_coders[lp_elem_coder], + self.safe_coders[tp_window_coder] + ])) + return new_coder_id, safe_coder_id elif coder.spec.urn in self._known_coder_urns: + # A known coder type, but its components may still need to be length + # prefixed. new_component_ids = [ self.maybe_length_prefixed_coder(c) for c in coder.component_coder_ids ] @@ -538,6 +570,7 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id): spec=coder.spec, component_coder_ids=safe_component_ids)) return new_coder_id, safe_coder_id else: + # A completely unkown coder. Wrap the entire thing in a length prefix. new_coder_id = unique_name( self.components.coders, coder_id + '_length_prefixed') self.components.coders[new_coder_id].CopyFrom( @@ -547,6 +580,25 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id): component_coder_ids=[coder_id])) return new_coder_id, self.bytes_coder_id + @memoize_on_instance + def timestamped_prefixed_window_coder(self, coder_id): + length_prefixed = self.maybe_length_prefixed_coder(coder_id) + new_coder_id = unique_name( + self.components.coders, coder_id + '_timestamp_prefixed') + self.components.coders[new_coder_id].CopyFrom( + beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.coders.CUSTOM_WINDOW.urn), + component_coder_ids=[length_prefixed])) + safe_coder_id = unique_name( + self.components.coders, coder_id + '_timestamp_prefixed_opaque') + self.components.coders[safe_coder_id].CopyFrom( + beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.FunctionSpec( + urn=python_urns.TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER))) + self.safe_coders[new_coder_id] = safe_coder_id + return new_coder_id + def length_prefix_pcoll_coders(self, pcoll_id): # type: (str) -> None self.components.pcollections[pcoll_id].coder_id = ( diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index 7700b05efec1..c54b5bf44e5c 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -146,6 +146,12 @@ def predecessor(self): """Returns the largest timestamp smaller than self.""" return Timestamp(micros=self.micros - 1) + def successor(self): + # type: () -> Timestamp + + """Returns the smallest timestamp larger than self.""" + return Timestamp(micros=self.micros + 1) + def __repr__(self): # type: () -> str micros = self.micros