Skip to content

Commit

Permalink
Properly handle timestamp prefixing of unkown window types.
Browse files Browse the repository at this point in the history
This was exposed by apache#28972 when the set of "known" coders was
inadvertently reduced.
  • Loading branch information
robertwb authored and hjtran committed Apr 4, 2024
1 parent b586ed2 commit 20398fe
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 5 deletions.
63 changes: 59 additions & 4 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,10 +1641,65 @@ def decode_from_stream(self, stream, nested):
return self._window_coder_impl.decode_from_stream(stream, nested)

def estimate_size(self, value: Any, nested: bool = False) -> int:
estimated_size = 0
estimated_size += TimestampCoderImpl().estimate_size(value)
estimated_size += self._window_coder_impl.estimate_size(value, nested)
return estimated_size
return (
TimestampCoderImpl().estimate_size(value.max_timestamp()) +
self._window_coder_impl.estimate_size(value, nested))


_OpaqueWindow = None


def _create_opaque_window(end, encoded_window):
# This is lazy to avoid circular import issues.
global _OpaqueWindow
if _OpaqueWindow is None:
from apache_beam.transforms.window import BoundedWindow

class _OpaqueWindow(BoundedWindow):
def __init__(self, end, encoded_window):
super().__init__(end)
self.encoded_window = encoded_window

def __repr__(self):
return 'OpaqueWindow(%s, %s)' % (self.end, self.encoded_window)

def __hash__(self):
return hash(self.encoded_window)

def __eq__(self, other):
return (
type(self) == type(other) and self.end == other.end and
self.encoded_window == other.encoded_window)

return _OpaqueWindow(end, encoded_window)


class TimestampPrefixingOpaqueWindowCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees.
A coder for unknown window types, which prefix required max_timestamp to
encoded original window.
The coder encodes and decodes custom window types with following format:
window's max_timestamp()
length prefixed encoded window
"""
def __init__(self) -> None:
pass

def encode_to_stream(self, value, stream, nested):
TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, True)
stream.write(value.encoded_window, True)

def decode_from_stream(self, stream, nested):
max_timestamp = TimestampCoderImpl().decode_from_stream(stream, True)
return _create_opaque_window(
max_timestamp.successor(), stream.read_all(True))

def estimate_size(self, value: Any, nested: bool = False) -> int:
return (
TimestampCoderImpl().estimate_size(value.max_timestamp()) +
len(value.encoded_window))


row_coders_registered = False
Expand Down
28 changes: 28 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,34 @@ def __hash__(self):
common_urns.coders.CUSTOM_WINDOW.urn, TimestampPrefixingWindowCoder)


class TimestampPrefixingOpaqueWindowCoder(FastCoder):
"""For internal use only; no backwards-compatibility guarantees.
Coder which decodes windows as bytes."""
def __init__(self) -> None:
pass

def _create_impl(self):
return coder_impl.TimestampPrefixingOpaqueWindowCoderImpl()

def is_deterministic(self) -> bool:
return True

def __repr__(self):
return 'TimestampPrefixingOpaqueWindowCoder'

def __eq__(self, other):
return type(self) == type(other)

def __hash__(self):
return hash((type(self)))


Coder.register_structured_urn(
python_urns.TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER,
TimestampPrefixingOpaqueWindowCoder)


class BigIntegerCoder(FastCoder):
def _create_impl(self):
return coder_impl.BigIntegerCoderImpl()
Expand Down
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/portability/python_urns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
# Components: The coders for the tuple elements, in order.
TUPLE_CODER = "beam:coder:tuple:v1"

# This allows us to decode TimestampedPrefixed(LengthPrefixed(AnyWindowCoder)).
TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER = (
"beam:timestamp_prefixed_opaque_window_coder:v1")

# Invoke UserFns in process, via direct function calls.
# Payload: None.
EMBEDDED_PYTHON = "beam:env:embedded_python:v1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,19 @@ def test_custom_merging_window(self):
from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn
self.assertEqual(GenericMergingWindowFn._HANDLES, {})

def test_custom_window_type(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda t: window.TimestampedValue(('k', t), t))
| beam.WindowInto(EvenOddWindows())
| beam.GroupByKey()
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(
res,
equal_to([('k', [1]), ('k', [2]), ('k', [101]), ('k', [100, 102])]))

@unittest.skip('BEAM-9119: test is flaky')
def test_large_elements(self):
with self.create_pipeline() as p:
Expand Down Expand Up @@ -2379,6 +2392,47 @@ def get_window_coder(self):
return coders.IntervalWindowCoder()


class ColoredFixedWindow(window.BoundedWindow):
def __init__(self, end, color):
super().__init__(end)
self.color = color

def __hash__(self):
return hash((self.end, self.color))

def __eq__(self, other):
return (
type(self) == type(other) and self.end == other.end and
self.color == other.color)


class ColoredFixedWindowCoder(beam.coders.Coder):
kv_coder = beam.coders.TupleCoder(
[beam.coders.TimestampCoder(), beam.coders.StrUtf8Coder()])

def encode(self, colored_window):
return self.kv_coder.encode((colored_window.end, colored_window.color))

def decode(self, encoded_window):
return ColoredFixedWindow(*self.kv_coder.decode(encoded_window))

def is_deterministic(self):
return True


class EvenOddWindows(window.NonMergingWindowFn):
def assign(self, context):
timestamp = context.timestamp
return [
ColoredFixedWindow(
timestamp - timestamp % 10 + 10,
'red' if timestamp.micros // 1000000 % 2 else 'black')
]

def get_window_coder(self):
return ColoredFixedWindowCoder()


class ExpectingSideInputsFn(beam.DoFn):
def __init__(self, name):
self._name = name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __init__(
self._known_coder_urns = set.union(
# Those which are required.
self._REQUIRED_CODER_URNS,
# Those common coders which are understood by all environments.
# Those common coders which are understood by many environments.
self._COMMON_CODER_URNS.intersection(
*(
set(env.capabilities)
Expand Down Expand Up @@ -515,8 +515,40 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id):
# type: (str) -> Tuple[str, str]
coder = self.components.coders[coder_id]
if coder.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
# If the coder is already length prefixed, we can use it as is, and
# have the runner treat it as opaque bytes.
return coder_id, self.bytes_coder_id
elif (coder.spec.urn == common_urns.coders.WINDOWED_VALUE.urn and
self.components.coders[coder.component_coder_ids[1]].spec.urn not in
self._known_coder_urns):
# A WindowedValue coder with an unknown window type.
# This needs to be encoded in such a way that we still have access to its
# timestmap.
lp_elem_coder = self.maybe_length_prefixed_coder(
coder.component_coder_ids[0])
tp_window_coder = self.timestamped_prefixed_window_coder(
coder.component_coder_ids[1])
new_coder_id = unique_name(
self.components.coders, coder_id + '_timestamp_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.WINDOWED_VALUE.urn),
component_coder_ids=[lp_elem_coder, tp_window_coder]))
safe_coder_id = unique_name(
self.components.coders, coder_id + '_timestamp_prefixed_opaque')
self.components.coders[safe_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.WINDOWED_VALUE.urn),
component_coder_ids=[
self.safe_coders[lp_elem_coder],
self.safe_coders[tp_window_coder]
]))
return new_coder_id, safe_coder_id
elif coder.spec.urn in self._known_coder_urns:
# A known coder type, but its components may still need to be length
# prefixed.
new_component_ids = [
self.maybe_length_prefixed_coder(c) for c in coder.component_coder_ids
]
Expand All @@ -538,6 +570,7 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id):
spec=coder.spec, component_coder_ids=safe_component_ids))
return new_coder_id, safe_coder_id
else:
# A completely unkown coder. Wrap the entire thing in a length prefix.
new_coder_id = unique_name(
self.components.coders, coder_id + '_length_prefixed')
self.components.coders[new_coder_id].CopyFrom(
Expand All @@ -547,6 +580,25 @@ def maybe_length_prefixed_and_safe_coder(self, coder_id):
component_coder_ids=[coder_id]))
return new_coder_id, self.bytes_coder_id

@memoize_on_instance
def timestamped_prefixed_window_coder(self, coder_id):
length_prefixed = self.maybe_length_prefixed_coder(coder_id)
new_coder_id = unique_name(
self.components.coders, coder_id + '_timestamp_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.CUSTOM_WINDOW.urn),
component_coder_ids=[length_prefixed]))
safe_coder_id = unique_name(
self.components.coders, coder_id + '_timestamp_prefixed_opaque')
self.components.coders[safe_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=python_urns.TIMESTAMP_PREFIXED_OPAQUE_WINDOW_CODER)))
self.safe_coders[new_coder_id] = safe_coder_id
return new_coder_id

def length_prefix_pcoll_coders(self, pcoll_id):
# type: (str) -> None
self.components.pcollections[pcoll_id].coder_id = (
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/utils/timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def predecessor(self):
"""Returns the largest timestamp smaller than self."""
return Timestamp(micros=self.micros - 1)

def successor(self):
# type: () -> Timestamp

"""Returns the smallest timestamp larger than self."""
return Timestamp(micros=self.micros + 1)

def __repr__(self):
# type: () -> str
micros = self.micros
Expand Down

0 comments on commit 20398fe

Please sign in to comment.