From 486cb6044548ca82209b712e3a007c759df64ff8 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Fri, 29 Mar 2024 16:13:58 +0200 Subject: [PATCH] prost-build: consolidate message field data When massaging field data in CodeGenerator::append_message, move it into lists of Field and OneofField structs so that later generation passes can operate on the data with less code duplication. Subsidiary append_* methods are changed to take references to these structs rather than moved data, as generation of lexical tokens does not actually consume any owned data, and we will need more passes over the same field lists for the upcoming builder code. --- prost-build/src/code_generator.rs | 168 ++++++++++++++++++------------ 1 file changed, 101 insertions(+), 67 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index ecd21852a..8330ad8d4 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -44,6 +44,45 @@ fn push_indent(buf: &mut String, depth: u8) { buf.push_str(" "); } } + +struct Field { + rust_name: String, + descriptor: FieldDescriptorProto, + path_index: i32, +} + +impl Field { + fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self { + Self { + rust_name: to_snake(descriptor.name()), + descriptor, + path_index, + } + } +} + +struct OneofField { + rust_name: String, + descriptor: OneofDescriptorProto, + fields: Vec<(FieldDescriptorProto, i32)>, + path_index: i32, +} + +impl OneofField { + fn new( + descriptor: OneofDescriptorProto, + fields: Vec<(FieldDescriptorProto, i32)>, + path_index: i32, + ) -> Self { + Self { + rust_name: to_snake(descriptor.name()), + descriptor, + fields, + path_index, + } + } +} + impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, @@ -159,21 +198,33 @@ impl<'a> CodeGenerator<'a> { // Split the fields into a vector of the normal fields, and oneof fields. // Path indexes are preserved so that comments can be retrieved. - type Fields = Vec<(FieldDescriptorProto, usize)>; - type OneofFields = MultiMap; - let (fields, mut oneof_fields): (Fields, OneofFields) = message + type OneofFieldsByIndex = MultiMap; + let (fields, mut oneof_map): (Vec, OneofFieldsByIndex) = message .field .into_iter() .enumerate() - .partition_map(|(idx, field)| { - if field.proto3_optional.unwrap_or(false) { - Either::Left((field, idx)) - } else if let Some(oneof_index) = field.oneof_index { - Either::Right((oneof_index, (field, idx))) + .partition_map(|(idx, proto)| { + let idx = idx as i32; + if proto.proto3_optional.unwrap_or(false) { + Either::Left(Field::new(proto, idx)) + } else if let Some(oneof_index) = proto.oneof_index { + Either::Right((oneof_index, (proto, idx))) } else { - Either::Left((field, idx)) + Either::Left(Field::new(proto, idx)) } }); + // Optional fields create a synthetic oneof that we want to skip + let oneof_fields: Vec = message + .oneof_decl + .into_iter() + .enumerate() + .filter_map(move |(idx, proto)| { + let idx = idx as i32; + oneof_map + .remove(&idx) + .map(|fields| OneofField::new(proto, fields, idx)) + }) + .collect(); self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); @@ -193,16 +244,15 @@ impl<'a> CodeGenerator<'a> { self.depth += 1; self.path.push(2); - for (field, idx) in fields { - self.path.push(idx as i32); + for field in &fields { + self.path.push(field.path_index); match field + .descriptor .type_name .as_ref() .and_then(|type_name| map_types.get(type_name)) { - Some(&(ref key, ref value)) => { - self.append_map_field(&fq_message_name, field, key, value) - } + Some((key, value)) => self.append_map_field(&fq_message_name, field, key, value), None => self.append_field(&fq_message_name, field), } self.path.pop(); @@ -210,16 +260,9 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.path.push(8); - for (idx, oneof) in message.oneof_decl.iter().enumerate() { - let idx = idx as i32; - - let fields = match oneof_fields.get_vec(&idx) { - Some(fields) => fields, - None => continue, - }; - - self.path.push(idx); - self.append_oneof_field(&message_name, &fq_message_name, oneof, fields); + for oneof in &oneof_fields { + self.path.push(oneof.path_index); + self.append_oneof_field(&message_name, &fq_message_name, oneof); self.path.pop(); } self.path.pop(); @@ -246,14 +289,8 @@ impl<'a> CodeGenerator<'a> { } self.path.pop(); - for (idx, oneof) in message.oneof_decl.into_iter().enumerate() { - let idx = idx as i32; - // optional fields create a synthetic oneof that we want to skip - let fields = match oneof_fields.remove(&idx) { - Some(fields) => fields, - None => continue, - }; - self.append_oneof(&fq_message_name, oneof, idx, fields); + for oneof in &oneof_fields { + self.append_oneof(&fq_message_name, oneof); } self.pop_mod(); @@ -362,12 +399,14 @@ impl<'a> CodeGenerator<'a> { } } - fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { + fn append_field(&mut self, fq_message_name: &str, field: &Field) { + let rust_name = &field.rust_name; + let field = &field.descriptor; let type_ = field.r#type(); let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(&field); - let optional = self.optional(&field); - let ty = self.resolve_type(&field, fq_message_name); + let deprecated = self.deprecated(field); + let optional = self.optional(field); + let ty = self.resolve_type(field, fq_message_name); let boxed = !repeated && ((type_ == Type::Message || type_ == Type::Group) @@ -396,7 +435,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); + let type_tag = self.field_type_tag(field); self.buf.push_str(&type_tag); if type_ == Type::Bytes { @@ -419,7 +458,7 @@ impl<'a> CodeGenerator<'a> { Label::Required => self.buf.push_str(", required"), Label::Repeated => { self.buf.push_str(", repeated"); - if can_pack(&field) + if can_pack(field) && !field .options .as_ref() @@ -470,7 +509,7 @@ impl<'a> CodeGenerator<'a> { self.append_field_attributes(fq_message_name, field.name()); self.push_indent(); self.buf.push_str("pub "); - self.buf.push_str(&to_snake(field.name())); + self.buf.push_str(rust_name); self.buf.push_str(": "); let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost"); @@ -498,10 +537,12 @@ impl<'a> CodeGenerator<'a> { fn append_map_field( &mut self, fq_message_name: &str, - field: FieldDescriptorProto, + field: &Field, key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { + let rust_name = &field.rust_name; + let field = &field.descriptor; let key_ty = self.resolve_type(key, fq_message_name); let value_ty = self.resolve_type(value, fq_message_name); @@ -535,7 +576,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", - to_snake(field.name()), + rust_name, map_type.rust_type(), key_ty, value_ty @@ -546,47 +587,40 @@ impl<'a> CodeGenerator<'a> { &mut self, message_name: &str, fq_message_name: &str, - oneof: &OneofDescriptorProto, - fields: &[(FieldDescriptorProto, usize)], + oneof: &OneofField, ) { - let name = format!( + let type_name = format!( "{}::{}", to_snake(message_name), - to_upper_camel(oneof.name()) + to_upper_camel(oneof.descriptor.name()) ); + let field_tags = oneof + .fields + .iter() + .map(|(field, _)| field.number()) + .join(", "); self.append_doc(fq_message_name, None); self.push_indent(); self.buf.push_str(&format!( "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - name, - fields - .iter() - .map(|&(ref field, _)| field.number()) - .join(", ") + type_name, field_tags, )); - self.append_field_attributes(fq_message_name, oneof.name()); + self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", - to_snake(oneof.name()), - name + oneof.rust_name, type_name )); } - fn append_oneof( - &mut self, - fq_message_name: &str, - oneof: OneofDescriptorProto, - idx: i32, - fields: Vec<(FieldDescriptorProto, usize)>, - ) { + fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) { self.path.push(8); - self.path.push(idx); + self.path.push(oneof.path_index); self.append_doc(fq_message_name, None); self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); + let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -599,20 +633,20 @@ impl<'a> CodeGenerator<'a> { self.append_skip_debug(&fq_message_name); self.push_indent(); self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.name())); + self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); self.buf.push_str(" {\n"); self.path.push(2); self.depth += 1; - for (field, idx) in fields { + for (field, idx) in &oneof.fields { let type_ = field.r#type(); - self.path.push(idx as i32); + self.path.push(*idx); self.append_doc(fq_message_name, Some(field.name())); self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(&field); + let ty_tag = self.field_type_tag(field); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, @@ -621,7 +655,7 @@ impl<'a> CodeGenerator<'a> { self.append_field_attributes(&oneof_name, field.name()); self.push_indent(); - let ty = self.resolve_type(&field, fq_message_name); + let ty = self.resolve_type(field, fq_message_name); let boxed = ((type_ == Type::Message || type_ == Type::Group) && self