diff --git a/api/handlers/endpoint.go b/api/handlers/endpoint.go index 5c3c7f058b..8135d01b05 100644 --- a/api/handlers/endpoint.go +++ b/api/handlers/endpoint.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/frain-dev/convoy/internal/pkg/fflag" "github.com/frain-dev/convoy/pkg/circuit_breaker" "github.com/frain-dev/convoy/pkg/msgpack" "net/http" @@ -211,29 +212,31 @@ func (h *Handler) GetEndpoints(w http.ResponseWriter, r *http.Request) { return } - // fetch keys from redis and mutate endpoints slice - keys := make([]string, len(endpoints)) - for i := 0; i < len(endpoints); i++ { - keys[i] = fmt.Sprintf("breaker:%s", endpoints[i].UID) - } + if h.A.FFlag.CanAccessFeature(fflag.CircuitBreaker) && h.A.Licenser.CircuitBreaking() && len(endpoints) > 0 { + // fetch keys from redis and mutate endpoints slice + keys := make([]string, len(endpoints)) + for i := 0; i < len(endpoints); i++ { + keys[i] = fmt.Sprintf("breaker:%s", endpoints[i].UID) + } - cbs, err := h.A.Redis.MGet(r.Context(), keys...).Result() - if err != nil { - _ = render.Render(w, r, util.NewServiceErrResponse(err)) - return - } + cbs, err := h.A.Redis.MGet(r.Context(), keys...).Result() + if err != nil { + _ = render.Render(w, r, util.NewServiceErrResponse(err)) + return + } - for i := 0; i < len(cbs); i++ { - if cbs[i] != nil { - str, ok := cbs[i].(string) - if ok { - var c circuit_breaker.CircuitBreaker - asBytes := []byte(str) - innerErr := msgpack.DecodeMsgPack(asBytes, &c) - if innerErr != nil { - continue + for i := 0; i < len(cbs); i++ { + if cbs[i] != nil { + str, ok := cbs[i].(string) + if ok { + var c circuit_breaker.CircuitBreaker + asBytes := []byte(str) + innerErr := msgpack.DecodeMsgPack(asBytes, &c) + if innerErr != nil { + continue + } + endpoints[i].FailureRate = c.FailureRate } - endpoints[i].FailureRate = c.FailureRate } } } @@ -505,6 +508,11 @@ func (h *Handler) PauseEndpoint(w http.ResponseWriter, r *http.Request) { // @Security ApiKeyAuth // @Router /v1/projects/{projectID}/endpoints/{endpointID}/activate [post] func (h *Handler) ActivateEndpoint(w http.ResponseWriter, r *http.Request) { + if !h.A.Licenser.CircuitBreaking() || !h.A.FFlag.CanAccessFeature(fflag.CircuitBreaker) { + _ = render.Render(w, r, util.NewErrorResponse("feature not enabled", http.StatusBadRequest)) + return + } + project, err := h.retrieveProject(r) if err != nil { _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) diff --git a/api/server_suite_test.go b/api/server_suite_test.go index ddb2ee9ac6..86bea52029 100644 --- a/api/server_suite_test.go +++ b/api/server_suite_test.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/frain-dev/convoy/internal/pkg/fflag" "io" "math/rand" "net/http" @@ -137,6 +138,7 @@ func buildServer() *ApplicationHandler { Redis: rd.Client(), Logger: logger, Cache: noopCache, + FFlag: fflag.NewFFlag([]string{string(fflag.Prometheus), string(fflag.FullTextSearch)}), Rate: r, Licenser: noopLicenser.NewLicenser(), }) diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 08274b9841..a68c684f5f 100644 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -137,7 +137,7 @@ func startServerComponent(_ context.Context, a *cli.App) error { lo.WithError(err).Fatal("failed to initialize realm chain") } - flag := fflag.NewFFlag(&cfg) + flag := fflag.NewFFlag(cfg.EnableFeatureFlag) lvl, err := log.ParseLevel(cfg.Logger.Level) if err != nil { diff --git a/cmd/ff/feature_flags.go b/cmd/ff/feature_flags.go index 88e578bf50..ce8bacf2bd 100644 --- a/cmd/ff/feature_flags.go +++ b/cmd/ff/feature_flags.go @@ -21,7 +21,7 @@ func AddFeatureFlagsCommand() *cobra.Command { log.WithError(err).Fatal("Error fetching the config.") } - f := fflag2.NewFFlag(&cfg) + f := fflag2.NewFFlag(cfg.EnableFeatureFlag) return f.ListFeatures() }, PersistentPostRun: func(cmd *cobra.Command, args []string) {}, diff --git a/cmd/hooks/hooks.go b/cmd/hooks/hooks.go index 8cbcadfac8..9dee06fc94 100644 --- a/cmd/hooks/hooks.go +++ b/cmd/hooks/hooks.go @@ -519,13 +519,31 @@ func buildCliConfiguration(cmd *cobra.Command) (*config.Configuration, error) { c.RetentionPolicy.IsRetentionPolicyEnabled = retentionPolicyEnabled } - // Feature flags + // CONVOY_ENABLE_FEATURE_FLAG fflag, err := cmd.Flags().GetStringSlice("enable-feature-flag") if err != nil { return nil, err } c.EnableFeatureFlag = fflag + // CONVOY_DISPATCHER_BLOCK_LIST + ipBlockList, err := cmd.Flags().GetStringSlice("ip-block-list") + if err != nil { + return nil, err + } + if len(ipBlockList) > 0 { + c.Dispatcher.BlockList = ipBlockList + } + + // CONVOY_DISPATCHER_ALLOW_LIST + ipAllowList, err := cmd.Flags().GetStringSlice("ip-allow-list") + if err != nil { + return nil, err + } + if len(ipAllowList) > 0 { + c.Dispatcher.AllowList = ipAllowList + } + // tracing tracingProvider, err := cmd.Flags().GetString("tracer-type") if err != nil { @@ -585,7 +603,7 @@ func buildCliConfiguration(cmd *cobra.Command) (*config.Configuration, error) { } - flag := fflag2.NewFFlag(c) + flag := fflag2.NewFFlag(c.EnableFeatureFlag) c.Metrics = config.MetricsConfiguration{ IsEnabled: false, } diff --git a/cmd/main.go b/cmd/main.go index 22d530c523..a7c3b01b41 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -49,6 +49,8 @@ func main() { var dbDatabase string var fflag []string + var ipAllowList []string + var ipBLockList []string var enableProfiling bool var redisPort int @@ -105,6 +107,9 @@ func main() { // misc c.Flags().StringSliceVar(&fflag, "enable-feature-flag", []string{}, "List of feature flags to enable e.g. \"full-text-search,prometheus\"") + c.Flags().StringSliceVar(&ipAllowList, "ip-allow-list", []string{}, "List of IPs CIDRs to allow e.g. \" 0.0.0.0/0,127.0.0.0/8\"") + c.Flags().StringSliceVar(&ipBLockList, "ip-block-list", []string{}, "List of IPs CIDRs to block e.g. \" 0.0.0.0/0,127.0.0.0/8\"") + c.Flags().IntVar(&instanceIngestRate, "instance-ingest-rate", 0, "Instance ingest Rate") c.Flags().IntVar(&apiRateLimit, "api-rate-limit", 0, "API rate limit") diff --git a/cmd/server/server.go b/cmd/server/server.go index edfe0c7a51..b9d1247321 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -107,7 +107,7 @@ func startConvoyServer(a *cli.App) error { a.Logger.WithError(err).Fatal("failed to initialize realm chain") } - flag := fflag.NewFFlag(&cfg) + flag := fflag.NewFFlag(cfg.EnableFeatureFlag) if cfg.Server.HTTP.Port <= 0 { return errors.New("please provide the HTTP port in the convoy.json file") diff --git a/cmd/worker/worker.go b/cmd/worker/worker.go index 547c28b24b..6b82eb275f 100644 --- a/cmd/worker/worker.go +++ b/cmd/worker/worker.go @@ -249,18 +249,26 @@ func StartWorker(ctx context.Context, a *cli.App, cfg config.Configuration, inte go memorystore.DefaultStore.Sync(ctx, interval) + featureFlag := fflag.NewFFlag(cfg.EnableFeatureFlag) newTelemetry := telemetry.NewTelemetry(lo, configuration, telemetry.OptionTracker(counter), telemetry.OptionBackend(pb), telemetry.OptionBackend(mb)) - dispatcher, err := net.NewDispatcher(cfg.Server.HTTP.HttpProxy, a.Licenser, false) + dispatcher, err := net.NewDispatcher( + a.Licenser, + featureFlag, + net.LoggerOption(lo), + net.ProxyOption(cfg.Server.HTTP.HttpProxy), + net.AllowListOption(cfg.Dispatcher.AllowList), + net.BlockListOption(cfg.Dispatcher.BlockList), + net.InsecureSkipVerifyOption(cfg.Dispatcher.InsecureSkipVerify), + ) if err != nil { lo.WithError(err).Fatal("Failed to create new net dispatcher") return err } - featureFlag := fflag.NewFFlag(&cfg) var circuitBreakerManager *cb.CircuitBreakerManager if featureFlag.CanAccessFeature(fflag.CircuitBreaker) { diff --git a/config/config.go b/config/config.go index 613fa4f561..ed6ee1fe2d 100644 --- a/config/config.go +++ b/config/config.go @@ -107,6 +107,11 @@ var DefaultConfiguration = Configuration{ SampleTime: 5, }, }, + Dispatcher: DispatcherConfiguration{ + InsecureSkipVerify: true, + AllowList: []string{"0.0.0.0/0", "::/0"}, + BlockList: []string{"127.0.0.0/8", "::1/128"}, + }, InstanceIngestRate: 25, ApiRateLimit: 25, WorkerExecutionMode: DefaultExecutionMode, @@ -388,6 +393,13 @@ type Configuration struct { WorkerExecutionMode ExecutionMode `json:"worker_execution_mode" envconfig:"CONVOY_WORKER_EXECUTION_MODE"` MaxRetrySeconds uint64 `json:"max_retry_seconds,omitempty" envconfig:"CONVOY_MAX_RETRY_SECONDS"` LicenseKey string `json:"license_key" envconfig:"CONVOY_LICENSE_KEY"` + Dispatcher DispatcherConfiguration `json:"dispatcher"` +} + +type DispatcherConfiguration struct { + InsecureSkipVerify bool `json:"insecure_skip_verify" envconfig:"CONVOY_DISPATCHER_INSECURE_SKIP_VERIFY"` + AllowList []string `json:"allow_list" envconfig:"CONVOY_DISPATCHER_ALLOW_LIST"` + BlockList []string `json:"block_list" envconfig:"CONVOY_DISPATCHER_BLOCK_LIST"` } type PyroscopeConfiguration struct { diff --git a/config/config_test.go b/config/config_test.go index 7d594cf955..6ed2ee5250 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -178,6 +178,11 @@ func TestLoadConfig(t *testing.T) { SampleTime: 5, }, }, + Dispatcher: DispatcherConfiguration{ + InsecureSkipVerify: true, + AllowList: []string{"0.0.0.0/0", "::/0"}, + BlockList: []string{"127.0.0.0/8", "::1/128"}, + }, WorkerExecutionMode: DefaultExecutionMode, InstanceIngestRate: 25, ApiRateLimit: 25, @@ -265,6 +270,11 @@ func TestLoadConfig(t *testing.T) { SampleTime: 5, }, }, + Dispatcher: DispatcherConfiguration{ + InsecureSkipVerify: true, + AllowList: []string{"0.0.0.0/0", "::/0"}, + BlockList: []string{"127.0.0.0/8", "::1/128"}, + }, InstanceIngestRate: 25, ApiRateLimit: 25, WorkerExecutionMode: DefaultExecutionMode, @@ -351,6 +361,11 @@ func TestLoadConfig(t *testing.T) { SampleTime: 5, }, }, + Dispatcher: DispatcherConfiguration{ + InsecureSkipVerify: true, + AllowList: []string{"0.0.0.0/0", "::/0"}, + BlockList: []string{"127.0.0.0/8", "::1/128"}, + }, InstanceIngestRate: 25, ApiRateLimit: 25, WorkerExecutionMode: DefaultExecutionMode, diff --git a/go.mod b/go.mod index f232891c2a..b1a78203c0 100644 --- a/go.mod +++ b/go.mod @@ -205,6 +205,7 @@ require ( github.com/shirou/gopsutil/v3 v3.23.12 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect + github.com/stealthrocket/netjail v0.1.2 // indirect github.com/theupdateframework/notary v0.7.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 8366a5e832..4e69fd8b38 100644 --- a/go.sum +++ b/go.sum @@ -1841,6 +1841,8 @@ github.com/spf13/viper v0.0.0-20150530192845-be5ff3e4840c/go.mod h1:A8kyI5cUJhb8 github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/spf13/viper v1.8.1 h1:Kq1fyeebqsBfbjZj4EL7gj2IO0mMaiyjYUWcUsl2O44= github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= +github.com/stealthrocket/netjail v0.1.2 h1:nOgFLer7XrkYcn8cJk5kI9aUFRkV7LC/8VjmJ2GjBQU= +github.com/stealthrocket/netjail v0.1.2/go.mod h1:LmslfwZTxTchb7koch3C/MNvEzF111G9HwZQrT23No4= github.com/stefanberger/go-pkcs11uri v0.0.0-20201008174630-78d3cae3a980/go.mod h1:AO3tvPzVZ/ayst6UlUKUv6rcPQInYe3IknH3jYhAKu8= github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= diff --git a/internal/pkg/fflag/fflag.go b/internal/pkg/fflag/fflag.go index 1492fca812..31ea935577 100644 --- a/internal/pkg/fflag/fflag.go +++ b/internal/pkg/fflag/fflag.go @@ -3,7 +3,6 @@ package fflag import ( "errors" "fmt" - "github.com/frain-dev/convoy/config" "os" "sort" "text/tabwriter" @@ -18,9 +17,10 @@ type ( ) const ( + IpRules FeatureFlagKey = "ip-rules" Prometheus FeatureFlagKey = "prometheus" - FullTextSearch FeatureFlagKey = "full-text-search" CircuitBreaker FeatureFlagKey = "circuit-breaker" + FullTextSearch FeatureFlagKey = "full-text-search" ) type ( @@ -33,6 +33,7 @@ const ( ) var DefaultFeaturesState = map[FeatureFlagKey]FeatureFlagState{ + IpRules: disabled, Prometheus: disabled, FullTextSearch: disabled, CircuitBreaker: disabled, @@ -42,13 +43,15 @@ type FFlag struct { Features map[FeatureFlagKey]FeatureFlagState } -func NewFFlag(c *config.Configuration) *FFlag { +func NewFFlag(enableFeatureFlags []string) *FFlag { f := &FFlag{ Features: clone(DefaultFeaturesState), } - for _, flag := range c.EnableFeatureFlag { + for _, flag := range enableFeatureFlags { switch flag { + case string(IpRules): + f.Features[IpRules] = enabled case string(Prometheus): f.Features[Prometheus] = enabled case string(FullTextSearch): diff --git a/internal/pkg/fflag/fflag_test.go b/internal/pkg/fflag/fflag_test.go index b41489f80b..1ba1e27e0c 100644 --- a/internal/pkg/fflag/fflag_test.go +++ b/internal/pkg/fflag/fflag_test.go @@ -26,6 +26,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { }{ Features: map[FeatureFlagKey]FeatureFlagState{ Prometheus: disabled, + IpRules: disabled, FullTextSearch: enabled, CircuitBreaker: disabled, }, @@ -44,6 +45,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { }{ Features: map[FeatureFlagKey]FeatureFlagState{ Prometheus: disabled, + IpRules: disabled, FullTextSearch: enabled, CircuitBreaker: disabled, }, @@ -62,6 +64,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { }{ Features: map[FeatureFlagKey]FeatureFlagState{ Prometheus: enabled, + IpRules: disabled, FullTextSearch: enabled, CircuitBreaker: disabled, }, @@ -80,6 +83,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { }{ Features: map[FeatureFlagKey]FeatureFlagState{ Prometheus: enabled, + IpRules: disabled, FullTextSearch: enabled, CircuitBreaker: disabled, }, @@ -100,6 +104,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { Prometheus: disabled, FullTextSearch: disabled, CircuitBreaker: disabled, + IpRules: disabled, }, }, args: struct { @@ -118,6 +123,7 @@ func TestFFlag_CanAccessFeature(t *testing.T) { Prometheus: disabled, FullTextSearch: disabled, CircuitBreaker: disabled, + IpRules: disabled, }, }, args: struct { @@ -170,6 +176,7 @@ func TestNewFFlag(t *testing.T) { Prometheus: disabled, FullTextSearch: disabled, CircuitBreaker: disabled, + IpRules: disabled, }, }, wantErr: false, @@ -186,6 +193,7 @@ func TestNewFFlag(t *testing.T) { Prometheus: enabled, FullTextSearch: disabled, CircuitBreaker: disabled, + IpRules: disabled, }, }, wantErr: false, @@ -200,6 +208,7 @@ func TestNewFFlag(t *testing.T) { Prometheus: disabled, FullTextSearch: disabled, CircuitBreaker: disabled, + IpRules: disabled, }, }, wantErr: false, @@ -207,7 +216,7 @@ func TestNewFFlag(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewFFlag(tt.args.c) + got := NewFFlag(tt.args.c.EnableFeatureFlag) if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewFFlag() got = %v, want %v", got, tt.want) diff --git a/internal/pkg/license/keygen/feature.go b/internal/pkg/license/keygen/feature.go index 5a3711ff99..6921fdea8d 100644 --- a/internal/pkg/license/keygen/feature.go +++ b/internal/pkg/license/keygen/feature.go @@ -28,6 +28,7 @@ const ( MultiPlayerMode Feature = "MULTI_PLAYER_MODE" IngestRate Feature = "INGEST_RATE" AgentExecutionMode Feature = "AGENT_EXECUTION_MODE" + IpRules Feature = "IP_RULES" ) const ( diff --git a/internal/pkg/license/keygen/keygen.go b/internal/pkg/license/keygen/keygen.go index 89c488a1ab..c84ee45817 100644 --- a/internal/pkg/license/keygen/keygen.go +++ b/internal/pkg/license/keygen/keygen.go @@ -475,6 +475,14 @@ func (k *Licenser) AgentExecutionMode() bool { return ok } +func (k *Licenser) IpRules() bool { + if checkExpiry(k.license) != nil { + return false + } + _, ok := k.featureList[IpRules] + return ok +} + func (k *Licenser) FeatureListJSON(ctx context.Context) (json.RawMessage, error) { // only these guys have dynamic limits for now for f := range k.featureList { diff --git a/internal/pkg/license/license.go b/internal/pkg/license/license.go index 2fbea33f01..7c724bae14 100644 --- a/internal/pkg/license/license.go +++ b/internal/pkg/license/license.go @@ -25,6 +25,7 @@ type Licenser interface { MultiPlayerMode() bool IngestRate() bool AgentExecutionMode() bool + IpRules() bool // need more fleshing out AdvancedRetentionPolicy() bool diff --git a/internal/pkg/license/noop/noop.go b/internal/pkg/license/noop/noop.go index 3386ad1721..7759f1ff01 100644 --- a/internal/pkg/license/noop/noop.go +++ b/internal/pkg/license/noop/noop.go @@ -11,7 +11,7 @@ import ( type Licenser struct{} -func (Licenser) FeatureListJSON(ctx context.Context) (json.RawMessage, error) { +func (Licenser) FeatureListJSON(_ context.Context) (json.RawMessage, error) { return []byte{}, nil } @@ -19,15 +19,15 @@ func NewLicenser() *Licenser { return &Licenser{} } -func (Licenser) CreateOrg(ctx context.Context) (bool, error) { +func (Licenser) CreateOrg(_ context.Context) (bool, error) { return true, nil } -func (Licenser) CreateUser(ctx context.Context) (bool, error) { +func (Licenser) CreateUser(_ context.Context) (bool, error) { return true, nil } -func (Licenser) CreateProject(ctx context.Context) (bool, error) { +func (Licenser) CreateProject(_ context.Context) (bool, error) { return true, nil } @@ -99,8 +99,8 @@ func (Licenser) PortalLinks() bool { return true } -func (Licenser) CircuitBreaking() bool { - return true +func (Licenser) CircuitBreaking() bool { + return true } func (Licenser) MultiPlayerMode() bool { @@ -114,3 +114,7 @@ func (Licenser) IngestRate() bool { func (Licenser) AgentExecutionMode() bool { return true } + +func (Licenser) IpRules() bool { + return true +} diff --git a/mocks/license.go b/mocks/license.go index 13144101d5..cac62a846d 100644 --- a/mocks/license.go +++ b/mocks/license.go @@ -280,6 +280,20 @@ func (mr *MockLicenserMockRecorder) IngestRate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IngestRate", reflect.TypeOf((*MockLicenser)(nil).IngestRate)) } +// IpRules mocks base method. +func (m *MockLicenser) IpRules() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IpRules") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IpRules indicates an expected call of IpRules. +func (mr *MockLicenserMockRecorder) IpRules() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IpRules", reflect.TypeOf((*MockLicenser)(nil).IpRules)) +} + // MultiPlayerMode mocks base method. func (m *MockLicenser) MultiPlayerMode() bool { m.ctrl.T.Helper() diff --git a/mocks/repository.go b/mocks/repository.go index 36453d5536..dba4ad4eb3 100644 --- a/mocks/repository.go +++ b/mocks/repository.go @@ -644,7 +644,7 @@ func (m *MockEventRepository) UpdateEventEndpoints(arg0 context.Context, arg1 *d } // UpdateEventEndpoints indicates an expected call of UpdateEventEndpoints. -func (mr *MockEventRepositoryMockRecorder) UpdateEventEndpoints(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockEventRepositoryMockRecorder) UpdateEventEndpoints(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEventEndpoints", reflect.TypeOf((*MockEventRepository)(nil).UpdateEventEndpoints), arg0, arg1, arg2) } @@ -658,7 +658,7 @@ func (m *MockEventRepository) UpdateEventStatus(arg0 context.Context, arg1 *data } // UpdateEventStatus indicates an expected call of UpdateEventStatus. -func (mr *MockEventRepositoryMockRecorder) UpdateEventStatus(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockEventRepositoryMockRecorder) UpdateEventStatus(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEventStatus", reflect.TypeOf((*MockEventRepository)(nil).UpdateEventStatus), arg0, arg1, arg2) } diff --git a/net/dispatcher.go b/net/dispatcher.go index b1b28288a7..bc840fc801 100644 --- a/net/dispatcher.go +++ b/net/dispatcher.go @@ -3,11 +3,16 @@ package net import ( "bytes" "context" + "crypto/tls" "encoding/json" "errors" + "fmt" + "github.com/frain-dev/convoy/internal/pkg/fflag" + "github.com/stealthrocket/netjail" "io" "net/http" "net/http/httptrace" + "net/netip" "net/url" "time" @@ -19,49 +24,180 @@ import ( "github.com/frain-dev/convoy/util" ) +var ( + ErrAllowListIsRequired = errors.New("allowlist is required") + ErrBlockListIsRequired = errors.New("blocklist is required") + ErrLoggerIsRequired = errors.New("logger is required") + ErrInvalidIPPrefix = errors.New("invalid IP prefix") +) + +type DispatcherOption func(d *Dispatcher) error + type Dispatcher struct { - client *http.Client + // gating mechanisms + ff *fflag.FFlag + l license.Licenser + + logger *log.Logger + transport *http.Transport + client *http.Client + rules *netjail.Rules } -func NewDispatcher(httpProxy string, licenser license.Licenser, enforceSecure bool) (*Dispatcher, error) { - d := &Dispatcher{client: &http.Client{}} - - tr := &http.Transport{ - MaxIdleConns: 100, - IdleConnTimeout: 10 * time.Second, - MaxIdleConnsPerHost: 10, - TLSHandshakeTimeout: 3 * time.Second, - ExpectContinueTimeout: 1 * time.Second, +func NewDispatcher(l license.Licenser, ff *fflag.FFlag, options ...DispatcherOption) (*Dispatcher, error) { + d := &Dispatcher{ + l: l, + ff: ff, + client: &http.Client{}, + rules: &netjail.Rules{}, + transport: &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 10 * time.Second, + MaxIdleConnsPerHost: 10, + TLSHandshakeTimeout: 3 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, } - if licenser.UseForwardProxy() { - proxyUrl, isValid, err := d.setProxy(httpProxy) - if err != nil { + for _, option := range options { + if err := option(d); err != nil { return nil, err } + } - if isValid { - tr.Proxy = http.ProxyURL(proxyUrl) + if d.logger == nil { + return nil, ErrLoggerIsRequired + } + + if ff.CanAccessFeature(fflag.IpRules) && len(d.rules.Allow) == 0 && len(d.rules.Block) == 0 { + d.rules = &netjail.Rules{ + Allow: []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/8"), + netip.MustParsePrefix("::/0"), + }, + Block: []netip.Prefix{ + netip.MustParsePrefix("127.0.0.0/8"), + netip.MustParsePrefix("::1/128"), + }, } } - // if enforceSecure is false, allow self-signed certificates, susceptible to MITM attacks. - // if !enforceSecure { - // tr.TLSClientConfig = &tls.Config{ - // InsecureSkipVerify: true, - // } - // } else { - // tr.TLSClientConfig = &tls.Config{ - // MinVersion: tls.VersionTLS12, - // } - // } + netJailTransport := &netjail.Transport{ + New: func() *http.Transport { + return d.transport.Clone() + }, + } - d.client.Transport = tr + if ff.CanAccessFeature(fflag.IpRules) { + d.client.Transport = netJailTransport + } else { + d.client.Transport = d.transport + } return d, nil } -func (d *Dispatcher) setProxy(proxyURL string) (*url.URL, bool, error) { +// ProxyOption defines an HTTP proxy which the client will use. It fails-open the string isn't a valid HTTP URL +func ProxyOption(httpProxy string) DispatcherOption { + return func(d *Dispatcher) error { + if httpProxy == "" { + return nil + } + + if d.l.UseForwardProxy() { + proxyUrl, isValid, err := d.validateProxy(httpProxy) + if err != nil { + return err + } + + if isValid { + d.transport.Proxy = http.ProxyURL(proxyUrl) + } + } + + return nil + } +} + +// AllowListOption sets a list of IP prefixes which will outgoing traffic will be granted access +func AllowListOption(allowList []string) DispatcherOption { + return func(d *Dispatcher) error { + if len(allowList) == 0 { + return ErrAllowListIsRequired + } + + if !d.l.IpRules() || !d.ff.CanAccessFeature(fflag.IpRules) { + return nil + } + + netAllow := make([]netip.Prefix, len(allowList)) + for i, prefix := range allowList { + parsed, err := netip.ParsePrefix(prefix) + if err != nil { + return fmt.Errorf("%w: %v in allowlist", ErrInvalidIPPrefix, err) + } + netAllow[i] = parsed + d.rules.Allow = netAllow + } + + return nil + } +} + +// BlockListOption sets a list of IP prefixes which will outgoing traffic will be denied access +func BlockListOption(blockList []string) DispatcherOption { + return func(d *Dispatcher) error { + if len(blockList) == 0 { + return ErrBlockListIsRequired + } + + if !d.l.IpRules() || !d.ff.CanAccessFeature(fflag.IpRules) { + return nil + } + + netBlock := make([]netip.Prefix, len(blockList)) + for i, prefix := range blockList { + parsed, err := netip.ParsePrefix(prefix) + if err != nil { + return fmt.Errorf("%w: %v in blocklist", ErrInvalidIPPrefix, err) + } + netBlock[i] = parsed + } + + d.rules.Block = netBlock + return nil + } +} + +// InsecureSkipVerifyOption allow self-signed certificates if set to false which is susceptible to MITM attacks. +func InsecureSkipVerifyOption(insecureSkipVerify bool) DispatcherOption { + return func(d *Dispatcher) error { + if insecureSkipVerify { + d.transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } else { + d.transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + + return nil + } +} + +func LoggerOption(logger *log.Logger) DispatcherOption { + return func(d *Dispatcher) error { + if logger == nil { + return ErrLoggerIsRequired + } + + d.logger = logger + return nil + } +} + +func (d *Dispatcher) validateProxy(proxyURL string) (*url.URL, bool, error) { if !util.IsStringEmpty(proxyURL) { pUrl, err := url.Parse(proxyURL) if err != nil { @@ -79,20 +215,26 @@ func (d *Dispatcher) setProxy(proxyURL string) (*url.URL, bool, error) { } func (d *Dispatcher) SendRequest(ctx context.Context, endpoint, method string, jsonData json.RawMessage, signatureHeader string, hmac string, maxResponseSize int64, headers httpheader.HTTPHeader, idempotencyKey string, timeout time.Duration) (*Response, error) { + d.logger.Debugf("rules: %+v", d.rules) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() r := &Response{} if util.IsStringEmpty(signatureHeader) || util.IsStringEmpty(hmac) { err := errors.New("signature header and hmac are required") - log.WithError(err).Error("Dispatcher invalid arguments") + d.logger.WithError(err).Error("Dispatcher invalid arguments") r.Error = err.Error() return r, err } + if d.ff.CanAccessFeature(fflag.IpRules) { + ctx = netjail.ContextWithRules(ctx, d.rules) + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewBuffer(jsonData)) if err != nil { - log.WithError(err).Error("error occurred while creating request") + d.logger.WithError(err).Error("error occurred while creating request") return r, err } @@ -135,7 +277,6 @@ func updateDispatchHeaders(r *Response, res *http.Response) { r.ResponseHeader = res.Header } -// TODO(subomi): Refactor this to support Enterprise Editions func defaultUserAgent() string { return "Convoy/" + convoy.GetVersion() } @@ -144,7 +285,7 @@ func (d *Dispatcher) do(req *http.Request, res *Response, maxResponseSize int64) trace := &httptrace.ClientTrace{ GotConn: func(connInfo httptrace.GotConnInfo) { res.IP = connInfo.Conn.RemoteAddr().String() - log.Debugf("IP address resolved to: %s", connInfo.Conn.RemoteAddr()) + d.logger.Debugf("IP address resolved to: %s", connInfo.Conn.RemoteAddr()) }, } @@ -152,7 +293,7 @@ func (d *Dispatcher) do(req *http.Request, res *Response, maxResponseSize int64) response, err := d.client.Do(req) if err != nil { - log.WithError(err).Error("error sending request to API endpoint") + d.logger.WithError(err).Error("error sending request to API endpoint") res.Error = err.Error() return err } @@ -170,7 +311,7 @@ func (d *Dispatcher) do(req *http.Request, res *Response, maxResponseSize int64) res.Body = buf if err != nil { - log.WithError(err).Error("couldn't parse response body") + d.logger.WithError(err).Error("couldn't parse response body") return err } diff --git a/net/dispatcher_test.go b/net/dispatcher_test.go index 1af4788e69..68f1accc1d 100644 --- a/net/dispatcher_test.go +++ b/net/dispatcher_test.go @@ -5,7 +5,12 @@ import ( "context" "crypto/rand" "encoding/json" + "github.com/frain-dev/convoy/internal/pkg/fflag" + "github.com/frain-dev/convoy/pkg/log" + "github.com/stealthrocket/netjail" "net/http" + "net/http/httptest" + "os" "testing" "time" @@ -276,7 +281,7 @@ func TestDispatcher_SendRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &Dispatcher{client: client} + d := &Dispatcher{client: client, logger: log.NewLogger(os.Stdout), ff: fflag.NewFFlag([]string{})} if tt.nFn != nil { deferFn := tt.nFn() @@ -308,7 +313,7 @@ func TestNewDispatcher(t *testing.T) { tests := []struct { name string args args - mockFn func(licenser license.Licenser) + mockFn func(license.Licenser) wantProxy bool wantErr bool wantErrMsg string @@ -351,7 +356,14 @@ func TestNewDispatcher(t *testing.T) { if tt.mockFn != nil { tt.mockFn(licenser) } - d, err := NewDispatcher(tt.args.httpProxy, licenser, tt.args.enforceSecure) + + d, err := NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + LoggerOption(log.NewLogger(os.Stdout)), + InsecureSkipVerifyOption(tt.args.enforceSecure), + ProxyOption(tt.args.httpProxy), + ) if tt.wantErr { require.Error(t, err) require.Equal(t, tt.wantErrMsg, err.Error()) @@ -361,8 +373,158 @@ func TestNewDispatcher(t *testing.T) { require.NoError(t, err) if tt.wantProxy { - require.NotNil(t, d.client.Transport.(*http.Transport).Proxy) + require.NotNil(t, d.client.Transport.(*netjail.Transport).New().Proxy) } }) } } + +// TestDispatcherSendRequest tests the basic functionality of SendRequest +func TestDispatcherSendRequest(t *testing.T) { + // Start a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "POST", r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "test-hmac", r.Header.Get("X-Signature")) + require.Equal(t, "test-key", r.Header.Get("X-Convoy-Idempotency-Key")) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status": "success"}`)) + })) + defer server.Close() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + licenser := mocks.NewMockLicenser(ctrl) + licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(2).Return(true) + + // Create a new dispatcher + dispatcher, err := NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + LoggerOption(log.NewLogger(os.Stdout)), + ProxyOption("nil"), + AllowListOption([]string{"0.0.0.0/0"}), + BlockListOption([]string{"10.0.0.0/8"}), + ) + require.NoError(t, err) + + // Prepare request data + jsonData := json.RawMessage(`{"key": "value"}`) + headers := httpheader.HTTPHeader{ + "X-Custom-Header": []string{"custom-value"}, + } + + // Send request + resp, err := dispatcher.SendRequest( + context.Background(), + server.URL, + "POST", + jsonData, + "X-Signature", + "test-hmac", + 1024, + headers, + "test-key", + 5*time.Second, + ) + + // Assert response + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, `{"status": "success"}`, string(resp.Body)) + require.Equal(t, "custom-value", resp.RequestHeader.Get("X-Custom-Header")) +} + +// TestDispatcherWithTimeout tests the timeout functionality +func TestDispatcherWithTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + licenser := mocks.NewMockLicenser(ctrl) + licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(2).Return(true) + + dispatcher, err := NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + LoggerOption(log.NewLogger(os.Stdout)), + ProxyOption("nil"), + AllowListOption([]string{"0.0.0.0/0"}), + BlockListOption([]string{"10.0.0.0/8"}), + ) + require.NoError(t, err) + + // Send request with a short timeout + _, err = dispatcher.SendRequest( + context.Background(), + server.URL, + "GET", + nil, + "X-Signature", + "test-hmac", + 1024, + nil, + "", + 1*time.Second, + ) + + // Assert that we got a timeout error + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Contains(t, err.Error(), "context deadline exceeded") +} + +// TestDispatcherWithBlockedIP tests the IP blocking functionality +func TestDispatcherWithBlockedIP(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + licenser := mocks.NewMockLicenser(ctrl) + licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(2).Return(true) + + // Create a dispatcher with a block list that includes the test server's IP + dispatcher, err := NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + LoggerOption(log.NewLogger(os.Stdout)), + ProxyOption("nil"), + AllowListOption([]string{"0.0.0.0/0"}), + BlockListOption([]string{"127.0.0.0/8"}), + ) + require.NoError(t, err) + + // Attempt to send a request + _, err = dispatcher.SendRequest( + context.Background(), + server.URL, + "GET", + nil, + "X-Signature", + "test-hmac", + 1024, + nil, + "", + 5*time.Second, + ) + + // Assert that the request was blocked + require.Error(t, err) + require.ErrorIs(t, err, netjail.ErrDenied) + require.Contains(t, err.Error(), "127.0.0.1: address not allowed") +} diff --git a/testcon/docker_e2e_integration_test.go b/testcon/docker_e2e_integration_test.go index fee94a384a..f3a5f97db4 100644 --- a/testcon/docker_e2e_integration_test.go +++ b/testcon/docker_e2e_integration_test.go @@ -45,7 +45,10 @@ func (d *DockerE2EIntegrationTestSuite) SetupSuite() { WaitForService("migrate", wait.NewLogStrategy("migration up succeeded").WithStartupTimeout(60*time.Second)). Up(ctx, tc.Wait(true), tc.WithRecreate(api.RecreateNever)) - if err != nil && !strings.Contains(err.Error(), "Ryuk") && !strings.Contains(err.Error(), "container exited with code 0") { + if err != nil && + !strings.Contains(err.Error(), "Ryuk") && + !strings.Contains(err.Error(), "container exited with code 0") && + !strings.Contains(err.Error(), "exited (0)") { require.NoError(t, err) } diff --git a/testcon/testdata/convoy-docker.json b/testcon/testdata/convoy-docker.json index 1ec985fe62..e0114e518f 100644 --- a/testcon/testdata/convoy-docker.json +++ b/testcon/testdata/convoy-docker.json @@ -18,7 +18,6 @@ } }, "instance_ingest_rate": 50, - "api_rate_limit_enabled": false, "auth": { "jwt": { "enabled": true diff --git a/testcon/testdata/convoy-host.json b/testcon/testdata/convoy-host.json index 20cb484e30..bdf7d7d58d 100644 --- a/testcon/testdata/convoy-host.json +++ b/testcon/testdata/convoy-host.json @@ -18,7 +18,6 @@ } }, "instance_ingest_rate": 50, - "api_rate_limit_enabled": false, "auth": { "jwt": { "enabled": true diff --git a/worker/task/process_event_delivery_test.go b/worker/task/process_event_delivery_test.go index 549b3742d7..b9ad4ee48e 100644 --- a/worker/task/process_event_delivery_test.go +++ b/worker/task/process_event_delivery_test.go @@ -929,6 +929,7 @@ func TestProcessEventDelivery(t *testing.T) { licenser := mocks.NewMockLicenser(ctrl) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) err := config.LoadConfig(tc.cfgPath) if err != nil { @@ -954,7 +955,15 @@ func TestProcessEventDelivery(t *testing.T) { tc.dbFn(endpointRepo, projectRepo, msgRepo, q, rateLimiter, attemptsRepo, licenser) } - dispatcher, err := net.NewDispatcher("", licenser, false) + featureFlag := fflag.NewFFlag(cfg.EnableFeatureFlag) + + dispatcher, err := net.NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + net.LoggerOption(log.NewLogger(os.Stdout)), + net.BlockListOption([]string{"10.0.0.0/8"}), + net.ProxyOption("nil"), + ) require.NoError(t, err) mockStore := cb.NewTestStore() @@ -977,7 +986,6 @@ func TestProcessEventDelivery(t *testing.T) { ) require.NoError(t, err) - featureFlag := fflag.NewFFlag(&cfg) processFn := ProcessEventDelivery(endpointRepo, msgRepo, licenser, projectRepo, q, rateLimiter, dispatcher, attemptsRepo, manager, featureFlag) payload := EventDelivery{ diff --git a/worker/task/process_meta_event_test.go b/worker/task/process_meta_event_test.go index c35c0111ec..754fabec9a 100644 --- a/worker/task/process_meta_event_test.go +++ b/worker/task/process_meta_event_test.go @@ -3,6 +3,9 @@ package task import ( "context" "encoding/json" + "github.com/frain-dev/convoy/internal/pkg/fflag" + "github.com/frain-dev/convoy/pkg/log" + "os" "testing" "time" @@ -136,7 +139,12 @@ func TestProcessMetaEvent(t *testing.T) { licenser := mocks.NewMockLicenser(ctrl) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) - dispatcher, err := net.NewDispatcher("", licenser, false) + dispatcher, err := net.NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + net.LoggerOption(log.NewLogger(os.Stdout)), + net.ProxyOption("nil"), + ) require.NoError(t, err) err = config.LoadConfig(tc.cfgPath) diff --git a/worker/task/process_retry_event_delivery_test.go b/worker/task/process_retry_event_delivery_test.go index d187711a57..14ced6e5a4 100644 --- a/worker/task/process_retry_event_delivery_test.go +++ b/worker/task/process_retry_event_delivery_test.go @@ -70,6 +70,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser, _ := l.(*mocks.MockLicenser) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, }, { @@ -109,6 +110,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser, _ := l.(*mocks.MockLicenser) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, }, { @@ -174,6 +176,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser, _ := l.(*mocks.MockLicenser) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) m.EXPECT(). UpdateStatusOfEventDelivery(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -272,6 +275,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -362,6 +366,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -452,6 +457,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -544,6 +550,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -636,6 +643,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -728,6 +736,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(false) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -817,6 +826,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -911,6 +921,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -1006,6 +1017,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { licenser.EXPECT().CircuitBreaking().Times(1).Return(false) licenser.EXPECT().AdvancedEndpointMgmt().Times(1).Return(true) licenser.EXPECT().UseForwardProxy().Times(1).Return(true) + licenser.EXPECT().IpRules().Times(1).Return(true) }, nFn: func() func() { httpmock.Activate() @@ -1061,7 +1073,13 @@ func TestProcessRetryEventDelivery(t *testing.T) { tc.dbFn(endpointRepo, projectRepo, msgRepo, q, rateLimiter, attemptsRepo, licenser) } - dispatcher, err := net.NewDispatcher("", licenser, false) + dispatcher, err := net.NewDispatcher( + licenser, + fflag.NewFFlag([]string{string(fflag.IpRules)}), + net.BlockListOption([]string{"10.0.0.0/8"}), + net.LoggerOption(log.NewLogger(os.Stdout)), + net.ProxyOption("nil"), + ) require.NoError(t, err) mockStore := cb.NewTestStore() @@ -1084,7 +1102,7 @@ func TestProcessRetryEventDelivery(t *testing.T) { ) require.NoError(t, err) - featureFlag := fflag.NewFFlag(&cfg) + featureFlag := fflag.NewFFlag(cfg.EnableFeatureFlag) processFn := ProcessRetryEventDelivery(endpointRepo, msgRepo, licenser, projectRepo, q, rateLimiter, dispatcher, attemptsRepo, manager, featureFlag) diff --git a/worker/task/search_tokenizer.go b/worker/task/search_tokenizer.go index 7434ed5f0d..4e870e8c68 100644 --- a/worker/task/search_tokenizer.go +++ b/worker/task/search_tokenizer.go @@ -87,7 +87,7 @@ func tokenize(ctx context.Context, eventRepo datastore.EventRepository, jobRepo return err } - fflag := fflag2.NewFFlag(&cfg) + fflag := fflag2.NewFFlag(cfg.EnableFeatureFlag) if !fflag.CanAccessFeature(fflag2.FullTextSearch) { return fflag2.ErrFullTextSearchNotEnabled