Skip to content

Commit

Permalink
Expose generation methods for buf (#598)
Browse files Browse the repository at this point in the history
Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
neoeinstein and LucioFranco authored Mar 18, 2022
1 parent 73dea0e commit 1f7546c
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 36 deletions.
150 changes: 115 additions & 35 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ use std::ffi::{OsStr, OsString};
use std::fmt;
use std::fs;
use std::io::{Error, ErrorKind, Result, Write};
use std::ops::RangeToInclusive;
use std::path::{Path, PathBuf};
use std::process::Command;

Expand All @@ -137,8 +138,6 @@ use crate::ident::to_snake;
use crate::message_graph::MessageGraph;
use crate::path::PathMap;

type Module = Vec<String>;

/// A service generator takes a service descriptor and generates Rust code.
///
/// `ServiceGenerator` can be used to generate application-specific interfaces
Expand Down Expand Up @@ -847,27 +846,43 @@ impl Config {
)
})?;

let modules = self.generate(file_descriptor_set.file)?;
let requests = file_descriptor_set
.file
.into_iter()
.map(|descriptor| {
(
Module::from_protobuf_package_name(descriptor.package()),
descriptor,
)
})
.collect::<Vec<_>>();

let file_names = requests
.iter()
.map(|req| {
(
req.0.clone(),
req.0.to_file_name_or(&self.default_package_filename),
)
})
.collect::<HashMap<Module, String>>();

let modules = self.generate(requests)?;
for (module, content) in &modules {
let mut filename = if module.is_empty() {
self.default_package_filename.clone()
} else {
module.join(".")
};

filename.push_str(".rs");

let output_path = target.join(&filename);
let file_name = file_names
.get(module)
.expect("every module should have a filename");
let output_path = target.join(file_name);

let previous_content = fs::read(&output_path);

if previous_content
.map(|previous_content| previous_content == content.as_bytes())
.unwrap_or(false)
{
trace!("unchanged: {:?}", filename);
trace!("unchanged: {:?}", file_name);
} else {
trace!("writing: {:?}", filename);
trace!("writing: {:?}", file_name);
fs::write(output_path, content)?;
}
}
Expand Down Expand Up @@ -896,17 +911,17 @@ impl Config {
) -> Result<usize> {
let mut written = 0;
while !entries.is_empty() {
let modident = &entries[0][depth];
let modident = entries[0].part(depth);
let matching: Vec<&Module> = entries
.iter()
.filter(|&v| &v[depth] == modident)
.filter(|&v| v.part(depth) == modident)
.copied()
.collect();
{
// Will NLL sort this mess out?
let _temp = entries
.drain(..)
.filter(|&v| &v[depth] != modident)
.filter(|&v| v.part(depth) != modident)
.collect();
entries = _temp;
}
Expand All @@ -923,7 +938,7 @@ impl Config {
)?;
written += subwritten;
if subwritten != matching.len() {
let modname = matching[0][..=depth].join(".");
let modname = matching[0].to_partial_file_name(..=depth);
if basepath.is_some() {
self.write_line(
outfile,
Expand All @@ -949,25 +964,32 @@ impl Config {
outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes())
}

fn generate(&mut self, files: Vec<FileDescriptorProto>) -> Result<HashMap<Module, String>> {
/// Processes a set of modules and file descriptors, returning a map of modules to generated
/// code contents.
///
/// This is generally used when control over the output should not be managed by Prost,
/// such as in a flow for a `protoc` code generating plugin. When compiling as part of a
/// `build.rs` file, instead use [`compile_protos()`].
pub fn generate(
&mut self,
requests: Vec<(Module, FileDescriptorProto)>,
) -> Result<HashMap<Module, String>> {
let mut modules = HashMap::new();
let mut packages = HashMap::new();

let message_graph = MessageGraph::new(&files)
let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1))
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

for file in files {
let module = self.module(&file);

for request in requests {
// Only record packages that have services
if !file.service.is_empty() {
packages.insert(module.clone(), file.package().to_string());
if !request.1.service.is_empty() {
packages.insert(request.0.clone(), request.1.package().to_string());
}

let buf = modules.entry(module).or_insert_with(String::new);
CodeGenerator::generate(self, &message_graph, &extern_paths, file, buf);
let buf = modules.entry(request.0).or_insert_with(String::new);
CodeGenerator::generate(self, &message_graph, &extern_paths, request.1, buf);
}

if let Some(ref mut service_generator) = self.service_generator {
Expand All @@ -979,14 +1001,6 @@ impl Config {

Ok(modules)
}

fn module(&self, file: &FileDescriptorProto) -> Module {
file.package()
.split('.')
.filter(|s| !s.is_empty())
.map(to_snake)
.collect()
}
}

impl default::Default for Config {
Expand Down Expand Up @@ -1031,6 +1045,72 @@ impl fmt::Debug for Config {
}
}

/// A Rust module path for a Protobuf package.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Module {
components: Vec<String>,
}

impl Module {
/// Construct a module path from an iterator of parts.
pub fn from_parts<I>(parts: I) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
{
Self {
components: parts.into_iter().map(|s| s.into()).collect(),
}
}

/// Construct a module path from a Protobuf package name.
///
/// Constituent parts are automatically converted to snake case in order to follow
/// Rust module naming conventions.
pub fn from_protobuf_package_name(name: &str) -> Self {
Self {
components: name
.split('.')
.filter(|s| !s.is_empty())
.map(to_snake)
.collect(),
}
}

/// An iterator over the parts of the path
pub fn parts(&self) -> impl Iterator<Item = &str> {
self.components.iter().map(|s| s.as_str())
}

/// Format the module path into a filename for generated Rust code.
///
/// If the module path is empty, `default` is used to provide the root of the filename.
pub fn to_file_name_or(&self, default: &str) -> String {
let mut root = if self.components.is_empty() {
default.to_owned()
} else {
self.components.join(".")
};

root.push_str(".rs");

root
}

/// The number of parts in the module's path.
pub fn len(&self) -> usize {
self.components.len()
}

fn to_partial_file_name(&self, range: RangeToInclusive<usize>) -> String {
self.components[range].join(".")
}

fn part(&self, idx: usize) -> &str {
self.components[idx].as_str()
}
}

/// Compile `.proto` files into Rust files during a Cargo build.
///
/// The generated `.rs` files are written to the Cargo `OUT_DIR` directory, suitable for use with
Expand Down
4 changes: 3 additions & 1 deletion prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ pub struct MessageGraph {
}

impl MessageGraph {
pub fn new(files: &[FileDescriptorProto]) -> Result<MessageGraph, String> {
pub fn new<'a>(
files: impl Iterator<Item = &'a FileDescriptorProto>,
) -> Result<MessageGraph, String> {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
Expand Down

0 comments on commit 1f7546c

Please sign in to comment.