From 47a67f1d6fa514f64b20f05f963de4eda168a8db Mon Sep 17 00:00:00 2001 From: Harshit Agrawal Date: Tue, 8 Oct 2024 23:22:22 +0530 Subject: [PATCH] Add event hook for redirect function in ServeHTTP --- Makefile | 2 +- runtime/events.go | 7 +++++++ runtime/gateway.go | 11 +++++++++-- runtime/gateway_test.go | 40 +++++++++++++++++++++++++++++++++++++--- runtime/router.go | 9 ++++++++- 5 files changed, 62 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index dc923d70b..8fe7a55df 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ install-packages: @echo "Mounting git pre-push hook" cp .git-pre-push-hook .git/hooks/pre-push @echo "Installing python packages..." - pip3 install --user yq + #pip3 install --user yq .PHONY: install-tools # set GO111MODULE to off to compile ancient tools within the vendor directory diff --git a/runtime/events.go b/runtime/events.go index e85def872..6c9b41e7f 100644 --- a/runtime/events.go +++ b/runtime/events.go @@ -20,6 +20,8 @@ package zanzibar +import "net/http" + // Context Variables const ( // ToCapture set to true if events have to be captured @@ -35,6 +37,7 @@ const ( type EventHandlerFn func([]Event) error type EnableEventGenFn func(string, string) bool +type RedirectFn func(w http.ResponseWriter, r *http.Request) bool type Event interface { Name() string @@ -127,3 +130,7 @@ func NoOpEventHandler(events []Event) error { func NoOpEventGen(_, _ string) bool { return false } + +func NoOpRedirectFn(_ http.ResponseWriter, _ *http.Request) bool { + return false +} diff --git a/runtime/gateway.go b/runtime/gateway.go index 76b73a39d..9d845957d 100644 --- a/runtime/gateway.go +++ b/runtime/gateway.go @@ -35,13 +35,11 @@ import ( "time" metricCollector "github.com/afex/hystrix-go/hystrix/metric_collector" - "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/uber-go/tally" "github.com/uber-go/tally/m3" jaegerConfig "github.com/uber/jaeger-client-go/config" jaegerLibTally "github.com/uber/jaeger-lib/metrics/tally" - "github.com/uber/tchannel-go" "github.com/uber/zanzibar/v2/runtime/jsonwrapper" "github.com/uber/zanzibar/v2/runtime/plugins" "go.uber.org/yarpc" @@ -83,6 +81,7 @@ type Options struct { NotFoundHandler func(*Gateway) http.HandlerFunc TracerProvider func(*Gateway) (opentracing.Tracer, io.Closer, error) EventProvider func(*Gateway) (EnableEventGenFn, EventHandlerFn) + RedirectProvider func(*Gateway) RedirectFn // If present, request uuid is retrieved from the incoming request // headers using the key, and put on the context. Otherwise, a new // uuid is created for the incoming request. @@ -114,6 +113,7 @@ type Gateway struct { JSONWrapper jsonwrapper.JSONWrapper EventHandler EventHandlerFn EnableEventGen EnableEventGenFn + RedirectHandlerFn RedirectFn // gRPC client dispatcher for gRPC client lifecycle management GRPCClientDispatcher *yarpc.Dispatcher @@ -274,6 +274,13 @@ func CreateGateway( gateway.EventHandler = NoOpEventHandler } + if opts.RedirectProvider != nil { + redirectFn := opts.RedirectProvider(gateway) + gateway.RedirectHandlerFn = redirectFn + } else { + gateway.RedirectHandlerFn = NoOpRedirectFn + } + if opts.NotFoundHandler != nil && config.ContainsKey("http.notFoundHandler.custom") && config.MustGetBoolean("http.notFoundHandler.custom") { diff --git a/runtime/gateway_test.go b/runtime/gateway_test.go index 2ca885356..b13c3ab25 100644 --- a/runtime/gateway_test.go +++ b/runtime/gateway_test.go @@ -29,11 +29,8 @@ import ( "sync" "testing" - "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" "github.com/uber-go/tally" - "github.com/uber/jaeger-client-go" - "github.com/uber/tchannel-go" "github.com/uber/zanzibar/v2/runtime/jsonwrapper" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -367,3 +364,40 @@ func TestGatewayWithEventHandler(t *testing.T) { }) } + +func TestGatewayWithRedirectHandler(t *testing.T) { + + rawCfgMap := map[string]interface{}{} + var metricsBackend tally.CachedStatsReporter + + opts := &Options{ + GetContextScopeExtractors: nil, + GetContextFieldExtractors: nil, + JSONWrapper: jsonwrapper.NewDefaultJSONWrapper(), + MetricsBackend: metricsBackend, + NotFoundHandler: func(gateway *Gateway) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) {} + }, + } + + t.Run("without redirect handler", func(t *testing.T) { + cfg := NewStaticConfigOrDie(nil, rawCfgMap) + g, err := CreateGateway(cfg, opts) + assert.Nil(t, err) + assert.NotEqual(t, NoOpRedirectFn, g.RedirectHandlerFn) + }) + + t.Run("with redirect handler", func(t *testing.T) { + redirectFn := func(_ http.ResponseWriter, _ *http.Request) bool { + return false + } + opts.RedirectProvider = func(gateway *Gateway) RedirectFn { + return redirectFn + } + cfg := NewStaticConfigOrDie(nil, rawCfgMap) + g, err := CreateGateway(cfg, opts) + assert.Nil(t, err) + assert.Equal(t, redirectFn, g.RedirectHandlerFn) + + }) +} diff --git a/runtime/router.go b/runtime/router.go index 0430de625..3ec51ff27 100644 --- a/runtime/router.go +++ b/runtime/router.go @@ -26,7 +26,6 @@ import ( "net/http" "net/url" - "github.com/opentracing/opentracing-go" "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/uber-go/tally" @@ -91,6 +90,7 @@ type RouterEndpoint struct { config *StaticConfig eventHandler EventHandlerFn enableEventGen EnableEventGenFn + redirectFn RedirectFn } // NewRouterEndpoint creates an endpoint that can be registered to HTTPRouter @@ -106,9 +106,11 @@ func NewRouterEndpoint( // continue working as is. eh := NoOpEventHandler eg := NoOpEventGen + rh := NoOpRedirectFn if deps.Gateway != nil { eh = deps.Gateway.EventHandler eg = deps.Gateway.EnableEventGen + rh = deps.Gateway.RedirectHandlerFn } return &RouterEndpoint{ @@ -123,6 +125,7 @@ func NewRouterEndpoint( config: deps.Config, eventHandler: eh, enableEventGen: eg, + redirectFn: rh, } } @@ -139,6 +142,10 @@ func (endpoint *RouterEndpoint) HandleRequest( // defer cancel() //} + if ok := endpoint.redirectFn(w, r); ok { + return + } + urlValues := ParamsFromContext(r.Context()) req := NewServerHTTPRequest(w, r, urlValues, endpoint) ctx := req.Context()