Skip to content

Commit

Permalink
Revive protoc plugin mode as protoc-gen-prost
Browse files Browse the repository at this point in the history
  • Loading branch information
neoeinstein committed Feb 24, 2022
1 parent ae84e7d commit 5cfb922
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ members = [
"prost-derive",
"prost-types",
"protobuf",
"protoc-gen-prost",
"tests",
"tests-2015",
"tests-no-std",
Expand Down
290 changes: 271 additions & 19 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ use std::process::Command;

use log::trace;
use prost::Message;
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use prost_types::{compiler, FileDescriptorProto, FileDescriptorSet};

pub use crate::ast::{Comments, Method, Service};
use crate::code_generator::CodeGenerator;
Expand Down Expand Up @@ -182,6 +182,25 @@ pub trait ServiceGenerator {
///
/// The default implementation is empty and does nothing.
fn finalize_package(&mut self, _package: &str, _buf: &mut String) {}

/// Handles parameters passed into a plugin for the generator as a customization point
///
/// This function will be called once for every parameter that is passed into the plugin. The
/// service generator is not expected to handle any of the paramters. Parameters may be
/// intended to modify `Config`, but are provided here in case they are of use.
///
/// The generator should only return an error when parsing the value of a parameter name it
/// expected to handle is not parsable. If it has handled a value, it should return `Ok(true)`
/// and `Ok(false)` otherwise.
///
/// The default implementation ignores all parameters and returns false.
fn handle_plugin_parameter(
&mut self,
_name: &str,
_value: &str,
) -> std::result::Result<bool, String> {
Ok(false)
}
}

/// The map collection type to output for Protobuf `map` fields.
Expand Down Expand Up @@ -847,16 +866,35 @@ impl Config {
)
})?;

