Skip to content

Commit

Permalink
feat(build): Add optional default unimplemented stubs (#1344)
Browse files Browse the repository at this point in the history
Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
xanather and LucioFranco authored Aug 14, 2023
1 parent 35cf1f2 commit aff1daf
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ members = [
"tonic-web/tests/integration",
"tests/service_named_result",
"tests/use_arc_self",
"tests/default_stubs",
]
resolver = "2"

Expand Down
20 changes: 20 additions & 0 deletions tests/default_stubs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
authors = ["Jordan Singh <[email protected]>"]
edition = "2021"
license = "MIT"
name = "default_stubs"
publish = false
version = "0.1.0"

[dependencies]
futures = "0.3"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio-stream = {version = "0.1", features = ["net"]}
prost = "0.11"
tonic = {path = "../../tonic"}

[build-dependencies]
tonic-build = {path = "../../tonic-build" }

[package.metadata.cargo-machete]
ignored = ["prost"]
9 changes: 9 additions & 0 deletions tests/default_stubs/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main() {
tonic_build::configure()
.compile(&["proto/test.proto"], &["proto"])
.unwrap();
tonic_build::configure()
.generate_default_stubs(true)
.compile(&["proto/test_default.proto"], &["proto"])
.unwrap();
}
12 changes: 12 additions & 0 deletions tests/default_stubs/proto/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
syntax = "proto3";

package test;

import "google/protobuf/empty.proto";

service Test {
rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty);
rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty);
rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty);
rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty);
}
12 changes: 12 additions & 0 deletions tests/default_stubs/proto/test_default.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
syntax = "proto3";

package test_default;

import "google/protobuf/empty.proto";

service TestDefault {
rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty);
rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty);
rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty);
rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty);
}
47 changes: 47 additions & 0 deletions tests/default_stubs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#![allow(unused_imports)]

mod test_defaults;

use futures::{Stream, StreamExt};
use std::pin::Pin;
use tonic::{Request, Response, Status, Streaming};

tonic::include_proto!("test");
tonic::include_proto!("test_default");

#[derive(Debug, Default)]
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
type ServerStreamStream = Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;
type BidirectionalStreamStream =
Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;

async fn unary(&self, _: Request<()>) -> Result<Response<()>, Status> {
Err(Status::permission_denied(""))
}

async fn server_stream(
&self,
_: Request<()>,
) -> Result<Response<Self::ServerStreamStream>, Status> {
Err(Status::permission_denied(""))
}

async fn client_stream(&self, _: Request<Streaming<()>>) -> Result<Response<()>, Status> {
Err(Status::permission_denied(""))
}

async fn bidirectional_stream(
&self,
_: Request<Streaming<()>>,
) -> Result<Response<Self::BidirectionalStreamStream>, Status> {
Err(Status::permission_denied(""))
}
}

#[tonic::async_trait]
impl test_default_server::TestDefault for Svc {
// Default unimplemented stubs provided here.
}
112 changes: 112 additions & 0 deletions tests/default_stubs/src/test_defaults.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#![allow(unused_imports)]

use crate::*;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tonic::transport::Server;

#[cfg(test)]
fn echo_requests_iter() -> impl Stream<Item = ()> {
tokio_stream::iter(1..usize::MAX).map(|_| ())
}

#[tokio::test()]
async fn test_default_stubs() {
use tonic::Code;

let addrs = run_services_in_background().await;

// First validate pre-existing functionality (trait has no default implementation, we explicitly return PermissionDenied in lib.rs).
let mut client = test_client::TestClient::connect(format!("http://{}", addrs.0))
.await
.unwrap();
assert_eq!(
client.unary(()).await.unwrap_err().code(),
Code::PermissionDenied
);
assert_eq!(
client.server_stream(()).await.unwrap_err().code(),
Code::PermissionDenied
);
assert_eq!(
client
.client_stream(echo_requests_iter().take(5))
.await
.unwrap_err()
.code(),
Code::PermissionDenied
);
assert_eq!(
client
.bidirectional_stream(echo_requests_iter().take(5))
.await
.unwrap_err()
.code(),
Code::PermissionDenied
);

// Then validate opt-in new functionality (trait has default implementation of returning Unimplemented).
let mut client_default_stubs = test_client::TestClient::connect(format!("http://{}", addrs.1))
.await
.unwrap();
assert_eq!(
client_default_stubs.unary(()).await.unwrap_err().code(),
Code::Unimplemented
);
assert_eq!(
client_default_stubs
.server_stream(())
.await
.unwrap_err()
.code(),
Code::Unimplemented
);
assert_eq!(
client_default_stubs
.client_stream(echo_requests_iter().take(5))
.await
.unwrap_err()
.code(),
Code::Unimplemented
);
assert_eq!(
client_default_stubs
.bidirectional_stream(echo_requests_iter().take(5))
.await
.unwrap_err()
.code(),
Code::Unimplemented
);
}

