diff --git a/runtime/events.go b/runtime/events.go index e85def872..6c9b41e7f 100644 --- a/runtime/events.go +++ b/runtime/events.go @@ -20,6 +20,8 @@ package zanzibar +import "net/http" + // Context Variables const ( // ToCapture set to true if events have to be captured @@ -35,6 +37,7 @@ const ( type EventHandlerFn func([]Event) error type EnableEventGenFn func(string, string) bool +type RedirectFn func(w http.ResponseWriter, r *http.Request) bool type Event interface { Name() string @@ -127,3 +130,7 @@ func NoOpEventHandler(events []Event) error { func NoOpEventGen(_, _ string) bool { return false } + +func NoOpRedirectFn(_ http.ResponseWriter, _ *http.Request) bool { + return false +} diff --git a/runtime/gateway.go b/runtime/gateway.go index 76b73a39d..485934d15 100644 --- a/runtime/gateway.go +++ b/runtime/gateway.go @@ -83,6 +83,7 @@ type Options struct { NotFoundHandler func(*Gateway) http.HandlerFunc TracerProvider func(*Gateway) (opentracing.Tracer, io.Closer, error) EventProvider func(*Gateway) (EnableEventGenFn, EventHandlerFn) + RedirectProvider func(*Gateway) RedirectFn // If present, request uuid is retrieved from the incoming request // headers using the key, and put on the context. Otherwise, a new // uuid is created for the incoming request. @@ -114,6 +115,7 @@ type Gateway struct { JSONWrapper jsonwrapper.JSONWrapper EventHandler EventHandlerFn EnableEventGen EnableEventGenFn + RedirectFn RedirectFn // gRPC client dispatcher for gRPC client lifecycle management GRPCClientDispatcher *yarpc.Dispatcher @@ -274,6 +276,13 @@ func CreateGateway( gateway.EventHandler = NoOpEventHandler } + if opts.RedirectProvider != nil { + redirectFn := opts.RedirectProvider(gateway) + gateway.RedirectFn = redirectFn + } else { + gateway.RedirectFn = NoOpRedirectFn + } + if opts.NotFoundHandler != nil && config.ContainsKey("http.notFoundHandler.custom") && config.MustGetBoolean("http.notFoundHandler.custom") { diff --git a/runtime/gateway_test.go b/runtime/gateway_test.go index 2ca885356..db5d61a5c 100644 --- a/runtime/gateway_test.go +++ b/runtime/gateway_test.go @@ -367,3 +367,40 @@ func TestGatewayWithEventHandler(t *testing.T) { }) } + +func TestGatewayWithRedirectHandler(t *testing.T) { + + rawCfgMap := map[string]interface{}{} + var metricsBackend tally.CachedStatsReporter + + opts := &Options{ + GetContextScopeExtractors: nil, + GetContextFieldExtractors: nil, + JSONWrapper: jsonwrapper.NewDefaultJSONWrapper(), + MetricsBackend: metricsBackend, + NotFoundHandler: func(gateway *Gateway) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) {} + }, + } + + t.Run("without redirect handler", func(t *testing.T) { + cfg := NewStaticConfigOrDie(nil, rawCfgMap) + g, err := CreateGateway(cfg, opts) + assert.Nil(t, err) + assert.NotEqual(t, NoOpRedirectFn, g.RedirectFn) + }) + + t.Run("with redirect handler", func(t *testing.T) { + redirectFn := func(_ http.ResponseWriter, _ *http.Request) bool { + return false + } + opts.RedirectProvider = func(gateway *Gateway) RedirectFn { + return redirectFn + } + cfg := NewStaticConfigOrDie(nil, rawCfgMap) + g, err := CreateGateway(cfg, opts) + assert.Nil(t, err) + assert.Equal(t, redirectFn, g.RedirectFn) + + }) +} diff --git a/runtime/router.go b/runtime/router.go index 0430de625..f3728e3f1 100644 --- a/runtime/router.go +++ b/runtime/router.go @@ -91,6 +91,7 @@ type RouterEndpoint struct { config *StaticConfig eventHandler EventHandlerFn enableEventGen EnableEventGenFn + redirectFn RedirectFn } // NewRouterEndpoint creates an endpoint that can be registered to HTTPRouter @@ -106,9 +107,11 @@ func NewRouterEndpoint( // continue working as is. eh := NoOpEventHandler eg := NoOpEventGen + rh := NoOpRedirectFn if deps.Gateway != nil { eh = deps.Gateway.EventHandler eg = deps.Gateway.EnableEventGen + rh = deps.Gateway.RedirectFn } return &RouterEndpoint{ @@ -123,6 +126,7 @@ func NewRouterEndpoint( config: deps.Config, eventHandler: eh, enableEventGen: eg, + redirectFn: rh, } } @@ -139,6 +143,10 @@ func (endpoint *RouterEndpoint) HandleRequest( // defer cancel() //} + if ok := endpoint.redirectFn(w, r); ok { + return + } + urlValues := ParamsFromContext(r.Context()) req := NewServerHTTPRequest(w, r, urlValues, endpoint) ctx := req.Context()