let modules = self.generate(file_descriptor_set.file)?;
for (module, content) in &modules {
let mut filename = if module.is_empty() {
self.default_package_filename.clone()
let request = compiler::CodeGeneratorRequest {
file_to_generate: file_descriptor_set
.file
.iter()
.filter_map(|file| file.name.as_ref())
.map(|name| name.to_owned())
.collect(),
parameter: None,
proto_file: file_descriptor_set.file,
compiler_version: None,
};


let response = self
.codegen(request)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

for file in response
.file
.iter()
.filter(|file| file.content.is_some())
{
let filename = if let Some(name) = file.name.as_deref().filter(|n| n.is_empty()) {
name.to_string()
} else {
module.join(".")
self.default_package_filename.clone() + ".rs"
};

filename.push_str(".rs");

let content = file.content.as_deref().unwrap_or_default();
let output_path = target.join(&filename);

let previous_content = fs::read(&output_path);
Expand All @@ -875,8 +913,9 @@ impl Config {
if let Some(ref include_file) = self.include_file {
trace!("Writing include file: {:?}", target.join(include_file));
let mut file = fs::File::create(target.join(include_file))?;
let entries: Vec<Module> = response.file.iter().filter_map(|f| f.name.as_deref()).map(Self::split_module).collect();
self.write_includes(
modules.keys().collect(),
entries.iter().map(|f| f).collect(),
&mut file,
0,
if target_is_env { None } else { Some(&target) },
Expand All @@ -887,6 +926,14 @@ impl Config {
Ok(())
}

fn split_module(s: &str) -> Vec<String> {
s
.trim_end_matches(".rs")
.split(".")
.map(|s| s.to_string())
.collect::<Vec<String>>()
}

fn write_includes(
&self,
mut entries: Vec<&Module>,
Expand Down Expand Up @@ -949,16 +996,44 @@ impl Config {
outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes())
}

fn generate(&mut self, files: Vec<FileDescriptorProto>) -> Result<HashMap<Module, String>> {
let mut modules = HashMap::new();
let mut packages = HashMap::new();
/// Handle the request protoc provides to plugins and generate an appropriate response.
///
/// The protoc plugin mechanism provides a way for an arbitrary executable to specify the code
/// generation for a set of protofiles. The plugin reads `CodeGeneratorRequest` and reports a
/// `CodeGeneratorResponse`. Errors for plugins are reported via the error field within the
/// `CodeGeneratorResponse`. Errors that are the fault of protoc are reported using a non-zero
/// exit code in the plugin. An example of using this can be seen in the `protoc-gen-prost`
/// binary.
pub fn run_plugin(&mut self, request: compiler::CodeGeneratorRequest) -> compiler::CodeGeneratorResponse {
self.codegen(request)
.unwrap_or_else(|message| compiler::CodeGeneratorResponse {
error: Some(message),
..compiler::CodeGeneratorResponse::default()
})
}

let message_graph = MessageGraph::new(&files)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

for file in files {
fn codegen(
&mut self,
request: compiler::CodeGeneratorRequest,
) -> std::result::Result<compiler::CodeGeneratorResponse, String> {
self.handle_plugin_parameters(&request)?;

let mut modules = HashMap::new();
let mut packages = HashMap::new();
let files: HashMap<&String, &FileDescriptorProto> = request
.proto_file
.iter()
.filter_map(|x| x.name.as_ref().map(|name| (name, x)))
.collect();

let message_graph = MessageGraph::new(&request.proto_file)?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)?;

for filename in request.file_to_generate {
let file: &FileDescriptorProto = files
.get(&filename)
.ok_or_else(|| "filename to generate not found in protofile field".to_string())?;
let module = self.module(&file);

// Only record packages that have services
Expand All @@ -967,7 +1042,13 @@ impl Config {
}

let buf = modules.entry(module).or_insert_with(String::new);
CodeGenerator::generate(self, &message_graph, &extern_paths, file, buf);
CodeGenerator::generate(
self,
&message_graph,
&extern_paths,
file.to_owned(),
buf,
);
}

if let Some(ref mut service_generator) = self.service_generator {
Expand All @@ -977,7 +1058,67 @@ impl Config {
}
}

Ok(modules)
Ok(compiler::CodeGeneratorResponse {
error: None,
supported_features: Some(compiler::code_generator_response::Feature::Proto3Optional as u64),
file: modules
.iter()
.map(
|(module, content)| compiler::code_generator_response::File {
name: {
let mut filename = module.join(".");
filename.push_str(".rs");
Some(filename)
},
insertion_point: None,
generated_code_info: None,
content: Some(content.to_owned()),
},
)
.collect(),
})
}

fn handle_plugin_parameters(
&mut self,
request: &compiler::CodeGeneratorRequest,
) -> std::result::Result<(), String> {
if request.parameter.is_none() {
return Ok(());
}
for param in request.parameter.as_ref().unwrap().split_whitespace() {
match param.find('=') {
Some(eq) => {
let name = &param[..eq];
let value = &param[eq + 1..];
let used_by_prost = self.handle_single_parameter(name, value)?;
let used_by_service = self.service_generator.as_mut().map_or_else(
|| Ok(false),
|ref mut gen| gen.handle_plugin_parameter(name, value),
)?;

if !used_by_prost && !used_by_service {
return Err(format!("Unrecognized parameter name \"{}\"", name));
}
}
None => {
return Err(format!(
"Invalid parameter \"{}\". Expected param_name=value",
param
));
}
}
}
Ok(())
}

fn handle_single_parameter(
&mut self,
_name: &str,
_value: &str,
) -> std::result::Result<bool, String> {
// TODO: provide a way to adjust Config via options parameters
Ok(false)
}

fn module(&self, file: &FileDescriptorProto) -> Module {
Expand Down Expand Up @@ -1131,12 +1272,19 @@ mod tests {
state: Rc<RefCell<MockState>>,
}

#[derive(Debug, Default, PartialEq)]
struct Parameter {
name: String,
value: String,
}

/// Holds state for `MockServiceGenerator`
#[derive(Default)]
struct MockState {
service_names: Vec<String>,
package_names: Vec<String>,
finalized: u32,
parameters: Vec<Parameter>,
}

impl MockServiceGenerator {
Expand All @@ -1160,6 +1308,19 @@ mod tests {
let mut state = self.state.borrow_mut();
state.package_names.push(package.to_string());
}

fn handle_plugin_parameter(
&mut self,
name: &str,
value: &str,
) -> std::result::Result<bool, String> {
let mut state = self.state.borrow_mut();
state.parameters.push(Parameter {
name: name.to_owned(),
value: value.to_owned(),
});
Ok(true)
}
}

#[test]
Expand Down Expand Up @@ -1188,5 +1349,96 @@ mod tests {
assert_eq!(&state.service_names, &["Greeting", "Farewell"]);
assert_eq!(&state.package_names, &["helloworld"]);
assert_eq!(state.finalized, 3);
assert!(state.parameters.is_empty());
}

#[test]
fn no_parameters() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(compiler::CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: None,
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert!(state.parameters.is_empty());

assert_eq!(response.error, None);
assert!(response.file.is_empty());
}

#[test]
fn valid_parameter() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(compiler::CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: Some("service_quirk=0".to_string()),
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert_eq!(
state.parameters,
&[Parameter {
name: "service_quirk".to_string(),
value: "0".to_string(),
}]
);

assert_eq!(response.error, None);
assert!(response.file.is_empty());
}

#[test]
fn invalid_parameter() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(compiler::CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: Some("service_quirk".to_string()),
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert!(&state.parameters.is_empty());

assert_eq!(
response.error,
Some("Invalid parameter \"service_quirk\". Expected param_name=value".to_string())
);
assert!(response.file.is_empty());
}
}
12 changes: 12 additions & 0 deletions protoc-gen-prost/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "protoc-gen-prost"
version = "0.1.0"
authors = ["Marcus Griep <[email protected]>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
prost-build = { version = "0.9.0", path = "../prost-build", default-features = false }
prost-types = { version = "0.9.0", path = "../prost-types", default-features = false }
prost = { version = "0.9.0", path = "..", default-features = false }
Loading

0 comments on commit 5cfb922

Please sign in to comment.