#[cfg(test)]
async fn run_services_in_background() -> (SocketAddr, SocketAddr) {
let svc = test_server::TestServer::new(Svc {});
let svc_default_stubs = test_default_server::TestDefaultServer::new(Svc {});

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let listener_default_stubs = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr_default_stubs = listener_default_stubs.local_addr().unwrap();

tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});

tokio::spawn(async move {
Server::builder()
.add_service(svc_default_stubs)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(
listener_default_stubs,
))
.await
.unwrap();
});

(addr, addr_default_stubs)
}
9 changes: 9 additions & 0 deletions tonic-build/src/code_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct CodeGenBuilder {
build_transport: bool,
disable_comments: HashSet<String>,
use_arc_self: bool,
generate_default_stubs: bool,
}

impl CodeGenBuilder {
Expand Down Expand Up @@ -64,6 +65,12 @@ impl CodeGenBuilder {
self
}

/// Enable or disable returning automatic unimplemented gRPC error code for generated traits.
pub fn generate_default_stubs(&mut self, generate_default_stubs: bool) -> &mut Self {
self.generate_default_stubs = generate_default_stubs;
self
}

/// Generate client code based on `Service`.
///
/// This takes some `Service` and will generate a `TokenStream` that contains
Expand Down Expand Up @@ -93,6 +100,7 @@ impl CodeGenBuilder {
&self.attributes,
&self.disable_comments,
self.use_arc_self,
self.generate_default_stubs,
)
}
}
Expand All @@ -106,6 +114,7 @@ impl Default for CodeGenBuilder {
build_transport: true,
disable_comments: HashSet::default(),
use_arc_self: false,
generate_default_stubs: false,
}
}
}
14 changes: 14 additions & 0 deletions tonic-build/src/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub fn configure() -> Builder {
emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
disable_comments: HashSet::default(),
use_arc_self: false,
generate_default_stubs: false,
}
}

Expand Down Expand Up @@ -174,6 +175,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
.attributes(self.builder.server_attributes.clone())
.disable_comments(self.builder.disable_comments.clone())
.use_arc_self(self.builder.use_arc_self)
.generate_default_stubs(self.builder.generate_default_stubs)
.generate_server(&service, &self.builder.proto_path);

self.servers.extend(server);
Expand Down Expand Up @@ -249,6 +251,7 @@ pub struct Builder {
pub(crate) emit_rerun_if_changed: bool,
pub(crate) disable_comments: HashSet<String>,
pub(crate) use_arc_self: bool,
pub(crate) generate_default_stubs: bool,

out_dir: Option<PathBuf>,
}
Expand Down Expand Up @@ -510,6 +513,17 @@ impl Builder {
self
}

/// Enable or disable directing service generation to providing a default implementation for service methods.
/// When this is false all gRPC methods must be explicitly implemented.
/// When this is true any unimplemented service methods will return 'unimplemented' gRPC error code.
/// When this is true all streaming server request RPC types explicitly use tonic::codegen::BoxStream type.
///
/// This defaults to `false`.
pub fn generate_default_stubs(mut self, enable: bool) -> Self {
self.generate_default_stubs = enable;
self
}

/// Compile the .proto files and execute code generation.
pub fn compile(
self,
Expand Down
Loading

0 comments on commit aff1daf

Please sign in to comment.