From 1763fd14aec3abda9c5afaeddc15ff59dbcff5d3 Mon Sep 17 00:00:00 2001 From: Sorin Stanculeanu Date: Mon, 30 Sep 2024 15:50:45 +0300 Subject: [PATCH 1/2] get number of shards from observers --- cmd/proxy/config/config.toml | 15 +- cmd/proxy/main.go | 30 +++- config/config.go | 3 +- observer/baseNodeProvider.go | 10 -- observer/baseNodeProvider_test.go | 12 -- observer/circularQueueNodesProvider_test.go | 3 - observer/nodesProviderFactory.go | 12 +- observer/nodesProviderFactory_test.go | 6 +- process/errors.go | 9 +- process/interface.go | 7 + process/mock/httpClientMock.go | 16 ++ process/numShardsProcessor.go | 147 ++++++++++++++++ process/numShardsProcessor_test.go | 176 ++++++++++++++++++++ 13 files changed, 397 insertions(+), 49 deletions(-) create mode 100644 process/mock/httpClientMock.go create mode 100644 process/numShardsProcessor.go create mode 100644 process/numShardsProcessor_test.go diff --git a/cmd/proxy/config/config.toml b/cmd/proxy/config/config.toml index 3072af75..03ef65fc 100644 --- a/cmd/proxy/config/config.toml +++ b/cmd/proxy/config/config.toml @@ -39,15 +39,18 @@ # With this flag disabled, /transaction/pool route will return an error AllowEntireTxPoolFetch = false - # NumberOfShards represents the total number of shards from the network (excluding metachain) - NumberOfShards = 3 + # NumShardsTimeoutInSec represents the maximum number of seconds to wait for at least one observer online until throwing an error + NumShardsTimeoutInSec = 90 + + # TimeBetweenNodesRequestsInSec represents time to wait before retry to get the number of shards from observers + TimeBetweenNodesRequestsInSec = 2 [AddressPubkeyConverter] - #Length specifies the length in bytes of an address - Length = 32 + #Length specifies the length in bytes of an address + Length = 32 - # Type specifies the type of public keys: hex or bech32 - Type = "bech32" + # Type specifies the type of public keys: hex or bech32 + Type = "bech32" [Marshalizer] Type = "gogo protobuf" diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 154f7e5c..21b66cb7 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -328,7 +328,6 @@ func createVersionsRegistryTestOrProduction( ValStatsCacheValidityDurationSec: 60, EconomicsMetricsCacheValidityDurationSec: 6, FaucetValue: "10000000000", - NumberOfShards: 3, }, ApiLogging: config.ApiLoggingConfig{ LoggingEnabled: true, @@ -409,12 +408,32 @@ func createVersionsRegistry( return nil, err } - shardCoord, err := sharding.NewMultiShardCoordinator(cfg.GeneralSettings.NumberOfShards, 0) + httpClient := &http.Client{} + httpClient.Timeout = time.Duration(cfg.GeneralSettings.RequestTimeoutSec) * time.Second + observersList := make([]string, 0, len(cfg.Observers)) + for _, node := range cfg.Observers { + observersList = append(observersList, node.Address) + } + argsNumShardsProcessor := process.ArgNumShardsProcessor{ + HttpClient: httpClient, + Observers: observersList, + TimeBetweenNodesRequestsInSec: cfg.GeneralSettings.TimeBetweenNodesRequestsInSec, + NumShardsTimeoutInSec: cfg.GeneralSettings.NumShardsTimeoutInSec, + RequestTimeoutInSec: cfg.GeneralSettings.RequestTimeoutSec, + } + numShardsProcessor, err := process.NewNumShardsProcessor(argsNumShardsProcessor) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + numShards, err := numShardsProcessor.GetNetworkNumShards(ctx) + cancel() if err != nil { return nil, err } - nodesProviderFactory, err := observer.NewNodesProviderFactory(*cfg, configurationFilePath) + nodesProviderFactory, err := observer.NewNodesProviderFactory(*cfg, configurationFilePath, numShards) if err != nil { return nil, err } @@ -431,6 +450,11 @@ func createVersionsRegistry( } } + shardCoord, err := sharding.NewMultiShardCoordinator(numShards, 0) + if err != nil { + return nil, err + } + bp, err := process.NewBaseProcessor( cfg.GeneralSettings.RequestTimeoutSec, shardCoord, diff --git a/config/config.go b/config/config.go index 2357dc72..90317616 100644 --- a/config/config.go +++ b/config/config.go @@ -16,7 +16,8 @@ type GeneralSettingsConfig struct { BalancedObservers bool BalancedFullHistoryNodes bool AllowEntireTxPoolFetch bool - NumberOfShards uint32 + NumShardsTimeoutInSec int + TimeBetweenNodesRequestsInSec int } // Config will hold the whole config file's data diff --git a/observer/baseNodeProvider.go b/observer/baseNodeProvider.go index 82ebe82c..49b964ec 100644 --- a/observer/baseNodeProvider.go +++ b/observer/baseNodeProvider.go @@ -139,16 +139,6 @@ func (bnp *baseNodeProvider) ReloadNodes(nodesType data.NodeType) data.NodesRelo } } - numOldShards := bnp.numOfShards - numNewShards := newConfig.GeneralSettings.NumberOfShards - if numOldShards != numNewShards { - return data.NodesReloadResponse{ - OkRequest: false, - Description: "not reloaded", - Error: fmt.Sprintf("different number of shards. before: %d, now: %d", numOldShards, numNewShards), - } - } - nodes := newConfig.Observers if nodesType == data.FullHistoryNode { nodes = newConfig.FullHistoryNodes diff --git a/observer/baseNodeProvider_test.go b/observer/baseNodeProvider_test.go index 34db5660..f5c09c92 100644 --- a/observer/baseNodeProvider_test.go +++ b/observer/baseNodeProvider_test.go @@ -73,18 +73,6 @@ func TestBaseNodeProvider_InvalidShardForObserver(t *testing.T) { require.True(t, strings.Contains(err.Error(), "addr1")) } -func TestBaseNodeProvider_ReloadNodesDifferentNumberOfNewShard(t *testing.T) { - bnp := &baseNodeProvider{ - configurationFilePath: configurationPath, - shardIds: []uint32{0, 1}, - numOfShards: 2, - } - - response := bnp.ReloadNodes(data.Observer) - require.False(t, response.OkRequest) - require.Contains(t, response.Error, "different number of shards") -} - func TestBaseNodeProvider_ReloadNodesConfigurationFileNotFound(t *testing.T) { bnp := &baseNodeProvider{ configurationFilePath: "wrong config path", diff --git a/observer/circularQueueNodesProvider_test.go b/observer/circularQueueNodesProvider_test.go index a24aadc8..f94f2bed 100644 --- a/observer/circularQueueNodesProvider_test.go +++ b/observer/circularQueueNodesProvider_test.go @@ -23,9 +23,6 @@ func getDummyConfig() config.Config { ShardId: 1, }, }, - GeneralSettings: config.GeneralSettingsConfig{ - NumberOfShards: 2, - }, } } diff --git a/observer/nodesProviderFactory.go b/observer/nodesProviderFactory.go index 83153d7c..ea18322e 100644 --- a/observer/nodesProviderFactory.go +++ b/observer/nodesProviderFactory.go @@ -11,13 +11,15 @@ var log = logger.GetOrCreate("observer") type nodesProviderFactory struct { cfg config.Config configurationFilePath string + numberOfShards uint32 } // NewNodesProviderFactory returns a new instance of nodesProviderFactory -func NewNodesProviderFactory(cfg config.Config, configurationFilePath string) (*nodesProviderFactory, error) { +func NewNodesProviderFactory(cfg config.Config, configurationFilePath string, numberOfShards uint32) (*nodesProviderFactory, error) { return &nodesProviderFactory{ cfg: cfg, configurationFilePath: configurationFilePath, + numberOfShards: numberOfShards, }, nil } @@ -27,13 +29,13 @@ func (npf *nodesProviderFactory) CreateObservers() (NodesProviderHandler, error) return NewCircularQueueNodesProvider( npf.cfg.Observers, npf.configurationFilePath, - npf.cfg.GeneralSettings.NumberOfShards) + npf.numberOfShards) } return NewSimpleNodesProvider( npf.cfg.Observers, npf.configurationFilePath, - npf.cfg.GeneralSettings.NumberOfShards) + npf.numberOfShards) } // CreateFullHistoryNodes will create and return an object of type NodesProviderHandler based on a flag @@ -42,7 +44,7 @@ func (npf *nodesProviderFactory) CreateFullHistoryNodes() (NodesProviderHandler, nodesProviderHandler, err := NewCircularQueueNodesProvider( npf.cfg.FullHistoryNodes, npf.configurationFilePath, - npf.cfg.GeneralSettings.NumberOfShards) + npf.numberOfShards) if err != nil { return getDisabledFullHistoryNodesProviderIfNeeded(err) } @@ -53,7 +55,7 @@ func (npf *nodesProviderFactory) CreateFullHistoryNodes() (NodesProviderHandler, nodesProviderHandler, err := NewSimpleNodesProvider( npf.cfg.FullHistoryNodes, npf.configurationFilePath, - npf.cfg.GeneralSettings.NumberOfShards) + npf.numberOfShards) if err != nil { return getDisabledFullHistoryNodesProviderIfNeeded(err) } diff --git a/observer/nodesProviderFactory_test.go b/observer/nodesProviderFactory_test.go index ee095e42..9e9a26c8 100644 --- a/observer/nodesProviderFactory_test.go +++ b/observer/nodesProviderFactory_test.go @@ -10,7 +10,7 @@ import ( func TestNewObserversProviderFactory_ShouldWork(t *testing.T) { t.Parallel() - opf, err := NewNodesProviderFactory(config.Config{}, "path") + opf, err := NewNodesProviderFactory(config.Config{}, "path", 2) assert.Nil(t, err) assert.NotNil(t, opf) } @@ -21,7 +21,7 @@ func TestObserversProviderFactory_CreateShouldReturnSimple(t *testing.T) { cfg := getDummyConfig() cfg.GeneralSettings.BalancedObservers = false - opf, _ := NewNodesProviderFactory(cfg, "path") + opf, _ := NewNodesProviderFactory(cfg, "path", 2) op, err := opf.CreateObservers() assert.Nil(t, err) _, ok := op.(*simpleNodesProvider) @@ -34,7 +34,7 @@ func TestObserversProviderFactory_CreateShouldReturnCircularQueue(t *testing.T) cfg := getDummyConfig() cfg.GeneralSettings.BalancedObservers = true - opf, _ := NewNodesProviderFactory(cfg, "path") + opf, _ := NewNodesProviderFactory(cfg, "path", 2) op, err := opf.CreateObservers() assert.Nil(t, err) _, ok := op.(*circularQueueNodesProvider) diff --git a/process/errors.go b/process/errors.go index 134e7bf9..8060b7b4 100644 --- a/process/errors.go +++ b/process/errors.go @@ -56,18 +56,12 @@ var ErrNoFaucetAccountForGivenShard = errors.New("no faucet account found for th // ErrNilNodesProvider signals that a nil observers provider has been provided var ErrNilNodesProvider = errors.New("nil nodes provider") -// ErrInvalidShardId signals that a invalid shard id has been provided -var ErrInvalidShardId = errors.New("invalid shard id") - // ErrNilPubKeyConverter signals that a nil pub key converter has been provided var ErrNilPubKeyConverter = errors.New("nil pub key converter provided") // ErrNoValidTransactionToSend signals that no valid transaction were received var ErrNoValidTransactionToSend = errors.New("no valid transaction to send") -// ErrNilDatabaseConnector signals that a nil database connector was provided -var ErrNilDatabaseConnector = errors.New("not valid database connector") - // ErrCannotParseNodeStatusMetrics signals that the node status metrics cannot be parsed var ErrCannotParseNodeStatusMetrics = errors.New("cannot parse node status metrics") @@ -115,3 +109,6 @@ var ErrEmptyCommitString = errors.New("empty commit id string") // ErrEmptyPubKey signals that an empty public key has been provided var ErrEmptyPubKey = errors.New("public key is empty") + +// ErrNilHttpClient signals that a nil http client has been provided +var ErrNilHttpClient = errors.New("nil http client") diff --git a/process/interface.go b/process/interface.go index eb7e558c..f424bdfc 100644 --- a/process/interface.go +++ b/process/interface.go @@ -1,6 +1,8 @@ package process import ( + "net/http" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/data/vm" @@ -78,3 +80,8 @@ type StatusMetricsProvider interface { GetMetricsForPrometheus() string IsInterfaceNil() bool } + +// HttpClient defines an interface for the http client +type HttpClient interface { + Do(req *http.Request) (*http.Response, error) +} diff --git a/process/mock/httpClientMock.go b/process/mock/httpClientMock.go new file mode 100644 index 00000000..9e09fb2a --- /dev/null +++ b/process/mock/httpClientMock.go @@ -0,0 +1,16 @@ +package mock + +import "net/http" + +// HttpClientMock - +type HttpClientMock struct { + DoCalled func(req *http.Request) (*http.Response, error) +} + +// Do - +func (mock *HttpClientMock) Do(req *http.Request) (*http.Response, error) { + if mock.DoCalled != nil { + return mock.DoCalled(req) + } + return &http.Response{}, nil +} diff --git a/process/numShardsProcessor.go b/process/numShardsProcessor.go new file mode 100644 index 00000000..8a9d907c --- /dev/null +++ b/process/numShardsProcessor.go @@ -0,0 +1,147 @@ +package process + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" +) + +var errTimeIsOut = errors.New("time is out") + +const ( + networkConfigPath = "/network/config" +) + +type networkConfigResponseData struct { + Config struct { + NumShards uint32 `json:"erd_num_shards_without_meta"` + } `json:"config"` +} + +type networkConfigResponse struct { + Data networkConfigResponseData `json:"data"` + Error string `json:"error"` + Code string `json:"code"` +} + +// ArgNumShardsProcessor is the DTO used to create a new instance of numShardsProcessor +type ArgNumShardsProcessor struct { + HttpClient HttpClient + Observers []string + TimeBetweenNodesRequestsInSec int + NumShardsTimeoutInSec int + RequestTimeoutInSec int +} + +type numShardsProcessor struct { + observers []string + httpClient HttpClient + timeBetweenNodesRequests time.Duration + numShardsTimeout time.Duration + requestTimeout time.Duration +} + +// NewNumShardsProcessor returns a new instance of numShardsProcessor +func NewNumShardsProcessor(args ArgNumShardsProcessor) (*numShardsProcessor, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + return &numShardsProcessor{ + observers: args.Observers, + httpClient: args.HttpClient, + timeBetweenNodesRequests: time.Second * time.Duration(args.TimeBetweenNodesRequestsInSec), + numShardsTimeout: time.Second * time.Duration(args.NumShardsTimeoutInSec), + requestTimeout: time.Second * time.Duration(args.RequestTimeoutInSec), + }, nil +} + +func checkArgs(args ArgNumShardsProcessor) error { + if check.IfNilReflect(args.HttpClient) { + return ErrNilHttpClient + } + if len(args.Observers) == 0 { + return fmt.Errorf("%w for Observers, empty list provided", core.ErrInvalidValue) + } + if args.TimeBetweenNodesRequestsInSec == 0 { + return fmt.Errorf("%w for TimeBetweenNodesRequestsInSec, %d provided", core.ErrInvalidValue, args.TimeBetweenNodesRequestsInSec) + } + if args.NumShardsTimeoutInSec == 0 { + return fmt.Errorf("%w for NumShardsTimeoutInSec, %d provided", core.ErrInvalidValue, args.NumShardsTimeoutInSec) + } + if args.RequestTimeoutInSec == 0 { + return fmt.Errorf("%w for RequestTimeoutInSec, %d provided", core.ErrInvalidValue, args.RequestTimeoutInSec) + } + + return nil +} + +// GetNetworkNumShards tries to get the number of shards from the network +func (processor *numShardsProcessor) GetNetworkNumShards(ctx context.Context) (uint32, error) { + log.Info("getting the number of shards from observers...") + + waitNodeTicker := time.NewTicker(processor.timeBetweenNodesRequests) + for { + select { + case <-waitNodeTicker.C: + for _, observerAddress := range processor.observers { + numShards, httpStatus := processor.tryGetnumShardsFromObserver(observerAddress) + if httpStatus == http.StatusOK { + log.Info("fetched the number of shards", "shards", numShards) + return numShards, nil + } + } + case <-time.After(processor.numShardsTimeout): + return 0, fmt.Errorf("%w, no observer online", errTimeIsOut) + case <-ctx.Done(): + log.Debug("closing the getNetworkNumShards loop due to context done...") + return 0, errTimeIsOut + } + } +} + +func (processor *numShardsProcessor) tryGetnumShardsFromObserver(observerAddress string) (uint32, int) { + ctx, cancel := context.WithTimeout(context.Background(), processor.requestTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, observerAddress+networkConfigPath, nil) + if err != nil { + return 0, http.StatusNotFound + } + + resp, err := processor.httpClient.Do(req) + if err != nil { + return 0, http.StatusNotFound + } + + defer func() { + if resp != nil && resp.Body != nil { + log.LogIfError(resp.Body.Close()) + } + }() + + if resp.StatusCode != http.StatusOK { + return 0, resp.StatusCode + } + + responseBodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return 0, http.StatusInternalServerError + } + + var response networkConfigResponse + err = json.Unmarshal(responseBodyBytes, &response) + if err != nil { + return 0, http.StatusInternalServerError + } + + return response.Data.Config.NumShards, resp.StatusCode +} diff --git a/process/numShardsProcessor_test.go b/process/numShardsProcessor_test.go new file mode 100644 index 00000000..2111c784 --- /dev/null +++ b/process/numShardsProcessor_test.go @@ -0,0 +1,176 @@ +package process + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-proxy-go/process/mock" + "github.com/stretchr/testify/require" +) + +func createMockArgNumShardsProcessor() ArgNumShardsProcessor { + return ArgNumShardsProcessor{ + HttpClient: &mock.HttpClientMock{}, + Observers: []string{"obs1, obs2"}, + TimeBetweenNodesRequestsInSec: 2, + NumShardsTimeoutInSec: 10, + RequestTimeoutInSec: 5, + } +} + +func TestNewNumShardsProcessor(t *testing.T) { + t.Parallel() + + t.Run("nil HttpClient should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.HttpClient = nil + + proc, err := NewNumShardsProcessor(args) + require.Equal(t, ErrNilHttpClient, err) + require.Nil(t, proc) + }) + t.Run("empty observers list should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.Observers = []string{} + + proc, err := NewNumShardsProcessor(args) + require.True(t, errors.Is(err, core.ErrInvalidValue)) + require.True(t, strings.Contains(err.Error(), "Observers")) + require.Nil(t, proc) + }) + t.Run("invalid TimeBetweenNodesRequestsInSec should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.TimeBetweenNodesRequestsInSec = 0 + + proc, err := NewNumShardsProcessor(args) + require.True(t, errors.Is(err, core.ErrInvalidValue)) + require.True(t, strings.Contains(err.Error(), "TimeBetweenNodesRequestsInSec")) + require.Nil(t, proc) + }) + t.Run("invalid NumShardsTimeoutInSec should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.NumShardsTimeoutInSec = 0 + + proc, err := NewNumShardsProcessor(args) + require.True(t, errors.Is(err, core.ErrInvalidValue)) + require.True(t, strings.Contains(err.Error(), "NumShardsTimeoutInSec")) + require.Nil(t, proc) + }) + t.Run("invalid RequestTimeoutInSec should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.RequestTimeoutInSec = 0 + + proc, err := NewNumShardsProcessor(args) + require.True(t, errors.Is(err, core.ErrInvalidValue)) + require.True(t, strings.Contains(err.Error(), "RequestTimeoutInSec")) + require.Nil(t, proc) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + proc, err := NewNumShardsProcessor(createMockArgNumShardsProcessor()) + require.NoError(t, err) + require.NotNil(t, proc) + }) +} + +func TestNumShardsProcessor_GetNetworkNumShards(t *testing.T) { + t.Parallel() + + t.Run("context done should exit with timeout", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.TimeBetweenNodesRequestsInSec = 30 + args.NumShardsTimeoutInSec = 30 + + proc, err := NewNumShardsProcessor(args) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Millisecond * 200) + cancel() + }() + numShards, err := proc.GetNetworkNumShards(ctx) + require.Equal(t, errTimeIsOut, err) + require.Zero(t, numShards) + }) + t.Run("timeout should exit with timeout", func(t *testing.T) { + t.Parallel() + + args := createMockArgNumShardsProcessor() + args.TimeBetweenNodesRequestsInSec = 30 + args.NumShardsTimeoutInSec = 1 + + proc, err := NewNumShardsProcessor(args) + require.NoError(t, err) + numShards, err := proc.GetNetworkNumShards(context.Background()) + require.True(t, errors.Is(err, errTimeIsOut)) + require.Zero(t, numShards) + }) + t.Run("should work on 4th observer", func(t *testing.T) { + t.Parallel() + + providedBody := &networkConfigResponse{ + Data: networkConfigResponseData{ + Config: struct { + NumShards uint32 `json:"erd_num_shards_without_meta"` + }(struct{ NumShards uint32 }{NumShards: 2}), + }, + } + providedBodyBuff, _ := json.Marshal(providedBody) + + args := createMockArgNumShardsProcessor() + args.TimeBetweenNodesRequestsInSec = 1 + args.NumShardsTimeoutInSec = 15 + cnt := 0 + args.HttpClient = &mock.HttpClientMock{ + DoCalled: func(req *http.Request) (*http.Response, error) { + cnt++ + switch cnt { + case 1: // error on Do + return nil, errors.New("observer offline") + case 2: // status code not 200 + return &http.Response{ + StatusCode: http.StatusBadRequest, + }, nil + case 3: // status code ok, but invalid response + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("not the expected response")), + }, nil + default: // response ok + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(providedBodyBuff)), + }, nil + } + }, + } + + proc, err := NewNumShardsProcessor(args) + require.NoError(t, err) + numShards, err := proc.GetNetworkNumShards(context.Background()) + require.NoError(t, err) + require.Equal(t, uint32(2), numShards) + }) +} From 2e35868d64c61a233eea9d2d7a36cefc0267b135 Mon Sep 17 00:00:00 2001 From: Sorin Stanculeanu Date: Tue, 1 Oct 2024 11:56:36 +0300 Subject: [PATCH 2/2] fixes after review --- cmd/proxy/main.go | 48 ++++++++++++++++++++--------------- process/numShardsProcessor.go | 6 +---- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 21b66cb7..05295c1f 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -408,27 +408,7 @@ func createVersionsRegistry( return nil, err } - httpClient := &http.Client{} - httpClient.Timeout = time.Duration(cfg.GeneralSettings.RequestTimeoutSec) * time.Second - observersList := make([]string, 0, len(cfg.Observers)) - for _, node := range cfg.Observers { - observersList = append(observersList, node.Address) - } - argsNumShardsProcessor := process.ArgNumShardsProcessor{ - HttpClient: httpClient, - Observers: observersList, - TimeBetweenNodesRequestsInSec: cfg.GeneralSettings.TimeBetweenNodesRequestsInSec, - NumShardsTimeoutInSec: cfg.GeneralSettings.NumShardsTimeoutInSec, - RequestTimeoutInSec: cfg.GeneralSettings.RequestTimeoutSec, - } - numShardsProcessor, err := process.NewNumShardsProcessor(argsNumShardsProcessor) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithCancel(context.Background()) - numShards, err := numShardsProcessor.GetNetworkNumShards(ctx) - cancel() + numShards, err := getNumOfShards(cfg) if err != nil { return nil, err } @@ -636,6 +616,32 @@ func waitForServerShutdown(httpServer *http.Server, closableComponents *data.Clo _ = httpServer.Close() } +// getNumOfShards will delay the start of proxy until it successfully gets the number of shards +func getNumOfShards(cfg *config.Config) (uint32, error) { + httpClient := &http.Client{} + httpClient.Timeout = time.Duration(cfg.GeneralSettings.RequestTimeoutSec) * time.Second + observersList := make([]string, 0, len(cfg.Observers)) + for _, node := range cfg.Observers { + observersList = append(observersList, node.Address) + } + argsNumShardsProcessor := process.ArgNumShardsProcessor{ + HttpClient: httpClient, + Observers: observersList, + TimeBetweenNodesRequestsInSec: cfg.GeneralSettings.TimeBetweenNodesRequestsInSec, + NumShardsTimeoutInSec: cfg.GeneralSettings.NumShardsTimeoutInSec, + RequestTimeoutInSec: cfg.GeneralSettings.RequestTimeoutSec, + } + numShardsProcessor, err := process.NewNumShardsProcessor(argsNumShardsProcessor) + if err != nil { + return 0, err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + return numShardsProcessor.GetNetworkNumShards(ctx) +} + func removeLogColors() { err := logger.RemoveLogObserver(os.Stdout) if err != nil { diff --git a/process/numShardsProcessor.go b/process/numShardsProcessor.go index 8a9d907c..3062e361 100644 --- a/process/numShardsProcessor.go +++ b/process/numShardsProcessor.go @@ -15,10 +15,6 @@ import ( var errTimeIsOut = errors.New("time is out") -const ( - networkConfigPath = "/network/config" -) - type networkConfigResponseData struct { Config struct { NumShards uint32 `json:"erd_num_shards_without_meta"` @@ -112,7 +108,7 @@ func (processor *numShardsProcessor) tryGetnumShardsFromObserver(observerAddress ctx, cancel := context.WithTimeout(context.Background(), processor.requestTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, observerAddress+networkConfigPath, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, observerAddress+NetworkConfigPath, nil) if err != nil { return 0, http.StatusNotFound }