From 6933a3c1106431e3eb199a663e959ac807e674e8 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Sat, 17 Aug 2024 16:41:57 -0300 Subject: [PATCH] feat: rest proxy (#13) - move the seed updating logic to its own thing - add rest nodes to the seed updater - reused the previous proxy implementation for both rpc and rest calls - fixed status page to show both kinds of nodes - fixed a possible nil access when response is nil (introduced in #12) closes #9 --- .gitignore | 1 + index.html | 44 +++++++------- internal/proxy/proxy.go | 78 +++++++++++++++---------- internal/proxy/proxy_test.go | 107 +++++++++++++++++----------------- internal/proxy/server.go | 22 +++---- internal/seed/seed.go | 9 +-- internal/seed/updater.go | 56 ++++++++++++++++++ internal/seed/updater_test.go | 78 +++++++++++++++++++++++++ main.go | 30 +++++++--- 9 files changed, 297 insertions(+), 128 deletions(-) create mode 100644 internal/seed/updater.go create mode 100644 internal/seed/updater_test.go diff --git a/.gitignore b/.gitignore index b0b9f5b..60b6a4f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pem dist/ +akash diff --git a/index.html b/index.html index fa68dce..f669cc2 100644 --- a/index.html +++ b/index.html @@ -10,33 +10,35 @@

Akash Proxy

- - - + + + + - {{ range .}} - - - - - - - - + {{ range $key, $value := . }} + {{ range $value }} + + + + + + + + + {{ end }} {{ end }}
NameURLAverage response timeServer Request CountAvg response time Error Rate StatusKind
{{.Name}}{{.URL}}{{.Avg}}{{.Requests}}{{.ErrorRate}}% - - {{ if not .Initialized}} - initializing - {{ else if .Degraded }} - degraded - {{else}} - OK - {{end}} -
{{ .Name }}{{ .Requests }}{{ .Avg }}{{ .ErrorRate }}% + {{ if not .Initialized }} + initializing + {{ else if .Degraded }} + degraded + {{ else }} + OK + {{ end }} + {{ $key }}
diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4199f19..8bfd068 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -7,23 +7,38 @@ import ( "net/http" "slices" "sort" + "strings" "sync" "sync/atomic" - "time" "github.com/akash-network/rpc-proxy/internal/config" "github.com/akash-network/rpc-proxy/internal/seed" ) -func New(cfg config.Config) *Proxy { +type ProxyKind uint8 + +const ( + RPC ProxyKind = iota + Rest ProxyKind = iota +) + +func New( + kind ProxyKind, + ch chan seed.Seed, + cfg config.Config, +) *Proxy { return &Proxy{ - cfg: cfg, + cfg: cfg, + ch: ch, + kind: kind, } } type Proxy struct { cfg config.Config + kind ProxyKind init sync.Once + ch chan seed.Seed round int mu sync.Mutex @@ -61,10 +76,16 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + switch p.kind { + case RPC: + r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rpc") + case Rest: + r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rest") + } + if srv := p.next(); srv != nil { srv.ServeHTTP(w, r) return - } slog.Error("no servers available") w.WriteHeader(http.StatusInternalServerError) @@ -90,17 +111,30 @@ func (p *Proxy) next() *Server { return p.next() } -func (p *Proxy) update(rpcs []seed.RPC) error { +func (p *Proxy) update(seed seed.Seed) { + var err error + switch p.kind { + case RPC: + err = p.doUpdate(seed.APIs.RPC) + case Rest: + err = p.doUpdate(seed.APIs.Rest) + } + if err != nil { + slog.Error("could not update seed", "err", err) + } +} + +func (p *Proxy) doUpdate(providers []seed.Provider) error { p.mu.Lock() defer p.mu.Unlock() // add new servers - for _, rpc := range rpcs { - idx := slices.IndexFunc(p.servers, func(srv *Server) bool { return srv.name == rpc.Provider }) + for _, provider := range providers { + idx := slices.IndexFunc(p.servers, func(srv *Server) bool { return srv.name == provider.Provider }) if idx == -1 { srv, err := newServer( - rpc.Provider, - rpc.Address, + provider.Provider, + provider.Address, p.cfg, ) if err != nil { @@ -112,8 +146,8 @@ func (p *Proxy) update(rpcs []seed.RPC) error { // remove deleted servers p.servers = slices.DeleteFunc(p.servers, func(srv *Server) bool { - for _, rpc := range rpcs { - if rpc.Provider == srv.name { + for _, provider := range providers { + if provider.Provider == srv.name { return false } } @@ -129,33 +163,15 @@ func (p *Proxy) update(rpcs []seed.RPC) error { func (p *Proxy) Start(ctx context.Context) { p.init.Do(func() { go func() { - t := time.NewTicker(p.cfg.SeedRefreshInterval) - defer t.Stop() for { select { - case <-t.C: - p.fetchAndUpdate() + case seed := <-p.ch: + p.update(seed) case <-ctx.Done(): p.shuttingDown.Store(true) return } } }() - p.fetchAndUpdate() }) } - -func (p *Proxy) fetchAndUpdate() { - result, err := seed.Fetch(p.cfg.SeedURL) - if err != nil { - slog.Error("could not get initial seed list", "err", err) - return - } - if result.ChainID != p.cfg.ChainID { - slog.Error("chain ID is different than expected", "got", result.ChainID, "expected", p.cfg.ChainID) - return - } - if err := p.update(result.Apis.RPC); err != nil { - slog.Error("could not update servers", "err", err) - } -} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 7ad2d45..b9e0067 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -17,53 +16,33 @@ import ( ) func TestProxy(t *testing.T) { - const chainID = "unittest" + for name, kind := range map[string]ProxyKind{ + "rpc": RPC, + "rest": Rest, + } { + t.Run(name, func(t *testing.T) { + testProxy(t, kind) + }) + } +} + +func testProxy(tb testing.TB, kind ProxyKind) { srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "srv1 replied") })) - t.Cleanup(srv1.Close) + tb.Cleanup(srv1.Close) srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(time.Millisecond * 500) _, _ = io.WriteString(w, "srv2 replied") })) - t.Cleanup(srv2.Close) + tb.Cleanup(srv2.Close) srv3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) })) - t.Cleanup(srv2.Close) - - seed := seed.Seed{ - ChainID: chainID, - Apis: seed.Apis{ - RPC: []seed.RPC{ - { - Address: srv1.URL, - Provider: "srv1", - }, - { - Address: srv2.URL, - Provider: "srv2", - }, - { - Address: srv3.URL, - Provider: "srv3", - }, - }, - }, - } - - t.Logf("%+v", seed) - - seedSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - bts, _ := json.Marshal(seed) - _, _ = w.Write(bts) - })) - t.Cleanup(seedSrv.Close) + tb.Cleanup(srv2.Close) - proxy := New(config.Config{ - SeedURL: seedSrv.URL, - SeedRefreshInterval: 500 * time.Millisecond, - ChainID: chainID, + ch := make(chan seed.Seed, 1) + proxy := New(kind, ch, config.Config{ HealthyThreshold: 10 * time.Millisecond, ProxyRequestTimeout: time.Second, UnhealthyServerRecoverChancePct: 1, @@ -72,19 +51,43 @@ func TestProxy(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) + tb.Cleanup(cancel) proxy.Start(ctx) - require.Len(t, proxy.servers, 3) + serverList := []seed.Provider{ + { + Address: srv1.URL, + Provider: "srv1", + }, + { + Address: srv2.URL, + Provider: "srv2", + }, + { + Address: srv3.URL, + Provider: "srv3", + }, + } + + ch <- seed.Seed{ + APIs: seed.Apis{ + Rest: serverList, + RPC: serverList, + }, + } + + require.Eventually(tb, func() bool { return proxy.initialized.Load() }, time.Second, time.Millisecond) + + require.Len(tb, proxy.servers, 3) proxySrv := httptest.NewServer(proxy) - t.Cleanup(proxySrv.Close) + tb.Cleanup(proxySrv.Close) var wg errgroup.Group wg.SetLimit(20) for i := 0; i < 100; i++ { wg.Go(func() error { - t.Log("go") + tb.Log("go") req, err := http.NewRequest(http.MethodGet, proxySrv.URL, nil) if err != nil { return err @@ -102,13 +105,13 @@ func TestProxy(t *testing.T) { return nil }) } - require.NoError(t, wg.Wait()) + require.NoError(tb, wg.Wait()) // stop the proxy cancel() stats := proxy.Stats() - require.Len(t, stats, 3) + require.Len(tb, stats, 3) var srv1Stats ServerStat var srv2Stats ServerStat @@ -124,13 +127,13 @@ func TestProxy(t *testing.T) { srv3Stats = st } } - require.Zero(t, srv1Stats.ErrorRate) - require.Zero(t, srv2Stats.ErrorRate) - require.Equal(t, float64(100), srv3Stats.ErrorRate) - require.Greater(t, srv1Stats.Requests, srv2Stats.Requests) - require.Greater(t, srv2Stats.Avg, srv1Stats.Avg) - require.False(t, srv1Stats.Degraded) - require.True(t, srv2Stats.Degraded) - require.True(t, srv1Stats.Initialized) - require.True(t, srv2Stats.Initialized) + require.Zero(tb, srv1Stats.ErrorRate) + require.Zero(tb, srv2Stats.ErrorRate) + require.Equal(tb, float64(100), srv3Stats.ErrorRate) + require.Greater(tb, srv1Stats.Requests, srv2Stats.Requests) + require.Greater(tb, srv2Stats.Avg, srv1Stats.Avg) + require.False(tb, srv1Stats.Degraded) + require.True(tb, srv2Stats.Degraded) + require.True(tb, srv1Stats.Initialized) + require.True(tb, srv2Stats.Initialized) } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 2c2f21a..afebf6e 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -7,7 +7,6 @@ import ( "log/slog" "net/http" "net/url" - "strings" "sync/atomic" "time" @@ -52,11 +51,12 @@ func (s *Server) ErrorRate() float64 { } func (s *Server) Healthy() bool { - return s.pings.Last() < s.cfg.HealthyThreshold + return s.pings.Last() < s.cfg.HealthyThreshold && + s.ErrorRate() < s.cfg.HealthyErrorRateThreshold } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var status int + var status int = -1 start := time.Now() defer func() { d := time.Since(start) @@ -64,14 +64,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { slog.Info("request done", "name", s.name, "avg", avg, "last", d, "status", status) }() + path := r.URL.Path proxiedURL := r.URL + proxiedURL.Path = s.url.Path + path proxiedURL.Host = s.url.Host proxiedURL.Scheme = s.url.Scheme - if !strings.HasSuffix(s.url.Path, "/rpc") { - proxiedURL.Path = strings.TrimSuffix(proxiedURL.Path, "/rpc") - } - slog.Info("proxying request", "name", s.name, "url", proxiedURL) rr := &http.Request{ @@ -87,7 +85,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer cancel() resp, err := http.DefaultClient.Do(rr.WithContext(ctx)) - status = resp.StatusCode + if resp != nil { + status = resp.StatusCode + } if err == nil { defer resp.Body.Close() for k, v := range resp.Header { @@ -102,10 +102,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } s.requestCount.Add(1) - if resp.StatusCode >= 200 && resp.StatusCode <= 300 { - s.successes.Append(resp.StatusCode, s.cfg.HealthyErrorRateBucketTimeout) + if status == 0 || (status >= 200 && status <= 300) { + s.successes.Append(status, s.cfg.HealthyErrorRateBucketTimeout) } else { - s.failures.Append(resp.StatusCode, s.cfg.HealthyErrorRateBucketTimeout) + s.failures.Append(status, s.cfg.HealthyErrorRateBucketTimeout) } if !s.Healthy() && ctx.Err() == nil && err == nil { diff --git a/internal/seed/seed.go b/internal/seed/seed.go index 0a7e56e..93ce094 100644 --- a/internal/seed/seed.go +++ b/internal/seed/seed.go @@ -10,19 +10,20 @@ import ( type Seed struct { Status string `json:"status"` ChainID string `json:"chain_id"` - Apis Apis `json:"apis"` + APIs Apis `json:"apis"` } -type RPC struct { +type Provider struct { Address string `json:"address"` Provider string `json:"provider"` } type Apis struct { - RPC []RPC `json:"rpc"` + RPC []Provider `json:"rpc"` + Rest []Provider `json:"rest"` } -func Fetch(url string) (Seed, error) { +func fetch(url string) (Seed, error) { var seed Seed resp, err := http.Get(url) if err != nil { diff --git a/internal/seed/updater.go b/internal/seed/updater.go new file mode 100644 index 0000000..49b59ab --- /dev/null +++ b/internal/seed/updater.go @@ -0,0 +1,56 @@ +package seed + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/akash-network/rpc-proxy/internal/config" +) + +type Updater struct { + cfg config.Config + listeners []chan<- Seed + init sync.Once +} + +func New(cfg config.Config, listeners ...chan<- Seed) *Updater { + return &Updater{ + cfg: cfg, + listeners: listeners, + } +} + +func (u *Updater) Start(ctx context.Context) { + u.init.Do(func() { + go func() { + t := time.NewTicker(u.cfg.SeedRefreshInterval) + defer t.Stop() + for { + select { + case <-t.C: + u.fetchAndUpdate() + case <-ctx.Done(): + return + } + } + }() + u.fetchAndUpdate() + }) +} + +func (u *Updater) fetchAndUpdate() { + result, err := fetch(u.cfg.SeedURL) + if err != nil { + slog.Error("could not get initial seed list", "err", err) + return + } + if result.ChainID != u.cfg.ChainID { + slog.Error("chain ID is different than expected", "got", result.ChainID, "expected", u.cfg.ChainID) + return + } + for _, ch := range u.listeners { + ch <- result + } +} diff --git a/internal/seed/updater_test.go b/internal/seed/updater_test.go new file mode 100644 index 0000000..21feca9 --- /dev/null +++ b/internal/seed/updater_test.go @@ -0,0 +1,78 @@ +package seed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/akash-network/rpc-proxy/internal/config" + "github.com/stretchr/testify/require" +) + +func TestUpdater(t *testing.T) { + chainID := "test" + seed := Seed{ + ChainID: chainID, + APIs: Apis{ + RPC: []Provider{ + { + Address: "http://rpc.local", + Provider: "rpc-provider", + }, + }, + Rest: []Provider{ + { + Address: "http://rest.local", + Provider: "rest-provider", + }, + }, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bts, _ := json.Marshal(seed) + _, _ = w.Write(bts) + })) + t.Cleanup(srv.Close) + + rpc := make(chan Seed, 1) + rest := make(chan Seed, 1) + + up := New(config.Config{ + SeedRefreshInterval: time.Millisecond, + SeedURL: srv.URL, + ChainID: chainID, + }, rpc, rest) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + up.Start(ctx) + + go func() { + time.Sleep(time.Millisecond * 500) + cancel() + }() + + var rpcUpdates, restUpdates atomic.Uint32 + +outer: + for { + select { + case got := <-rpc: + rpcUpdates.Add(1) + require.Equal(t, seed, got) + case got := <-rest: + restUpdates.Add(1) + require.Equal(t, seed, got) + case <-ctx.Done(): + break outer + } + } + + require.NotZero(t, rpcUpdates.Load()) + require.NotZero(t, restUpdates.Load()) +} diff --git a/main.go b/main.go index feca581..e68ccba 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/akash-network/rpc-proxy/internal/config" "github.com/akash-network/rpc-proxy/internal/proxy" + "github.com/akash-network/rpc-proxy/internal/seed" "golang.org/x/crypto/acme/autocert" ) @@ -34,28 +35,39 @@ func main() { am.HostPolicy = autocert.HostWhitelist(hosts...) } - proxyHandler := proxy.New(cfg) + rpcListener := make(chan seed.Seed, 1) + restListener := make(chan seed.Seed, 1) - proxyCtx, proxyCtxCancel := context.WithCancel(context.Background()) + updater := seed.New(cfg, rpcListener, restListener) + rpcProxyHandler := proxy.New(proxy.RPC, rpcListener, cfg) + restProxyHandler := proxy.New(proxy.Rest, restListener, cfg) + + ctx, proxyCtxCancel := context.WithCancel(context.Background()) defer proxyCtxCancel() - proxyHandler.Start(proxyCtx) + updater.Start(ctx) + rpcProxyHandler.Start(ctx) + restProxyHandler.Start(ctx) indexTpl := template.Must(template.New("stats").Parse(string(index))) m := http.NewServeMux() m.Handle("/health/ready", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !proxyHandler.Ready() { + if !rpcProxyHandler.Ready() || !restProxyHandler.Ready() { w.WriteHeader(http.StatusServiceUnavailable) } })) m.Handle("/health/live", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !proxyHandler.Live() { + if !rpcProxyHandler.Live() || !restProxyHandler.Live() { w.WriteHeader(http.StatusServiceUnavailable) } })) - m.Handle("/rpc", proxyHandler) + m.Handle("/rpc", rpcProxyHandler) + m.Handle("/rest", restProxyHandler) m.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := indexTpl.Execute(w, proxyHandler.Stats()); err != nil { + if err := indexTpl.Execute(w, map[string][]proxy.ServerStat{ + "RPC": rpcProxyHandler.Stats(), + "Rest": restProxyHandler.Stats(), + }); err != nil { slog.Error("could render stats", "err", err) } })) @@ -97,9 +109,9 @@ func main() { proxyCtxCancel() - proxyCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err := srv.Shutdown(proxyCtx); err != nil { + if err := srv.Shutdown(ctx); err != nil { slog.Error("could not close server", "err", err) os.Exit(1) }