From 53d4032b027055342c7321bef95e8a23e805c74c Mon Sep 17 00:00:00 2001 From: Alexei Frolov Date: Mon, 14 Oct 2024 20:57:46 +0000 Subject: [PATCH] pw_protobuf: Force use of callbacks for oneof MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pw_protobuf's message structures previously generated each of the members of a oneof group as a separate struct member, allowing multiple of them to be set to serialize semantically-invalid wire messages. This replaces oneof struct member generation with a single callback for the entire oneof group, allowing the user to encode/decode the desired oneof field using the wire stream encoder and decoder. Change-Id: I2e4f79c5f7d2ec99eb036e0f93a77675a29a625c Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/236592 Reviewed-by: Armando Montanez Lint: Lint 🤖 --- pw_protobuf/codegen_message_test.cc | 147 +++++++++++++++++- pw_protobuf/docs.rst | 112 ++++++++++++- pw_protobuf/encoder.cc | 31 +++- pw_protobuf/public/pw_protobuf/encoder.h | 6 + .../public/pw_protobuf/internal/codegen.h | 100 +++++++++++- .../pw_protobuf_test_protos/full_test.proto | 12 ++ .../pw_protobuf_test_protos/importer.proto | 8 +- pw_protobuf/py/pw_protobuf/codegen_pwpb.py | 140 ++++++++++++----- pw_protobuf/py/pw_protobuf/proto_tree.py | 51 +++++- .../size_report/oneof_codegen_comparison.cc | 41 ++++- pw_protobuf/size_report/proto_bloat.cc | 5 +- pw_protobuf/stream_decoder.cc | 10 +- 12 files changed, 587 insertions(+), 76 deletions(-) diff --git a/pw_protobuf/codegen_message_test.cc b/pw_protobuf/codegen_message_test.cc index 3beb5105a9..be59e3e49e 100644 --- a/pw_protobuf/codegen_message_test.cc +++ b/pw_protobuf/codegen_message_test.cc @@ -1301,7 +1301,7 @@ TEST(CodegenMessage, DISABLED_ReadDoesNotOverrun) { false, false, false, - false, + internal::CallbackType::kNone, 0, sizeof(KeyValuePair::Message) * 2, {}}, @@ -1903,7 +1903,7 @@ TEST(CodegenMessage, DISABLED_WriteDoesNotOverrun) { false, false, false, - false, + internal::CallbackType::kNone, 0, sizeof(KeyValuePair::Message) * 2, {}}, @@ -2079,5 +2079,148 @@ TEST(CodegenMessage, MaxSize) { EXPECT_EQ(count_message.enums.max_size(), RepeatedTest::kEnumsMaxSize); } +TEST(CodegenMessage, OneOf_Encode) { + OneOfTest::Message message; + + int invocations = 0; + message.type.SetEncoder([&invocations](OneOfTest::StreamEncoder& encoder) { + invocations++; + return encoder.WriteAnInt(32); + }); + + // clang-format off + constexpr uint8_t expected_proto[] = { + // type.an_int + 0x08, 0x20, + }; + // clang-format on + + std::array buffer; + OneOfTest::MemoryEncoder oneof_test(buffer); + + EXPECT_EQ(oneof_test.Write(message), OkStatus()); + EXPECT_EQ(invocations, 1); + + EXPECT_EQ(oneof_test.size(), sizeof(expected_proto)); + EXPECT_EQ( + std::memcmp(oneof_test.data(), expected_proto, sizeof(expected_proto)), + 0); +} + +TEST(CodegenMessage, OneOf_Encode_MultipleTimes) { + OneOfTest::Message message; + + int invocations = 0; + message.type.SetEncoder([&invocations](OneOfTest::StreamEncoder& encoder) { + invocations++; + return encoder.WriteAString("oneof"); + }); + + // clang-format off + constexpr uint8_t expected_proto[] = { + // type.a_string + 0x12, 0x05, 'o', 'n', 'e', 'o', 'f' + }; + // clang-format on + + // Write the same message struct to two different buffers. Even though its + // internal state is modified during the write, it should be logically const + // with both writes successfully producing the same output. + + std::array buffer_1; + std::array buffer_2; + OneOfTest::MemoryEncoder oneof_test_1(buffer_1); + OneOfTest::MemoryEncoder oneof_test_2(buffer_2); + + EXPECT_EQ(oneof_test_1.Write(message), OkStatus()); + EXPECT_EQ(invocations, 1); + EXPECT_EQ(oneof_test_1.size(), sizeof(expected_proto)); + EXPECT_EQ( + std::memcmp(oneof_test_1.data(), expected_proto, sizeof(expected_proto)), + 0); + + EXPECT_EQ(oneof_test_2.Write(message), OkStatus()); + EXPECT_EQ(invocations, 2); + EXPECT_EQ(oneof_test_2.size(), sizeof(expected_proto)); + EXPECT_EQ( + std::memcmp(oneof_test_2.data(), expected_proto, sizeof(expected_proto)), + 0); +} + +TEST(CodegenMessage, OneOf_Encode_UnsetEncoderFails) { + OneOfTest::Message message; + std::array buffer; + OneOfTest::MemoryEncoder oneof_test(buffer); + EXPECT_EQ(oneof_test.Write(message), Status::DataLoss()); +} + +TEST(CodegenMessage, OneOf_Decode) { + // clang-format off + constexpr uint8_t proto_data[] = { + // type.a_message + 0x1a, 0x02, 0x08, 0x01, + }; + // clang-format on + + stream::MemoryReader reader(as_bytes(span(proto_data))); + OneOfTest::StreamDecoder stream_decoder(reader); + + struct { + OneOfTest::Fields field; + OneOfTest::AMessage::Message submessage; + int invocations = 0; + } result; + + OneOfTest::Message message; + message.type.SetDecoder( + [&result](OneOfTest::Fields field, OneOfTest::StreamDecoder& decoder) { + result.field = field; + result.invocations++; + if (field == OneOfTest::Fields::kAMessage) { + return decoder.GetAMessageDecoder().Read(result.submessage); + } + return Status::InvalidArgument(); + }); + + EXPECT_EQ(stream_decoder.Read(message), OkStatus()); + EXPECT_EQ(result.field, OneOfTest::Fields::kAMessage); + EXPECT_EQ(result.invocations, 1); + EXPECT_EQ(result.submessage.a_bool, true); +} + +TEST(CodegenMessage, OneOf_Decode_MultipleOneOfFieldsFails) { + // clang-format off + constexpr uint8_t proto_data[] = { + // type.an_int + 0x08, 0x20, + // type.a_message + 0x1a, 0x02, 0x08, 0x01, + }; + // clang-format on + + stream::MemoryReader reader(as_bytes(span(proto_data))); + OneOfTest::StreamDecoder stream_decoder(reader); + + OneOfTest::Message message; + message.type.SetDecoder( + [](OneOfTest::Fields, OneOfTest::StreamDecoder&) { return OkStatus(); }); + + EXPECT_EQ(stream_decoder.Read(message), Status::DataLoss()); +} + +TEST(CodegenMessage, OneOf_Decode_UnsetDecoderFails) { + // clang-format off + constexpr uint8_t proto_data[] = { + // type.an_int + 0x08, 0x20, + }; + // clang-format on + + stream::MemoryReader reader(as_bytes(span(proto_data))); + OneOfTest::StreamDecoder stream_decoder(reader); + OneOfTest::Message message; + EXPECT_EQ(stream_decoder.Read(message), Status::DataLoss()); +} + } // namespace } // namespace pw::protobuf diff --git a/pw_protobuf/docs.rst b/pw_protobuf/docs.rst index 619b86b82d..a16ae005d5 100644 --- a/pw_protobuf/docs.rst +++ b/pw_protobuf/docs.rst @@ -67,6 +67,13 @@ Message Structures The highest level API is based around message structures created through C++ code generation, integrated with Pigweed's build system. +.. warning:: + + Message structures only support a subset of protobuf field types. Before + continuing, refer to :ref:`pw_protobuf-message-limitations` to understand + what types of protobuf messages can and cannot be represented, and whether + or not message structures are a suitable choice. + This results in the following generated structure: .. code-block:: c++ @@ -260,6 +267,31 @@ from one message type to another: encoder's ``status()`` call. Always check the status of calls or the encoder, as in the case of error, the encoded data will be invalid. +.. _pw_protobuf-message-limitations: + +Limitations +----------- +``pw_protobuf``'s message structure API is incomplete. Generally speaking, it is +reasonable to use for basic messages containing simple inline fields (scalar +types, bytes, and strings) and nested messages of similar form. Beyond this, +certain other types of protobuf specifiers can be used, but may be limited in +how and when they are supported. These cases are described below. + +If an object representation of protobuf messages is desired, the Pigweed team +recommends using `Nanopb`_, which is fully supported within Pigweed's build and +RPC systems. + +Message structures are eventually intended to be replaced with an alternative +object model. See `SEED-0103 `_ for additional +information about how message structures came to be and our future plans. + +``oneof`` fields +^^^^^^^^^^^^^^^^ +``oneof`` protobuf fields cannot be inlined within a message structure: they +must be encoded and decoded using callbacks. + +.. _pw_protobuf-per-field-apis: + Per-Field Writers and Readers ============================= The middle level API is based around typed methods to write and read each @@ -1039,7 +1071,7 @@ that can hold the set of values encoded by it, following these rules. message Store { Store nearest_store = 1; repeated int32 employee_numbers = 2; - string driections = 3; + string directions = 3; repeated string address = 4; repeated Employee employees = 5; } @@ -1061,6 +1093,75 @@ that can hold the set of values encoded by it, following these rules. A Callback object can be converted to a ``bool`` indicating whether a callback is set. +* Fields defined within a ``oneof`` group are represented by a ``OneOf`` + callback. + + .. code-block:: protobuf + + message OnlineOrder { + Product product = 1; + Customer customer = 2; + + oneof delivery { + Address shipping_address = 3; + Date pickup_date = 4; + } + } + + .. code-block:: + + // No options set. + + .. code-block:: c++ + + struct OnlineOrder::Message { + Product::Message product; + Customer::Message customer; + pw::protobuf::OneOf + delivery; + }; + + Encoding a ``oneof`` field is identical to using a regular field callback. + The encode callback will be invoked once when the message is written. Users + must ensure that only a single field is written to the encoder within the + callback. + + .. code-block:: c++ + + OnlineOrder::Message message; + message.delivery.SetEncoder( + [&pickup_date](OnlineOrder::StreamEncoder& encoder) { + return encoder.GetPickupDateEncoder().Write(pickup_date); + }); + + The ``OneOf`` decoder callback is invoked when reading a message structure + when a field within the ``oneof`` group is encountered. The found field is + passed to the callback. + + If multiple fields from the ``oneof`` group are encountered within a ``Read``, + it will fail with a ``DATA_LOSS`` status. + + .. code-block:: c++ + + OnlineOrder::Message message; + message.delivery.SetDecoder( + [this](OnlineOrder::Fields field, OnlineOrder::StreamDecoder& decoder) { + switch (field) { + case OnlineOrder::Fields::kShippingAddress: + PW_TRY(decoder.GetShippingAddressDecoder().Read(&this->shipping_address)); + break; + case OnlineOrder::Fields::kPickupDate: + PW_TRY(decoder.GetPickupDateDecoder().Read(&this->pickup_date)); + break; + default: + return pw::Status::DataLoss(); + } + + return pw::OkStatus(); + }); + Message structures can be copied, but doing so will clear any assigned callbacks. To preserve functions applied to callbacks, ensure that the message structure is moved. @@ -2313,10 +2414,9 @@ allocation, making it unsuitable for many embedded systems. nanopb ====== -`nanopb `_ is a commonly used embedded -protobuf library with very small code size and full code generation. It provides -both encoding/decoding functionality and in-memory C structs representing -protobuf messages. +`Nanopb`_ is a commonly used embedded protobuf library with very small code size +and full code generation. It provides both encoding/decoding functionality and +in-memory C structs representing protobuf messages. nanopb works well for many embedded products; however, using its generated code can run into RAM usage issues when processing nontrivial protobuf messages due @@ -2334,3 +2434,5 @@ intuitive user interface. Depending on the requirements of a project, either of these libraries could be suitable. + +.. _Nanopb: https://jpa.kapsi.fi/nanopb/ diff --git a/pw_protobuf/encoder.cc b/pw_protobuf/encoder.cc index e5b3c2d421..2831175e56 100644 --- a/pw_protobuf/encoder.cc +++ b/pw_protobuf/encoder.cc @@ -252,19 +252,25 @@ Status StreamEncoder::Write(span message, for (const auto& field : table) { // Calculate the span of bytes corresponding to the structure field to // read from. - const auto values = + ConstByteSpan values = message.subspan(field.field_offset(), field.field_size()); PW_CHECK(values.begin() >= message.begin() && values.end() <= message.end()); // If the field is using callbacks, interpret the input field accordingly // and allow the caller to provide custom handling. - if (field.use_callback()) { + if (field.callback_type() == internal::CallbackType::kSingleField) { const Callback* callback = reinterpret_cast*>( values.data()); PW_TRY(callback->Encode(*this)); continue; + } else if (field.callback_type() == internal::CallbackType::kOneOfGroup) { + const OneOf* callback = + reinterpret_cast*>( + values.data()); + PW_TRY(callback->Encode(*this)); + continue; } switch (field.wire_type()) { @@ -560,7 +566,28 @@ Status StreamEncoder::Write(span message, } } + ResetOneOfCallbacks(message, table); + return status_; } +void StreamEncoder::ResetOneOfCallbacks( + ConstByteSpan message, span table) { + for (const auto& field : table) { + // Calculate the span of bytes corresponding to the structure field to + // read from. + ConstByteSpan values = + message.subspan(field.field_offset(), field.field_size()); + PW_CHECK(values.begin() >= message.begin() && + values.end() <= message.end()); + + if (field.callback_type() == internal::CallbackType::kOneOfGroup) { + const OneOf* callback = + reinterpret_cast*>( + values.data()); + callback->invoked_ = false; + } + } +} + } // namespace pw::protobuf diff --git a/pw_protobuf/public/pw_protobuf/encoder.h b/pw_protobuf/public/pw_protobuf/encoder.h index bd6a95b33e..782afad444 100644 --- a/pw_protobuf/public/pw_protobuf/encoder.h +++ b/pw_protobuf/public/pw_protobuf/encoder.h @@ -789,6 +789,12 @@ class StreamEncoder { WireType type, size_t data_size); + // Callbacks for oneof fields set a flag to ensure they are only invoked once. + // To maintain logical constness of message structs passed to write, this + // resets each callback's invoked flag following a write operation. + void ResetOneOfCallbacks(ConstByteSpan message, + span table); + // The current encoder status. This status is only updated to reflect the // first error encountered. Any further write operations are blocked when the // encoder enters an error state. diff --git a/pw_protobuf/public/pw_protobuf/internal/codegen.h b/pw_protobuf/public/pw_protobuf/internal/codegen.h index d3e13c5784..b922533a0b 100644 --- a/pw_protobuf/public/pw_protobuf/internal/codegen.h +++ b/pw_protobuf/public/pw_protobuf/internal/codegen.h @@ -18,6 +18,7 @@ #include "pw_function/function.h" #include "pw_preprocessor/compiler.h" #include "pw_protobuf/wire_format.h" +#include "pw_result/result.h" #include "pw_span/span.h" #include "pw_status/status.h" @@ -38,6 +39,12 @@ enum class VarintType { kZigZag = 2, }; +enum class CallbackType { + kNone = 0, + kSingleField = 1, + kOneOfGroup = 2, +}; + // Represents a field in a code generated message struct that can be the target // for decoding or source of encoding. // @@ -72,7 +79,7 @@ class MessageField { bool is_fixed_size, bool is_repeated, bool is_optional, - bool use_callback, + CallbackType callback_type, size_t field_offset, size_t field_size, const span* nested_message_fields) @@ -84,7 +91,7 @@ class MessageField { static_cast(is_fixed_size) << kIsFixedSizeShift | static_cast(is_repeated) << kIsRepeatedShift | static_cast(is_optional) << kIsOptionalShift | - static_cast(use_callback) << kUseCallbackShift | + static_cast(callback_type) << kCallbackTypeShift | static_cast(field_size) << kFieldSizeShift), field_offset_(field_offset), nested_message_fields_(nested_message_fields) {} @@ -113,8 +120,9 @@ class MessageField { constexpr bool is_optional() const { return (field_info_ >> kIsOptionalShift) & 1; } - constexpr bool use_callback() const { - return (field_info_ >> kUseCallbackShift) & 1; + constexpr CallbackType callback_type() const { + return static_cast((field_info_ >> kCallbackTypeShift) & + kCallbackTypeMask); } constexpr size_t field_offset() const { return field_offset_; } constexpr size_t field_size() const { @@ -136,11 +144,11 @@ class MessageField { // is_string : 1 // is_fixed_size : 1 // is_repeated : 1 - // use_callback : 1 + // [unused space] : 1 // - // elem_size : 4 + // callback_type : 2 // is_optional : 1 - // [unused space] : 2 // - // field_size : 16 // @@ -156,9 +164,11 @@ class MessageField { static constexpr unsigned int kIsStringShift = 26u; static constexpr unsigned int kIsFixedSizeShift = 25u; static constexpr unsigned int kIsRepeatedShift = 24u; - static constexpr unsigned int kUseCallbackShift = 23u; + // Unused space: bit 23 (previously use_callback). static constexpr unsigned int kElemSizeShift = 19u; static constexpr unsigned int kElemSizeMask = (1u << 4) - 1; + static constexpr unsigned int kCallbackTypeShift = 17; + static constexpr unsigned int kCallbackTypeMask = (1u << 2) - 1; static constexpr unsigned int kIsOptionalShift = 16u; static constexpr unsigned int kFieldSizeShift = 0u; static constexpr unsigned int kFieldSizeMask = kMaxFieldSize; @@ -237,6 +247,82 @@ union Callback { Function decode_; }; +enum class NullFields : uint32_t {}; + +/// Callback for a oneof structure member. +/// A oneof callback will only be invoked once per struct member. +template +struct OneOf { + public: + constexpr OneOf() : invoked_(false), encode_() {} + ~OneOf() { encode_ = nullptr; } + + // Set the encoder callback. + void SetEncoder(Function&& encode) { + encode_ = std::move(encode); + } + + // Set the decoder callback. + void SetDecoder( + Function&& decode) { + decode_ = std::move(decode); + } + + // Allow moving of callbacks by moving the member. + constexpr OneOf(OneOf&& other) = default; + constexpr OneOf& operator=(OneOf&& other) = default; + + // Copying a callback does not copy the functions. + constexpr OneOf(const OneOf&) : encode_() {} + constexpr OneOf& operator=(const OneOf&) { + encode_ = nullptr; + return *this; + } + + // Evaluate to true if the encoder or decoder callback is set. + explicit operator bool() const { return encode_ || decode_; } + + private: + friend StreamDecoder; + friend StreamEncoder; + + constexpr void ResetForNewWrite() const { invoked_ = false; } + + Status Encode(StreamEncoder& encoder) const { + if (encode_) { + if (invoked_) { + // The oneof has already been encoded. + return OkStatus(); + } + + invoked_ = true; + return encode_(encoder); + } + return Status::DataLoss(); + } + + Status Decode(Fields field, StreamDecoder& decoder) const { + if (decode_) { + if (invoked_) { + // Multiple fields from the same oneof exist in the serialized message. + return Status::DataLoss(); + } + + invoked_ = true; + return decode_(field, decoder); + } + return Status::DataLoss(); + } + + mutable bool invoked_; + union { + Function encode_; + Function decode_; + }; +}; + template constexpr bool IsTriviallyComparable() { static_assert(internal::kInvalidMessageStruct, diff --git a/pw_protobuf/pw_protobuf_test_protos/full_test.proto b/pw_protobuf/pw_protobuf_test_protos/full_test.proto index fa144e3654..0c3403c4cf 100644 --- a/pw_protobuf/pw_protobuf_test_protos/full_test.proto +++ b/pw_protobuf/pw_protobuf_test_protos/full_test.proto @@ -211,3 +211,15 @@ message CornerCases { // Generates ReadUint32() and WriteUint32() that call the parent definition. uint32 _uint32 = 1; } + +message OneOfTest { + message AMessage { + bool a_bool = 1; + } + + oneof type { + int32 an_int = 1; + string a_string = 2; + AMessage a_message = 3; + } +} diff --git a/pw_protobuf/pw_protobuf_test_protos/importer.proto b/pw_protobuf/pw_protobuf_test_protos/importer.proto index 851c855ae6..e4c4895e62 100644 --- a/pw_protobuf/pw_protobuf_test_protos/importer.proto +++ b/pw_protobuf/pw_protobuf_test_protos/importer.proto @@ -34,9 +34,7 @@ message TestResult { } message TestMessage { - oneof level { - imported.Notice notice = 1; - imported.Debug debug = 2; - imported.PrefixDebug prefix_debug = 3; - } + imported.Notice notice = 1; + imported.Debug debug = 2; + imported.PrefixDebug prefix_debug = 3; } diff --git a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py index b856eb8dea..e98c4e94ee 100644 --- a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py +++ b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py @@ -89,6 +89,21 @@ def debug_print(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) +class _CallbackType(enum.Enum): + NONE = 0 + SINGLE_FIELD = 1 + ONEOF_GROUP = 2 + + def as_cpp(self) -> str: + match self: + case _CallbackType.NONE: + return 'kNone' + case _CallbackType.SINGLE_FIELD: + return 'kSingleField' + case _CallbackType.ONEOF_GROUP: + return 'kOneOfGroup' + + class ProtoMember(abc.ABC): """Base class for a C++ class member for a field in a protobuf message.""" @@ -113,6 +128,21 @@ def name(self) -> str: def should_appear(self) -> bool: # pylint: disable=no-self-use """Whether the member should be generated.""" + @abc.abstractmethod + def _use_callback(self) -> bool: + """Whether the member should be encoded and decoded with a callback.""" + + def callback_type(self) -> _CallbackType: + if self._field.oneof() is not None: + return _CallbackType.ONEOF_GROUP + + options = self._field.options() + assert options is not None + + if options.use_callback or self._use_callback(): + return _CallbackType.SINGLE_FIELD + return _CallbackType.NONE + def field_cast(self) -> str: return 'static_cast(Fields::{})'.format( self._field.enum_name() @@ -190,6 +220,9 @@ def should_appear(self) -> bool: # pylint: disable=no-self-use """Whether the method should be generated.""" return True + def _use_callback(self) -> bool: # pylint: disable=no-self-use + return False + def param_string(self) -> str: return ', '.join([f'{type} {name}' for type, name in self.params()]) @@ -394,6 +427,11 @@ def name(self) -> str: return self._field.field_name() def should_appear(self) -> bool: + # Oneof fields are not supported by the code generator. + oneof = self._field.oneof() + if oneof is not None: + return oneof.is_synthetic() + return True @abc.abstractmethod @@ -424,13 +462,9 @@ def repeated_field_container(type_name: str, max_size: str) -> str: """ return f'::pw::Vector<{type_name}, {max_size}>' - def use_callback(self) -> bool: # pylint: disable=no-self-use + def _use_callback(self) -> bool: # pylint: disable=no-self-use """Returns whether the decoder should use a callback.""" - options = self._field.options() - assert options is not None - return options.use_callback or ( - self._field.is_repeated() and self.max_size() == 0 - ) + return self._field.is_repeated() and self.max_size() == 0 def is_optional(self) -> bool: """Returns whether the decoder should use std::optional.""" @@ -466,7 +500,7 @@ def sub_table(self) -> str: # pylint: disable=no-self-use def struct_member_type(self, from_root: bool = False) -> str: """Returns the structure member type.""" - if self.use_callback(): + if self.callback_type() is _CallbackType.SINGLE_FIELD: return ( f'{PROTOBUF_NAMESPACE}::Callback' ) @@ -514,6 +548,13 @@ def _bool_attr(self, attr: str) -> str: def table_entry(self) -> list[str]: """Table entry.""" + + oneof = self._field.oneof() + if oneof is not None and not oneof.is_synthetic(): + struct_member = oneof.name + else: + struct_member = self.name() + return [ self.field_cast(), self._wire_type_table_entry(), @@ -523,9 +564,12 @@ def table_entry(self) -> list[str]: self._bool_attr('is_fixed_size'), self._bool_attr('is_repeated'), self._bool_attr('is_optional'), - self._bool_attr('use_callback'), - 'offsetof(Message, {})'.format(self.name()), - 'sizeof(Message::{})'.format(self.name()), + ( + f'{_INTERNAL_NAMESPACE}::CallbackType::' + + self.callback_type().as_cpp() + ), + 'offsetof(Message, {})'.format(struct_member), + 'sizeof(Message::{})'.format(struct_member), self.sub_table(), ] @@ -642,23 +686,17 @@ def _elem_size_table_entry(self) -> str: def type_name(self, from_root: bool = False) -> str: return '{}::Message'.format(self._relative_type_namespace(from_root)) - def use_callback(self) -> bool: + def _use_callback(self) -> bool: # Always use a callback for a message dependency removed to break a # cycle, and for repeated fields, since in both cases there's no way # to handle the size of nested field. - options = self._field.options() - assert options is not None - return ( - options.use_callback - or self._dependency_removed() - or self._field.is_repeated() - ) + return self._dependency_removed() or self._field.is_repeated() def wire_type(self) -> str: return 'kDelimited' def sub_table(self) -> str: - if self.use_callback(): + if self.callback_type() is not _CallbackType.NONE: return 'nullptr' return '&{}::kMessageFields'.format(self._relative_type_namespace()) @@ -671,7 +709,7 @@ def _size_fn(self) -> str: return 'SizeOfDelimitedFieldWithoutValue' def _size_length(self) -> str | None: - if self.use_callback(): + if self.callback_type() is not _CallbackType.NONE: return None return '{}::kMaxEncodedSizeBytes'.format( @@ -1984,7 +2022,7 @@ class BytesProperty(MessageProperty): def type_name(self, from_root: bool = False) -> str: return 'std::byte' - def use_callback(self) -> bool: + def _use_callback(self) -> bool: return self.max_size() == 0 def max_size(self) -> int: @@ -2013,7 +2051,7 @@ def _size_fn(self) -> str: return 'SizeOfDelimitedFieldWithoutValue' def _size_length(self) -> str | None: - if self.use_callback(): + if self.callback_type() is not _CallbackType.NONE: return None return self.max_size_constant_name() @@ -2115,7 +2153,7 @@ class StringProperty(MessageProperty): def type_name(self, from_root: bool = False) -> str: return 'char' - def use_callback(self) -> bool: + def _use_callback(self) -> bool: return self.max_size() == 0 def max_size(self) -> int: @@ -2146,7 +2184,7 @@ def _size_fn(self) -> str: return 'SizeOfDelimitedFieldWithoutValue' def _size_length(self) -> str | None: - if self.use_callback(): + if self.callback_type() is not _CallbackType.NONE: return None return self.max_size_constant_name() @@ -2578,14 +2616,18 @@ def _size_fn(self) -> str: def proto_message_field_props( message: ProtoMessage, root: ProtoNode, + include_hidden: bool = False, ) -> Iterable[MessageProperty]: """Yields a MessageProperty for each field in a ProtoMessage. - Only properties which should_appear() is True are returned. + Only properties which should_appear() is True are returned, unless + `include_hidden` is set. Args: message: The ProtoMessage whose fields are iterated. root: The root ProtoNode of the tree. + include_hidden: If True, also yield fields which shouldn't appear in the + struct. Yields: An appropriately-typed MessageProperty object for each field @@ -2594,7 +2636,7 @@ def proto_message_field_props( for field in message.fields(): property_class = PROTO_FIELD_PROPERTIES[field.type()] prop = property_class(field, message, root) - if prop.should_appear(): + if include_hidden or prop.should_appear(): yield prop @@ -2930,9 +2972,19 @@ def generate_struct_for_message( name = prop.name() output.write_line(f'{type_name} {name};') - if not prop.use_callback(): + if prop.callback_type() is _CallbackType.NONE: cmp.append(f'this->{name} == other.{name}') + for oneof in message.oneofs(): + if oneof.is_synthetic(): + continue + + fields = f'{message.cpp_namespace(root=root)}::Fields' + output.write_line( + f'{PROTOBUF_NAMESPACE}::OneOf' + f' {oneof.name};' + ) + # Equality operator output.write_line() output.write_line('bool operator==(const Message& other) const {') @@ -2986,7 +3038,10 @@ def generate_table_for_message( # # The kMessageFields span is generated whether the message has fields or # not. Only the span is referenced elsewhere. - if properties: + all_properties = list( + proto_message_field_props(message, root, include_hidden=True) + ) + if all_properties: output.write_line( f'inline constexpr {_INTERNAL_NAMESPACE}::MessageField ' ' _kMessageFields[] = {' @@ -2994,7 +3049,7 @@ def generate_table_for_message( # Generate members for each of the message's fields. with output.indent(): - for prop in properties: + for prop in all_properties: table = ', '.join(prop.table_entry()) output.write_line(f'{{{table}}},') @@ -3010,19 +3065,20 @@ def generate_table_for_message( [f'message.{prop.name()}' for prop in properties] ) - # Generate std::tuple for Message fields. - output.write_line( - 'inline constexpr auto ToTuple(const Message &message) {' - ) - output.write_line(f' return std::tie({member_list});') - output.write_line('}') + if properties: + # Generate std::tuple for main Message fields only. + output.write_line( + 'inline constexpr auto ToTuple(const Message &message) {' + ) + output.write_line(f' return std::tie({member_list});') + output.write_line('}') - # Generate mutable std::tuple for Message fields. - output.write_line( - 'inline constexpr auto ToMutableTuple(Message &message) {' - ) - output.write_line(f' return std::tie({member_list});') - output.write_line('}') + # Generate mutable std::tuple for Message fields. + output.write_line( + 'inline constexpr auto ToMutableTuple(Message &message) {' + ) + output.write_line(f' return std::tie({member_list});') + output.write_line('}') else: output.write_line( f'inline constexpr pw::span None: is_trivially_comparable = True for prop in proto_message_field_props(message, root): - if prop.use_callback(): + if prop.callback_type() is not _CallbackType.NONE: is_trivially_comparable = False break diff --git a/pw_protobuf/py/pw_protobuf/proto_tree.py b/pw_protobuf/py/pw_protobuf/proto_tree.py index bd60bf4f6c..22b5d4a82d 100644 --- a/pw_protobuf/py/pw_protobuf/proto_tree.py +++ b/pw_protobuf/py/pw_protobuf/proto_tree.py @@ -17,6 +17,7 @@ import abc import collections +import dataclasses import enum import itertools @@ -390,9 +391,27 @@ def _supports_child(self, child: ProtoNode) -> bool: class ProtoMessage(ProtoNode): """Representation of a message in a .proto file.""" + @dataclasses.dataclass + class OneOf: + name: str + fields: list[ProtoMessageField] = dataclasses.field( + default_factory=list + ) + + def is_synthetic(self) -> bool: + """Returns whether this is a synthetic oneof field.""" + # protoc expresses proto3 optional fields as a "synthetic" oneof + # containing only a single member. pw_protobuf does not support + # oneof in general, but has special handling for proto3 optional + # fields. This method exists to distinguish a real, user-defined + # oneof from a compiler-generated one. + # https://cs.opensource.google/protobuf/protobuf/+/main:src/google/protobuf/descriptor.proto;l=305;drc=5a68dddcf9564f92815296099f07f7dfe8713908 + return len(self.fields) == 1 and self.fields[0].has_presence() + def __init__(self, name: str): super().__init__(name) self._fields: list[ProtoMessageField] = [] + self._oneofs: list[ProtoMessage.OneOf] = [] self._dependencies: list[ProtoMessage] | None = None self._dependency_cycles: list[ProtoMessage] = [] @@ -402,9 +421,24 @@ def type(self) -> ProtoNode.Type: def fields(self) -> list[ProtoMessageField]: return list(self._fields) - def add_field(self, field: ProtoMessageField) -> None: + def oneofs(self) -> list[ProtoMessage.OneOf]: + return list(self._oneofs) + + def add_field( + self, + field: ProtoMessageField, + oneof_index: int | None = None, + ) -> None: self._fields.append(field) + if oneof_index is not None: + self._oneofs[oneof_index].fields.append(field) + # pylint: disable=protected-access + field._oneof = self._oneofs[oneof_index] + + def add_oneof(self, name) -> None: + self._oneofs.append(ProtoMessage.OneOf(name)) + def _supports_child(self, child: ProtoNode) -> bool: return ( child.type() == self.Type.ENUM or child.type() == self.Type.MESSAGE @@ -496,6 +530,7 @@ def __init__( self._has_presence: bool = has_presence self._repeated: bool = repeated self._options: CodegenOptions | None = codegen_options + self._oneof: ProtoMessage.OneOf | None = None def name(self) -> str: return self.upper_camel_case(self._field_name) @@ -529,6 +564,11 @@ def is_repeated(self) -> bool: def options(self) -> CodegenOptions | None: return self._options + def oneof(self) -> ProtoMessage.OneOf | None: + if self._oneof is not None and not self._oneof.is_synthetic(): + return self._oneof + return None + @staticmethod def upper_camel_case(field_name: str) -> str: """Converts a field name to UpperCamelCase.""" @@ -714,6 +754,10 @@ def _add_message_fields( elif codegen_options: merged_options = codegen_options + oneof_index = ( + field.oneof_index if field.HasField('oneof_index') else None + ) + message.add_field( ProtoMessageField( field.name, @@ -723,7 +767,8 @@ def _add_message_fields( has_presence, repeated, merged_options, - ) + ), + oneof_index=oneof_index, ) @@ -809,6 +854,8 @@ def _build_hierarchy( def build_message_subtree(proto_message): node = ProtoMessage(proto_message.name) + for oneof in proto_message.oneof_decl: + node.add_oneof(oneof.name) for proto_enum in proto_message.enum_type: node.add_child(ProtoEnum(proto_enum.name)) for submessage in proto_message.nested_type: diff --git a/pw_protobuf/size_report/oneof_codegen_comparison.cc b/pw_protobuf/size_report/oneof_codegen_comparison.cc index 263677292b..2640621518 100644 --- a/pw_protobuf/size_report/oneof_codegen_comparison.cc +++ b/pw_protobuf/size_report/oneof_codegen_comparison.cc @@ -331,14 +331,18 @@ PW_NO_INLINE void BasicEncode() { } which_key = KeyType::KEY_STRING; volatile bool has_timestamp = true; volatile bool has_has_value = false; - if (which_key == KeyType::KEY_STRING) { - message.key_string.SetEncoder( - [](ResponseInfo::StreamEncoder& key_string_encoder) -> pw::Status { - return key_string_encoder.WriteKeyString("test"); - }); - } else if (which_key == KeyType::KEY_TOKEN) { - message.key_token = 99999; - } + + message.key.SetEncoder( + [&which_key](ResponseInfo::StreamEncoder& key_encoder) { + if (which_key == KeyType::KEY_STRING) { + return key_encoder.WriteKeyString("test"); + } + if (which_key == KeyType::KEY_TOKEN) { + return key_encoder.WriteKeyToken(99999); + } + return Status::InvalidArgument(); + }); + message.timestamp = has_timestamp ? std::optional(1663003467) : std::nullopt; message.has_value = has_has_value ? std::optional(false) : std::nullopt; @@ -361,6 +365,27 @@ PW_NO_INLINE void BasicDecode() { } which_key = KeyType::NONE; volatile bool has_timestamp = false; volatile bool has_has_value = false; + + message.key.SetDecoder( + [](ResponseInfo::Fields field, ResponseInfo::StreamDecoder& key_decoder) { + switch (field) { + case ResponseInfo::Fields::kKeyString: { + std::array key_string; + StatusWithSize sws = key_decoder.ReadKeyString(key_string); + ConsumeValue(sws); + break; + } + case ResponseInfo::Fields::kKeyToken: { + Result key_token = key_decoder.ReadKeyToken(); + ConsumeValue(key_token); + break; + } + default: + return Status::DataLoss(); + } + return OkStatus(); + }); + if (pw::Status status = decoder.Read(message); status.ok()) { ConsumeValue(status); has_timestamp = message.timestamp.has_value(); diff --git a/pw_protobuf/size_report/proto_bloat.cc b/pw_protobuf/size_report/proto_bloat.cc index b46485692e..b4876cfc80 100644 --- a/pw_protobuf/size_report/proto_bloat.cc +++ b/pw_protobuf/size_report/proto_bloat.cc @@ -22,6 +22,7 @@ #include "pw_preprocessor/concat.h" #include "pw_protobuf/decoder.h" #include "pw_protobuf/encoder.h" +#include "pw_protobuf/internal/codegen.h" #include "pw_protobuf/stream_decoder.h" #include "pw_status/status.h" #include "pw_stream/null_stream.h" @@ -75,7 +76,7 @@ constexpr protobuf::internal::MessageField kFakeTable[] = { true, true, true, - true, + protobuf::internal::CallbackType::kSingleField, 260, 840245, nullptr}, @@ -87,7 +88,7 @@ constexpr protobuf::internal::MessageField kFakeTable[] = { true, true, true, - true, + protobuf::internal::CallbackType::kSingleField, 260, 840245, nullptr}}; diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc index f763410bd2..184ce7ab1d 100644 --- a/pw_protobuf/stream_decoder.cc +++ b/pw_protobuf/stream_decoder.cc @@ -543,13 +543,21 @@ Status StreamDecoder::Read(span message, // If the field is using callbacks, interpret the output field accordingly // and allow the caller to provide custom handling. - if (field->use_callback()) { + if (field->callback_type() == internal::CallbackType::kSingleField) { const Callback* callback = reinterpret_cast*>( out.data()); PW_TRY(callback->Decode(*this)); continue; } + if (field->callback_type() == internal::CallbackType::kOneOfGroup) { + const OneOf* callback = + reinterpret_cast*>( + out.data()); + PW_TRY(callback->Decode( + static_cast(current_field_.field_number()), *this)); + continue; + } // Switch on the expected wire type of the field, not the actual, to ensure // the remote encoder doesn't influence our decoding unexpectedly.