Skip to content

Commit

Permalink
Add Schema field to Spec for introspection (#629)
Browse files Browse the repository at this point in the history
Adds field Schema of type any on connect.Spec and
handler and client option WithSchema to populate it.
The protoc-gen-connect-go generator generates
code that uses this new option to populate the schema
with the protoreflect.MethodDescriptor corresponding
to the RPC.
  • Loading branch information
emcfarlane authored Nov 9, 2023
1 parent 20b5723 commit 8292c67
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 42 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ type clientConfig struct {
URL *url.URL
Protocol protocol
Procedure string
Schema any
CompressMinBytes int
Interceptor Interceptor
CompressionPools map[string]*compressionPool
Expand Down Expand Up @@ -251,6 +252,7 @@ func (c *clientConfig) newSpec(t StreamType) Spec {
return Spec{
StreamType: t,
Procedure: c.Procedure,
Schema: c.Schema,
IsClient: true,
IdempotencyLevel: c.IdempotencyLevel,
}
Expand Down
87 changes: 87 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package connect_test
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"testing"
Expand All @@ -26,6 +27,7 @@ import (
pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1"
"connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect"
"connectrpc.com/connect/internal/memhttp/memhttptest"
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestNewClient_InitFailure(t *testing.T) {
Expand Down Expand Up @@ -186,6 +188,44 @@ func TestGetNotModified(t *testing.T) {
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

func TestSpecSchema(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(
pingServer{},
connect.WithInterceptors(&assertSchemaInterceptor{t}),
))
server := memhttptest.NewServer(t, mux)
ctx := context.Background()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
connect.WithInterceptors(&assertSchemaInterceptor{t}),
)
t.Run("unary", func(t *testing.T) {
t.Parallel()
unaryReq := connect.NewRequest[pingv1.PingRequest](nil)
_, err := client.Ping(ctx, unaryReq)
assert.NotNil(t, unaryReq.Spec().Schema)
assert.Nil(t, err)
text := strings.Repeat(".", 256)
r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text}))
assert.Nil(t, err)
assert.Equal(t, r.Msg.Text, text)
})
t.Run("bidi_stream", func(t *testing.T) {
t.Parallel()
bidiStream := client.CumSum(ctx)
t.Cleanup(func() {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotZero(t, bidiStream.Spec().Schema)
err := bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
})
}

type notModifiedPingServer struct {
pingv1connect.UnimplementedPingServiceHandler

Expand Down Expand Up @@ -233,3 +273,50 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl
return next(ctx, conn)
}
}

type assertSchemaInterceptor struct {
tb testing.TB
}

func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if !assert.NotNil(a.tb, req.Spec().Schema) {
return next(ctx, req)
}
methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name())
assert.Equal(a.tb, procedure, req.Spec().Procedure)
}
return next(ctx, req)
}
}

func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
if !assert.NotNil(a.tb, spec.Schema) {
return conn
}
methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name())
assert.Equal(a.tb, procedure, spec.Procedure)
}
return conn
}
}

func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
if !assert.NotNil(a.tb, conn.Spec().Schema) {
return next(ctx, conn)
}
methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name())
assert.Equal(a.tb, procedure, conn.Spec().Procedure)
}
return next(ctx, conn)
}
}
53 changes: 30 additions & 23 deletions cmd/protoc-gen-connect-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,6 @@ func main() {
)
}

func needsWithIdempotency(file *protogen.File) bool {
for _, service := range file.Services {
for _, method := range service.Methods {
if methodIdempotency(method) != connect.IdempotencyUnknown {
return true
}
}
}
return false
}

func generate(plugin *protogen.Plugin, file *protogen.File) {
if len(file.Services) == 0 {
return
Expand All @@ -135,6 +124,7 @@ func generate(plugin *protogen.Plugin, file *protogen.File) {
generatedFile.Import(file.GoImportPath)
generatePreamble(generatedFile, file)
generateServiceNameConstants(generatedFile, file.Services)
generateServiceNameVariables(generatedFile, file)
for _, service := range file.Services {
generateService(generatedFile, service)
}
Expand Down Expand Up @@ -180,11 +170,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) {
"is not defined, this code was generated with a version of connect newer than the one ",
"compiled into your binary. You can fix the problem by either regenerating this code ",
"with an older version of connect or updating the connect version compiled into your binary.")
if needsWithIdempotency(file) {
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_7_0"))
} else {
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_1_0"))
}
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_13_0"))
g.P()
}

Expand Down Expand Up @@ -225,6 +211,23 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge
g.P()
}

