diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 1aee7c35f..db404c322 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -32,8 +32,9 @@ fn main() -> io::Result<()> { Err(error) => conformance_response::Result::ParseError(format!("{:?}", error)), }; - let mut response = ConformanceResponse::default(); - response.result = Some(result); + let response = ConformanceResponse { + result: Some(result), + }; let len = response.encoded_len(); bytes.clear(); diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 8a1f61770..af50c28c4 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -16,9 +16,9 @@ use prost_types::{ use crate::ast::{Comments, Method, Service}; use crate::extern_paths::ExternPaths; -use crate::ident::{match_ident, to_snake, to_upper_camel}; +use crate::ident::{to_snake, to_upper_camel}; use crate::message_graph::MessageGraph; -use crate::Config; +use crate::{BytesType, Config, MapType}; #[derive(PartialEq)] enum Syntax { @@ -26,21 +26,6 @@ enum Syntax { Proto3, } -#[derive(PartialEq)] -enum BytesTy { - Vec, - Bytes, -} - -impl BytesTy { - fn as_str(&self) -> &'static str { - match self { - BytesTy::Vec => "\"vec\"", - BytesTy::Bytes => "\"bytes\"", - } - } -} - pub struct CodeGenerator<'a> { config: &'a mut Config, package: String, @@ -267,24 +252,25 @@ impl<'a> CodeGenerator<'a> { fn append_type_attributes(&mut self, fq_message_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); // TODO: this clone is dirty, but expedious. - for (matcher, attribute) in self.config.type_attributes.clone() { - if match_ident(&matcher, fq_message_name, None) { - self.push_indent(); - self.buf.push_str(&attribute); - self.buf.push('\n'); - } + if let Some(attributes) = self.config.type_attributes.get(fq_message_name).cloned() { + self.push_indent(); + self.buf.push_str(&attributes); + self.buf.push('\n'); } } fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); // TODO: this clone is dirty, but expedious. - for (matcher, attribute) in self.config.field_attributes.clone() { - if match_ident(&matcher, fq_message_name, Some(field_name)) { - self.push_indent(); - self.buf.push_str(&attribute); - self.buf.push('\n'); - } + if let Some(attributes) = self + .config + .field_attributes + .get_field(fq_message_name, field_name) + .cloned() + { + self.push_indent(); + self.buf.push_str(&attributes); + self.buf.push('\n'); } } @@ -321,9 +307,14 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(&type_tag); if type_ == Type::Bytes { - self.buf.push('='); + let bytes_type = self + .config + .bytes_type + .get_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default(); self.buf - .push_str(self.bytes_backing_type(&field, fq_message_name).as_str()); + .push_str(&format!("={:?}", bytes_type.annotation())); } match field.label() { @@ -431,23 +422,18 @@ impl<'a> CodeGenerator<'a> { self.append_doc(fq_message_name, Some(field.name())); self.push_indent(); - let btree_map = self + let map_type = self .config - .btree_map - .iter() - .any(|matcher| match_ident(matcher, fq_message_name, Some(field.name()))); - let (annotation_ty, lib_name, rust_ty) = if btree_map { - ("btree_map", "::prost::alloc::collections", "BTreeMap") - } else { - ("map", "::std::collections", "HashMap") - }; - + .map_type + .get_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default(); let key_tag = self.field_type_tag(key); let value_tag = self.map_value_type_tag(value); self.buf.push_str(&format!( "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", - annotation_ty, + map_type.annotation(), key_tag, value_tag, field.number() @@ -455,10 +441,9 @@ impl<'a> CodeGenerator<'a> { self.append_field_attributes(fq_message_name, field.name()); self.push_indent(); self.buf.push_str(&format!( - "pub {}: {}::{}<{}, {}>,\n", + "pub {}: {}<{}, {}>,\n", to_snake(field.name()), - lib_name, - rust_ty, + map_type.rust_type(), key_ty, value_ty )); @@ -580,12 +565,15 @@ impl<'a> CodeGenerator<'a> { } fn append_doc(&mut self, fq_name: &str, field_name: Option<&str>) { - if !self - .config - .disable_comments - .iter() - .any(|matcher| match_ident(matcher, fq_name, field_name)) - { + let append_doc = if let Some(field_name) = field_name { + self.config + .disable_comments + .get_field(fq_name, field_name) + .is_none() + } else { + self.config.disable_comments.get(fq_name).is_none() + }; + if append_doc { Comments::from_location(self.location()).append_with_indent(self.depth, &mut self.buf) } } @@ -759,10 +747,14 @@ impl<'a> CodeGenerator<'a> { Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), Type::Bool => String::from("bool"), Type::String => String::from("::prost::alloc::string::String"), - Type::Bytes => match self.bytes_backing_type(field, fq_message_name) { - BytesTy::Bytes => String::from("::prost::bytes::Bytes"), - BytesTy::Vec => String::from("::prost::alloc::vec::Vec"), - }, + Type::Bytes => self + .config + .bytes_type + .get_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default() + .rust_type() + .to_owned(), Type::Group | Type::Message => self.resolve_ident(field.type_name()), } } @@ -845,19 +837,6 @@ impl<'a> CodeGenerator<'a> { } } - fn bytes_backing_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> BytesTy { - let bytes = self - .config - .bytes - .iter() - .any(|matcher| match_ident(matcher, fq_message_name, Some(field.name()))); - if bytes { - BytesTy::Bytes - } else { - BytesTy::Vec - } - } - /// Returns `true` if the field options includes the `deprecated` option. fn deprecated(&self, field: &FieldDescriptorProto) -> bool { field @@ -1019,6 +998,42 @@ fn strip_enum_prefix<'a>(prefix: &str, name: &'a str) -> &'a str { } } +impl MapType { + /// The `prost-derive` annotation type corresponding to the map type. + fn annotation(&self) -> &'static str { + match self { + MapType::HashMap => "map", + MapType::BTreeMap => "btree_map", + } + } + + /// The fully-qualified Rust type corresponding to the map type. + fn rust_type(&self) -> &'static str { + match self { + MapType::HashMap => "::std::collections::HashMap", + MapType::BTreeMap => "::prost::alloc::collections::BTreeMap", + } + } +} + +impl BytesType { + /// The `prost-derive` annotation type corresponding to the bytes type. + fn annotation(&self) -> &'static str { + match self { + BytesType::Vec => "vec", + BytesType::Bytes => "bytes", + } + } + + /// The fully-qualified Rust type corresponding to the bytes type. + fn rust_type(&self) -> &'static str { + match self { + BytesType::Vec => "::prost::alloc::vec::Vec", + BytesType::Bytes => "::prost::bytes::Bytes", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/prost-build/src/ident.rs b/prost-build/src/ident.rs index 0b6890847..1dd6495ad 100644 --- a/prost-build/src/ident.rs +++ b/prost-build/src/ident.rs @@ -40,40 +40,6 @@ pub fn to_upper_camel(s: &str) -> String { ident } -/// Matches a 'matcher' against a fully qualified identifier. -pub fn match_ident(matcher: &str, fq_name: &str, field_name: Option<&str>) -> bool { - assert_eq!(b'.', fq_name.as_bytes()[0]); - - if matcher.is_empty() { - return false; - } else if matcher == "." { - return true; - } - - let match_paths = matcher.split('.').collect::>(); - let field_paths = { - let mut paths = fq_name.split('.').collect::>(); - if let Some(field_name) = field_name { - paths.push(field_name); - } - paths - }; - - if &matcher[..1] == "." { - // Prefix match. - if match_paths.len() > field_paths.len() { - false - } else { - match_paths[..] == field_paths[..match_paths.len()] - } - // Suffix match. - } else if match_paths.len() > field_paths.len() { - false - } else { - match_paths[..] == field_paths[field_paths.len() - match_paths.len()..] - } -} - #[cfg(test)] mod tests { @@ -193,44 +159,4 @@ mod tests { assert_eq!("FuzzBuster", &to_upper_camel("FuzzBuster")); assert_eq!("Self_", &to_upper_camel("self")); } - - #[test] - fn test_match_ident() { - // Prefix matches - assert!(match_ident(".", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident(".foo", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident(".foo.bar", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident(".foo.bar.Baz", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident( - ".foo.bar.Baz.buzz", - ".foo.bar.Baz", - Some("buzz") - )); - - assert!(!match_ident(".fo", ".foo.bar.Baz", Some("buzz"))); - assert!(!match_ident(".foo.", ".foo.bar.Baz", Some("buzz"))); - assert!(!match_ident(".buzz", ".foo.bar.Baz", Some("buzz"))); - assert!(!match_ident(".Baz.buzz", ".foo.bar.Baz", Some("buzz"))); - - // Suffix matches - assert!(match_ident("buzz", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident("Baz.buzz", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident("bar.Baz.buzz", ".foo.bar.Baz", Some("buzz"))); - assert!(match_ident( - "foo.bar.Baz.buzz", - ".foo.bar.Baz", - Some("buzz") - )); - - assert!(!match_ident("buz", ".foo.bar.Baz", Some("buzz"))); - assert!(!match_ident("uz", ".foo.bar.Baz", Some("buzz"))); - - // Type names - assert!(match_ident("Baz", ".foo.bar.Baz", None)); - assert!(match_ident(".", ".foo.bar.Baz", None)); - assert!(match_ident(".foo.bar", ".foo.bar.Baz", None)); - assert!(match_ident(".foo.bar.Baz", ".foo.bar.Baz", None)); - assert!(!match_ident(".fo", ".foo.bar.Baz", None)); - assert!(!match_ident(".buzz.Baz", ".foo.bar.Baz", None)); - } } diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 6f8d60914..6353b7d0d 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -114,6 +114,7 @@ mod code_generator; mod extern_paths; mod ident; mod message_graph; +mod path; use std::collections::HashMap; use std::default; @@ -134,6 +135,7 @@ use crate::code_generator::CodeGenerator; use crate::extern_paths::ExternPaths; use crate::ident::to_snake; use crate::message_graph::MessageGraph; +use crate::path::PathMap; type Module = Vec; @@ -182,22 +184,54 @@ pub trait ServiceGenerator { fn finalize_package(&mut self, _package: &str, _buf: &mut String) {} } +/// The map collection type to output for Protobuf `map` fields. +#[non_exhaustive] +#[derive(Clone, Copy, Debug, PartialEq)] +enum MapType { + /// The [`std::collections::HashMap`] type. + HashMap, + /// The [`std::collections::BTreeMap`] type. + BTreeMap, +} + +impl Default for MapType { + fn default() -> MapType { + MapType::HashMap + } +} + +/// The bytes collection type to output for Protobuf `bytes` fields. +#[non_exhaustive] +#[derive(Clone, Copy, Debug, PartialEq)] +enum BytesType { + /// The [`alloc::collections::Vec::`] type. + Vec, + /// The [`bytes::Bytes`] type. + Bytes, +} + +impl Default for BytesType { + fn default() -> BytesType { + BytesType::Vec + } +} + /// Configuration options for Protobuf code generation. /// /// This configuration builder can be used to set non-default code generation options. pub struct Config { file_descriptor_set_path: Option, service_generator: Option>, - btree_map: Vec, - bytes: Vec, - type_attributes: Vec<(String, String)>, - field_attributes: Vec<(String, String)>, + map_type: PathMap, + bytes_type: PathMap, + type_attributes: PathMap, + field_attributes: PathMap, prost_types: bool, strip_enum_prefix: bool, out_dir: Option, extern_paths: Vec<(String, String)>, protoc_args: Vec, - disable_comments: Vec, + disable_comments: PathMap<()>, } impl Config { @@ -259,7 +293,11 @@ impl Config { I: IntoIterator, S: AsRef, { - self.btree_map = paths.into_iter().map(|s| s.as_ref().to_string()).collect(); + self.map_type.clear(); + for matcher in paths { + self.map_type + .insert(matcher.as_ref().to_string(), MapType::BTreeMap); + } self } @@ -316,7 +354,11 @@ impl Config { I: IntoIterator, S: AsRef, { - self.bytes = paths.into_iter().map(|s| s.as_ref().to_string()).collect(); + self.bytes_type.clear(); + for matcher in paths { + self.bytes_type + .insert(matcher.as_ref().to_string(), BytesType::Bytes); + } self } @@ -349,7 +391,7 @@ impl Config { A: AsRef, { self.field_attributes - .push((path.as_ref().to_string(), attribute.as_ref().to_string())); + .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); self } @@ -398,7 +440,7 @@ impl Config { A: AsRef, { self.type_attributes - .push((path.as_ref().to_string(), attribute.as_ref().to_string())); + .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); self } @@ -440,7 +482,11 @@ impl Config { I: IntoIterator, S: AsRef, { - self.disable_comments = paths.into_iter().map(|s| s.as_ref().to_string()).collect(); + self.disable_comments.clear(); + for matcher in paths { + self.disable_comments + .insert(matcher.as_ref().to_string(), ()); + } self } @@ -556,10 +602,11 @@ impl Config { self } - /// When set, the `FileDescriptorSet` generated by `protoc` is written to the provided path. + /// When set, the `FileDescriptorSet` generated by `protoc` is written to the provided + /// filesystem path. /// - /// This option can be used in conjunction with the `include_bytes!` macro and the types in the - /// `prost-types` crate for implementing reflection capabilities, among other things. + /// This option can be used in conjunction with the [`include_bytes!`] macro and the types in + /// the `prost-types` crate for implementing reflection capabilities, among other things. /// /// ## Example /// @@ -790,16 +837,16 @@ impl default::Default for Config { Config { file_descriptor_set_path: None, service_generator: None, - btree_map: Vec::new(), - bytes: Vec::new(), - type_attributes: Vec::new(), - field_attributes: Vec::new(), + map_type: PathMap::default(), + bytes_type: PathMap::default(), + type_attributes: PathMap::default(), + field_attributes: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, extern_paths: Vec::new(), protoc_args: Vec::new(), - disable_comments: Vec::new(), + disable_comments: PathMap::default(), } } } @@ -807,13 +854,21 @@ impl default::Default for Config { impl fmt::Debug for Config { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Config") - .field("btree_map", &self.btree_map) + .field("file_descriptor_set_path", &self.file_descriptor_set_path) + .field( + "service_generator", + &self.file_descriptor_set_path.is_some(), + ) + .field("map_type", &self.map_type) + .field("bytes_type", &self.bytes_type) .field("type_attributes", &self.type_attributes) .field("field_attributes", &self.field_attributes) .field("prost_types", &self.prost_types) .field("strip_enum_prefix", &self.strip_enum_prefix) .field("out_dir", &self.out_dir) .field("extern_paths", &self.extern_paths) + .field("protoc_args", &self.protoc_args) + .field("disable_comments", &self.disable_comments) .finish() } } diff --git a/prost-build/src/path.rs b/prost-build/src/path.rs new file mode 100644 index 000000000..619d0abf0 --- /dev/null +++ b/prost-build/src/path.rs @@ -0,0 +1,165 @@ +//! Utilities for working with Protobuf paths. + +use std::collections::HashMap; +use std::iter; + +/// Maps a fully-qualified Protobuf path to a value using path matchers. +#[derive(Debug, Default)] +pub(crate) struct PathMap { + matchers: HashMap, +} + +impl PathMap { + /// Inserts a new matcher and associated value to the path map. + pub(crate) fn insert(&mut self, matcher: String, value: T) { + self.matchers.insert(matcher, value); + } + + /// Returns the value which matches the provided fully-qualified Protobuf path. + pub(crate) fn get(&self, fq_path: &'_ str) -> Option<&T> { + // First, try matching the full path. + iter::once(fq_path) + // Then, try matching path suffixes. + .chain(suffixes(fq_path)) + // Then, try matching path prefixes. + .chain(prefixes(fq_path)) + // Then, match the global path. This matcher must never fail, since the constructor + // initializes it. + .chain(iter::once(".")) + .flat_map(|path| self.matchers.get(path)) + .next() + } + + /// Returns the value which matches the provided fully-qualified Protobuf path and field name. + pub(crate) fn get_field(&self, fq_path: &'_ str, field: &'_ str) -> Option<&T> { + let full_path = format!("{}.{}", fq_path, field); + let full_path = full_path.as_str(); + + // First, try matching the path. + let value = iter::once(full_path) + // Then, try matching path suffixes. + .chain(suffixes(full_path)) + // Then, try matching path suffixes without the field name. + .chain(suffixes(fq_path)) + // Then, try matching path prefixes. + .chain(prefixes(full_path)) + // Then, match the global path. This matcher must never fail, since the constructor + // initializes it. + .chain(iter::once(".")) + .flat_map(|path| self.matchers.get(path)) + .next(); + + value + } + + /// Removes all matchers from the path map. + pub(crate) fn clear(&mut self) { + self.matchers.clear(); + } +} + +/// Given a fully-qualified path, returns a sequence of fully-qualified paths which match a prefix +/// of the input path, in decreasing path-length order. +/// +/// Example: prefixes(".a.b.c.d") -> [".a.b.c", ".a.b", ".a"] +fn prefixes(fq_path: &str) -> impl Iterator { + std::iter::successors(Some(fq_path), |path| { + path.rsplitn(2, '.').nth(1).filter(|path| !path.is_empty()) + }) + .skip(1) +} + +/// Given a fully-qualified path, returns a sequence of paths which match the suffix of the input +/// path, in decreasing path-length order. +/// +/// Example: suffixes(".a.b.c.d") -> ["a.b.c.d", "b.c.d", "c.d", "d"] +fn suffixes(fq_path: &str) -> impl Iterator { + std::iter::successors(Some(fq_path), |path| { + path.splitn(2, '.').nth(1).filter(|path| !path.is_empty()) + }) + .skip(1) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_prefixes() { + assert_eq!( + prefixes(".a.b.c.d").collect::>(), + vec![".a.b.c", ".a.b", ".a"], + ); + assert_eq!(prefixes(".a").count(), 0); + assert_eq!(prefixes(".").count(), 0); + } + + #[test] + fn test_suffixes() { + assert_eq!( + suffixes(".a.b.c.d").collect::>(), + vec!["a.b.c.d", "b.c.d", "c.d", "d"], + ); + assert_eq!(suffixes(".a").collect::>(), vec!["a"]); + assert_eq!(suffixes(".").collect::>(), Vec::<&str>::new()); + } + + #[test] + fn test_path_map_get() { + let mut path_map = PathMap::default(); + path_map.insert(".a.b.c.d".to_owned(), 1); + path_map.insert(".a.b".to_owned(), 2); + path_map.insert("M1".to_owned(), 3); + path_map.insert("M1.M2".to_owned(), 4); + path_map.insert("M1.M2.f1".to_owned(), 5); + path_map.insert("M1.M2.f2".to_owned(), 6); + + assert_eq!(None, path_map.get(".a.other")); + assert_eq!(None, path_map.get(".a.bother")); + assert_eq!(None, path_map.get(".other")); + assert_eq!(None, path_map.get(".M1.other")); + assert_eq!(None, path_map.get(".M1.M2.other")); + + assert_eq!(Some(&1), path_map.get(".a.b.c.d")); + assert_eq!(Some(&1), path_map.get(".a.b.c.d.other")); + + assert_eq!(Some(&2), path_map.get(".a.b")); + assert_eq!(Some(&2), path_map.get(".a.b.c")); + assert_eq!(Some(&2), path_map.get(".a.b.other")); + assert_eq!(Some(&2), path_map.get(".a.b.other.Other")); + assert_eq!(Some(&2), path_map.get(".a.b.c.dother")); + + assert_eq!(Some(&3), path_map.get(".M1")); + assert_eq!(Some(&3), path_map.get(".a.b.c.d.M1")); + assert_eq!(Some(&3), path_map.get(".a.b.M1")); + + assert_eq!(Some(&4), path_map.get(".M1.M2")); + assert_eq!(Some(&4), path_map.get(".a.b.c.d.M1.M2")); + assert_eq!(Some(&4), path_map.get(".a.b.M1.M2")); + + assert_eq!(Some(&5), path_map.get(".M1.M2.f1")); + assert_eq!(Some(&5), path_map.get(".a.M1.M2.f1")); + assert_eq!(Some(&5), path_map.get(".a.b.M1.M2.f1")); + + assert_eq!(Some(&6), path_map.get(".M1.M2.f2")); + assert_eq!(Some(&6), path_map.get(".a.M1.M2.f2")); + assert_eq!(Some(&6), path_map.get(".a.b.M1.M2.f2")); + + // get_field + + assert_eq!(Some(&2), path_map.get_field(".a.b.Other", "other")); + + assert_eq!(Some(&4), path_map.get_field(".M1.M2", "other")); + assert_eq!(Some(&4), path_map.get_field(".a.M1.M2", "other")); + assert_eq!(Some(&4), path_map.get_field(".a.b.M1.M2", "other")); + + assert_eq!(Some(&5), path_map.get_field(".M1.M2", "f1")); + assert_eq!(Some(&5), path_map.get_field(".a.M1.M2", "f1")); + assert_eq!(Some(&5), path_map.get_field(".a.b.M1.M2", "f1")); + + assert_eq!(Some(&6), path_map.get_field(".M1.M2", "f2")); + assert_eq!(Some(&6), path_map.get_field(".a.M1.M2", "f2")); + assert_eq!(Some(&6), path_map.get_field(".a.b.M1.M2", "f2")); + } +} diff --git a/src/encoding.rs b/src/encoding.rs index d54efc2cc..c599f662d 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -314,7 +314,7 @@ pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut B) where B: BufMut, { - debug_assert!(tag >= MIN_TAG && tag <= MAX_TAG); + debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag)); let key = (tag << 3) | wire_type as u32; encode_varint(u64::from(key), buf); }