Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster default coder for unknown windows. #33382

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ cdef libc.stdint.int64_t MIN_TIMESTAMP_micros
cdef libc.stdint.int64_t MAX_TIMESTAMP_micros


cdef class _OrderedUnionCoderImpl(StreamCoderImpl):
cdef tuple _types
cdef tuple _coder_impls
cdef CoderImpl _fallback_coder_impl

@cython.locals(ix=int, c=CoderImpl)
cpdef encode_to_stream(self, value, OutputStream stream, bint nested)

@cython.locals(ix=int, c=CoderImpl)
cpdef decode_from_stream(self, InputStream stream, bint nested)


cdef class WindowedValueCoderImpl(StreamCoderImpl):
"""A coder for windowed values."""
cdef CoderImpl _value_coder
Expand Down
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,37 @@ def estimate_size(self, value, nested=False):
return size


class _OrderedUnionCoderImpl(StreamCoderImpl):
def __init__(self, coder_impl_types, fallback_coder_impl):
assert len(coder_impl_types) < 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why do we set the upper bound to 128 here? Shouldn't it be 255 given ix==0xFF is reserved for fallback coder in the following code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I was thinking about avoiding any sign issues, but also this leaves headroom for expanding the protocol in the future.

self._types, self._coder_impls = zip(*coder_impl_types)
self._fallback_coder_impl = fallback_coder_impl

def encode_to_stream(self, value, out, nested):
value_t = type(value)
for (ix, t) in enumerate(self._types):
if value_t is t:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to enforce the class strictly or allow matching for subclasses too?
e.g.

if issubclass(value_t, t):
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm comparing exact types here to make the round trip faithful. E.g. a Coder[T] is unlikely to faithfully encode and decode all subclasses of T. This is how fast primitives coder works as well.

out.write_byte(ix)
c = self._coder_impls[ix] # for typing
c.encode_to_stream(value, out, nested)
break
else:
if self._fallback_coder_impl is None:
raise ValueError("No fallback.")
out.write_byte(0xFF)
self._fallback_coder_impl.encode_to_stream(value, out, nested)

def decode_from_stream(self, in_stream, nested):
ix = in_stream.read_byte()
if ix == 0xFF:
if self._fallback_coder_impl is None:
raise ValueError("No fallback.")
return self._fallback_coder_impl.decode_from_stream(in_stream, nested)
else:
c = self._coder_impls[ix] # for typing
return c.decode_from_stream(in_stream, nested)


class WindowedValueCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees.

Expand Down
38 changes: 37 additions & 1 deletion sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,12 +1350,48 @@ def __hash__(self):
common_urns.coders.INTERVAL_WINDOW.urn, IntervalWindowCoder)


class _OrderedUnionCoder(FastCoder):
def __init__(
self, *coder_types: Tuple[type, Coder], fallback_coder: Optional[Coder]):
self._coder_types = coder_types
self._fallback_coder = fallback_coder

def _create_impl(self):
return coder_impl._OrderedUnionCoderImpl(
[(t, c.get_impl()) for t, c in self._coder_types],
fallback_coder_impl=self._fallback_coder.get_impl()
if self._fallback_coder else None)

def is_deterministic(self) -> bool:
return (
all(c.is_deterministic for _, c in self._coder_types) and (
self._fallback_coder is None or
self._fallback_coder.is_deterministic()))

def to_type_hint(self):
return Any

def __eq__(self, other):
return (
type(self) == type(other) and
self._coder_types == other._coder_types and
self._fallback_coder == other._fallback_coder)

def __hash__(self):
return hash((type(self), tuple(self._coder_types), self._fallback_coder))


class WindowedValueCoder(FastCoder):
"""Coder for windowed values."""
def __init__(self, wrapped_value_coder, window_coder=None):
# type: (Coder, Optional[Coder]) -> None
if not window_coder:
window_coder = PickleCoder()
# Avoid circular imports.
from apache_beam.transforms import window
window_coder = _OrderedUnionCoder(
(window.GlobalWindow, GlobalWindowCoder()),
(window.IntervalWindow, IntervalWindowCoder()),
fallback_coder=PickleCoder())
self.wrapped_value_coder = wrapped_value_coder
self.timestamp_coder = TimestampCoder()
self.window_coder = window_coder
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,14 @@ def test_decimal_coder(self):
test_encodings[idx],
base64.b64encode(test_coder.encode(value)).decode().rstrip("="))

def test_OrderedUnionCoder(self):
test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()),
(int, coders.VarIntCoder()),
fallback_coder=coders.FloatCoder())
self.check_coder(test_coder, 's')
self.check_coder(test_coder, 123)
self.check_coder(test_coder, 1.5)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
Loading