func generateServiceNameVariables(g *protogen.GeneratedFile, file *protogen.File) {
wrapComments(g, "These variables are the protoreflect.Descriptor objects for the RPCs defined in this package.")
g.P("var (")
for _, service := range file.Services {
serviceDescName := unexport(fmt.Sprintf("%sServiceDescriptor", service.Desc.Name()))
g.P(serviceDescName, ` = `,
g.QualifiedGoIdent(file.GoDescriptorIdent),
`.Services().ByName("`, service.Desc.Name(), `")`)
for _, method := range service.Methods {
g.P(procedureVarMethodDescriptor(method), ` = `,
serviceDescName,
`.Methods().ByName("`, method.Desc.Name(), `")`)
}
}
g.P(")")
}

func generateService(g *protogen.GeneratedFile, service *protogen.Service) {
names := newNames(service)
generateClientInterface(g, service, names)
Expand Down Expand Up @@ -273,7 +276,9 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
}
g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"),
", baseURL string, opts ...", clientOption, ") ", names.Client, " {")
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
if len(service.Methods) > 0 {
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
}
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), ": ",
Expand All @@ -283,17 +288,16 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
)
g.P("httpClient,")
g.P(`baseURL + `, procedureConstName(method), `,`)
g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),")
idempotency := methodIdempotency(method)
switch idempotency {
case connect.IdempotencyNoSideEffects:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),")
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
case connect.IdempotencyIdempotent:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),")
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
case connect.IdempotencyUnknown:
g.P("opts...,")
}
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
g.P("),")
}
g.P("}")
Expand Down Expand Up @@ -419,16 +423,15 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv
}
g.P(procedureConstName(method), `,`)
g.P("svc.", method.GoName, ",")
g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),")
switch idempotency {
case connect.IdempotencyNoSideEffects:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),")
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
case connect.IdempotencyIdempotent:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),")
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
case connect.IdempotencyUnknown:
g.P("opts...,")
}
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
g.P(")")
}
g.P(`return "/`, service.Desc.FullName(), `/", `, httpPackage.Ident("HandlerFunc"), `(func(w `, httpPackage.Ident("ResponseWriter"), `, r *`, httpPackage.Ident("Request"), `){`)
Expand Down Expand Up @@ -516,6 +519,10 @@ func procedureHandlerName(m *protogen.Method) string {
return fmt.Sprintf("%s%sHandler", unexport(m.Parent.GoName), m.GoName)
}

func procedureVarMethodDescriptor(m *protogen.Method) string {
return unexport(fmt.Sprintf("%s%sMethodDescriptor", m.Parent.GoName, m.GoName))
}

func isDeprecatedService(service *protogen.Service) bool {
serviceOptions, ok := service.Desc.Options().(*descriptorpb.ServiceOptions)
return ok && serviceOptions.GetDeprecated()
Expand Down
8 changes: 5 additions & 3 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ const Version = "1.13.0-dev"
// These constants are used in compile-time handshakes with connect's generated
// code.
const (
IsAtLeastVersion0_0_1 = true
IsAtLeastVersion0_1_0 = true
IsAtLeastVersion1_7_0 = true
IsAtLeastVersion0_0_1 = true
IsAtLeastVersion0_1_0 = true
IsAtLeastVersion1_7_0 = true
IsAtLeastVersion1_13_0 = true
)

// StreamType describes whether the client, server, neither, or both is
Expand Down Expand Up @@ -314,6 +315,7 @@ type HTTPClient interface {
// fully-qualified Procedure corresponding to each RPC in your schema.
type Spec struct {
StreamType StreamType
Schema any // for protobuf RPCs, a protoreflect.MethodDescriptor
Procedure string // for example, "/acme.foo.v1.FooService/Bar"
IsClient bool // otherwise we're in a handler
IdempotencyLevel IdempotencyLevel
Expand Down
2 changes: 2 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type handlerConfig struct {
CompressMinBytes int
Interceptor Interceptor
Procedure string
Schema any
HandleGRPC bool
HandleGRPCWeb bool
RequireConnectProtocolHeader bool
Expand Down Expand Up @@ -279,6 +280,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler
func (c *handlerConfig) newSpec() Spec {
return Spec{
Procedure: c.Procedure,
Schema: c.Schema,
StreamType: c.StreamType,
IdempotencyLevel: c.IdempotencyLevel,
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions internal/gen/connect/import/v1/importv1connect/import.connect.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8292c67

Please sign in to comment.