diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 678ea7b1e..edf91a6a8 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -123,6 +123,7 @@ use std::ffi::{OsStr, OsString}; use std::fmt; use std::fs; use std::io::{Error, ErrorKind, Result, Write}; +use std::ops::RangeToInclusive; use std::path::{Path, PathBuf}; use std::process::Command; @@ -137,8 +138,6 @@ use crate::ident::to_snake; use crate::message_graph::MessageGraph; use crate::path::PathMap; -type Module = Vec; - /// A service generator takes a service descriptor and generates Rust code. /// /// `ServiceGenerator` can be used to generate application-specific interfaces @@ -847,17 +846,33 @@ impl Config { ) })?; - let modules = self.generate(file_descriptor_set.file)?; + let requests = file_descriptor_set + .file + .into_iter() + .map(|descriptor| { + ( + Module::from_protobuf_package_name(descriptor.package()), + descriptor, + ) + }) + .collect::>(); + + let file_names = requests + .iter() + .map(|req| { + ( + req.0.clone(), + req.0.to_file_name_or(&self.default_package_filename), + ) + }) + .collect::>(); + + let modules = self.generate(requests)?; for (module, content) in &modules { - let mut filename = if module.is_empty() { - self.default_package_filename.clone() - } else { - module.join(".") - }; - - filename.push_str(".rs"); - - let output_path = target.join(&filename); + let file_name = file_names + .get(module) + .expect("every module should have a filename"); + let output_path = target.join(file_name); let previous_content = fs::read(&output_path); @@ -865,9 +880,9 @@ impl Config { .map(|previous_content| previous_content == content.as_bytes()) .unwrap_or(false) { - trace!("unchanged: {:?}", filename); + trace!("unchanged: {:?}", file_name); } else { - trace!("writing: {:?}", filename); + trace!("writing: {:?}", file_name); fs::write(output_path, content)?; } } @@ -896,17 +911,17 @@ impl Config { ) -> Result { let mut written = 0; while !entries.is_empty() { - let modident = &entries[0][depth]; + let modident = entries[0].part(depth); let matching: Vec<&Module> = entries .iter() - .filter(|&v| &v[depth] == modident) + .filter(|&v| v.part(depth) == modident) .copied() .collect(); { // Will NLL sort this mess out? let _temp = entries .drain(..) - .filter(|&v| &v[depth] != modident) + .filter(|&v| v.part(depth) != modident) .collect(); entries = _temp; } @@ -923,7 +938,7 @@ impl Config { )?; written += subwritten; if subwritten != matching.len() { - let modname = matching[0][..=depth].join("."); + let modname = matching[0].to_partial_file_name(..=depth); if basepath.is_some() { self.write_line( outfile, @@ -949,25 +964,32 @@ impl Config { outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes()) } - fn generate(&mut self, files: Vec) -> Result> { + /// Processes a set of modules and file descriptors, returning a map of modules to generated + /// code contents. + /// + /// This is generally used when control over the output should not be managed by Prost, + /// such as in a flow for a `protoc` code generating plugin. When compiling as part of a + /// `build.rs` file, instead use [`compile_protos()`]. + pub fn generate( + &mut self, + requests: Vec<(Module, FileDescriptorProto)>, + ) -> Result> { let mut modules = HashMap::new(); let mut packages = HashMap::new(); - let message_graph = MessageGraph::new(&files) + let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1)) .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types) .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; - for file in files { - let module = self.module(&file); - + for request in requests { // Only record packages that have services - if !file.service.is_empty() { - packages.insert(module.clone(), file.package().to_string()); + if !request.1.service.is_empty() { + packages.insert(request.0.clone(), request.1.package().to_string()); } - let buf = modules.entry(module).or_insert_with(String::new); - CodeGenerator::generate(self, &message_graph, &extern_paths, file, buf); + let buf = modules.entry(request.0).or_insert_with(String::new); + CodeGenerator::generate(self, &message_graph, &extern_paths, request.1, buf); } if let Some(ref mut service_generator) = self.service_generator { @@ -979,14 +1001,6 @@ impl Config { Ok(modules) } - - fn module(&self, file: &FileDescriptorProto) -> Module { - file.package() - .split('.') - .filter(|s| !s.is_empty()) - .map(to_snake) - .collect() - } } impl default::Default for Config { @@ -1031,6 +1045,72 @@ impl fmt::Debug for Config { } } +/// A Rust module path for a Protobuf package. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Module { + components: Vec, +} + +impl Module { + /// Construct a module path from an iterator of parts. + pub fn from_parts(parts: I) -> Self + where + I: IntoIterator, + I::Item: Into, + { + Self { + components: parts.into_iter().map(|s| s.into()).collect(), + } + } + + /// Construct a module path from a Protobuf package name. + /// + /// Constituent parts are automatically converted to snake case in order to follow + /// Rust module naming conventions. + pub fn from_protobuf_package_name(name: &str) -> Self { + Self { + components: name + .split('.') + .filter(|s| !s.is_empty()) + .map(to_snake) + .collect(), + } + } + + /// An iterator over the parts of the path + pub fn parts(&self) -> impl Iterator { + self.components.iter().map(|s| s.as_str()) + } + + /// Format the module path into a filename for generated Rust code. + /// + /// If the module path is empty, `default` is used to provide the root of the filename. + pub fn to_file_name_or(&self, default: &str) -> String { + let mut root = if self.components.is_empty() { + default.to_owned() + } else { + self.components.join(".") + }; + + root.push_str(".rs"); + + root + } + + /// The number of parts in the module's path. + pub fn len(&self) -> usize { + self.components.len() + } + + fn to_partial_file_name(&self, range: RangeToInclusive) -> String { + self.components[range].join(".") + } + + fn part(&self, idx: usize) -> &str { + self.components[idx].as_str() + } +} + /// Compile `.proto` files into Rust files during a Cargo build. /// /// The generated `.rs` files are written to the Cargo `OUT_DIR` directory, suitable for use with diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index fb4a3187e..ac0ad1523 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -15,7 +15,9 @@ pub struct MessageGraph { } impl MessageGraph { - pub fn new(files: &[FileDescriptorProto]) -> Result { + pub fn new<'a>( + files: impl Iterator, + ) -> Result { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(),