diff --git a/api/apihttp/apihttp.go b/api/apihttp/apihttp.go index ed84ad759..e42a6b9b5 100644 --- a/api/apihttp/apihttp.go +++ b/api/apihttp/apihttp.go @@ -18,7 +18,6 @@ package apihttp import ( - "bytes" "encoding/json" "net/http" "time" @@ -37,33 +36,26 @@ type HealthCheckResponse struct { // HealthCheckHandler checks the system status and returns it accordinly. // The http call it answer is: -// GET /health-check +// HEAD / // // The following statuses are expected: // -// If everything is alright, the HTTP status is 200 and the body contains: -// {"version": "0", "status":"ok"} +// If everything is alright, the HTTP response will have a 204 status code +// and no body. func HealthCheckHandler(w http.ResponseWriter, r *http.Request) { metrics.QedAPIHealthcheckRequestsTotal.Inc() - result := HealthCheckResponse{ - Version: 0, - Status: "ok", + // Make sure we can only be called with an HTTP POST request. + if r.Method != "HEAD" { + w.Header().Set("Allow", "HEAD") + w.WriteHeader(http.StatusMethodNotAllowed) + return } - resultJson, _ := json.Marshal(result) - // A very simple health check. - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNoContent) - // In the future we could report back on the status of our DB, or our cache - // (e.g. Redis) by performing a simple PING, and include them in the response. - out := new(bytes.Buffer) - _ = json.Compact(out, resultJson) - - _, _ = w.Write(out.Bytes()) } // Add posts an event into the system: @@ -304,7 +296,7 @@ func AuthHandlerMiddleware(handler http.HandlerFunc) http.HandlerFunc { func NewApiHttp(balloon raftwal.RaftBalloonApi) *http.ServeMux { api := http.NewServeMux() - api.HandleFunc("/health-check", AuthHandlerMiddleware(HealthCheckHandler)) + api.HandleFunc("/healthcheck", AuthHandlerMiddleware(HealthCheckHandler)) api.HandleFunc("/events", AuthHandlerMiddleware(Add(balloon))) api.HandleFunc("/proofs/membership", AuthHandlerMiddleware(Membership(balloon))) api.HandleFunc("/proofs/digest-membership", AuthHandlerMiddleware(DigestMembership(balloon))) @@ -343,9 +335,12 @@ func LogHandler(handle http.Handler) http.HandlerFunc { latency := time.Now().Sub(start) log.Debugf("Request: lat %d %+v", latency, request) - if writer.status >= 400 { + if writer.status >= 400 && writer.status < 500 { log.Infof("Bad Request: %d %+v", latency, request) } + if writer.status >= 500 { + log.Infof("Server error: %d %+v", latency, request) + } } } diff --git a/api/apihttp/apihttp_test.go b/api/apihttp/apihttp_test.go index d0cbb5a95..e0734efaf 100644 --- a/api/apihttp/apihttp_test.go +++ b/api/apihttp/apihttp_test.go @@ -97,7 +97,7 @@ func (b fakeRaftBalloon) Info() map[string]interface{} { func TestHealthCheckHandler(t *testing.T) { // Create a request to pass to our handler. We don't have any query parameters for now, so we'll // pass 'nil' as the third parameter. - req, err := http.NewRequest("GET", "/health-check", nil) + req, err := http.NewRequest("HEAD", "/healthcheck", nil) if err != nil { t.Fatal(err) } @@ -111,16 +111,15 @@ func TestHealthCheckHandler(t *testing.T) { handler.ServeHTTP(rr, req) // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { + if status := rr.Code; status != http.StatusNoContent { t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) + status, http.StatusNoContent) } // Check the response body is what we expect. - expected := `{"version":0,"status":"ok"}` - if rr.Body.String() != expected { + if rr.Body.String() != "" { t.Errorf("handler returned unexpected body: got %v want %v", - rr.Body.String(), expected) + rr.Body.String(), "") } } @@ -296,7 +295,7 @@ func TestIncremental(t *testing.T) { func TestAuthHandlerMiddleware(t *testing.T) { - req, err := http.NewRequest("GET", "/health-check", nil) + req, err := http.NewRequest("HEAD", "/healthcheck", nil) if err != nil { t.Fatal(err) } @@ -313,9 +312,9 @@ func TestAuthHandlerMiddleware(t *testing.T) { handler.ServeHTTP(rr, req) // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { + if status := rr.Code; status != http.StatusNoContent { t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) + status, http.StatusNoContent) } } diff --git a/balloon/balloon.go b/balloon/balloon.go index 45856ad93..f50c429ee 100644 --- a/balloon/balloon.go +++ b/balloon/balloon.go @@ -284,7 +284,7 @@ func (b Balloon) QueryDigestMembership(keyDigest hashing.Digest, version uint64) historyProof, historyErr = b.historyTree.ProveMembership(proof.ActualVersion, version) }() } else { - return nil, fmt.Errorf("query version %d is not on history tree which version is %d", version, proof.ActualVersion) + return nil, fmt.Errorf("query version %d is greater than the actual version which is %d", version, proof.ActualVersion) } } diff --git a/balloon/balloon_test.go b/balloon/balloon_test.go index f109db9ae..03e338504 100644 --- a/balloon/balloon_test.go +++ b/balloon/balloon_test.go @@ -114,7 +114,7 @@ func TestQueryConsistencyProof(t *testing.T) { balloon, err := NewBalloon(store, hashing.NewFakeXorHasher) require.NoError(t, err) - for j := 0; j <= int(c.addtions); j++ { + for j := 0; j <= int(c.additions); j++ { _, mutations, err := balloon.Add(util.Uint64AsBytes(uint64(j))) require.NoErrorf(t, err, "Error adding event %d", j) store.Mutate(mutations) diff --git a/client/backoff.go b/client/backoff.go new file mode 100644 index 000000000..a62cd3e12 --- /dev/null +++ b/client/backoff.go @@ -0,0 +1,120 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "math" + "math/rand" + "sync" + "time" +) + +// BackoffF specifies the signature of a function that returns the +// time to wait before the next call to a resource. To stop retrying +// return false in the 2nd return value. +type BackoffF func(attempt int) (time.Duration, bool) + +// Backoff allows callers to implement their own Backoff strategy. +type Backoff interface { + // Next implements a BackoffF. + Next(attempt int) (time.Duration, bool) +} + +// StopBackoff is a fixed backoff policy that always returns false for +// Next(), meaning that the operation should never be retried. +type StopBackoff struct{} + +// NewStopBackoff returns a new StopBackoff. +func NewStopBackoff() *StopBackoff { + return &StopBackoff{} +} + +// Next implements BackoffF for StopBackoff. +func (b StopBackoff) Next(attempt int) (time.Duration, bool) { + return 0, false +} + +// ConstantBackoff is a backoff policy that always returns the same delay. +type ConstantBackoff struct { + interval time.Duration +} + +// NewConstantBackoff returns a new ConstantBackoff. +func NewConstantBackoff(interval time.Duration) *ConstantBackoff { + return &ConstantBackoff{interval: interval} +} + +// Next implements BackoffF for ConstantBackoff. +func (b *ConstantBackoff) Next(attempt int) (time.Duration, bool) { + return b.interval, true +} + +// SimpleBackoff takes a list of fixed values for backoff intervals. +// Each call to Next returns the next value from that fixed list. +// After each value is returned, subsequent calls to Next will only return +// the last element. +type SimpleBackoff struct { + sync.Mutex + ticks []int +} + +// NewSimpleBackoff creates a SimpleBackoff algorithm with the specified +// list of fixed intervals in milliseconds. +func NewSimpleBackoff(ticks ...int) *SimpleBackoff { + return &SimpleBackoff{ticks: ticks} +} + +// Next implements BackoffF for SimpleBackoff. +func (b *SimpleBackoff) Next(attempt int) (time.Duration, bool) { + b.Lock() + defer b.Unlock() + if attempt >= len(b.ticks) { + return 0, false + } + ms := b.ticks[attempt] + return time.Duration(ms) * time.Millisecond, true +} + +// ExponentialBackoff implements the simple exponential backoff described by +// Douglas Thain at http://dthain.blogspot.de/2009/02/exponential-backoff-in-distributed.html. +type ExponentialBackoff struct { + t float64 // initial timeout (in msec) + f float64 // exponential factor (e.g. 2) + m float64 // maximum timeout (in msec) +} + +// NewExponentialBackoff returns a ExponentialBackoff backoff policy. +// Use initialTimeout to set the first/minimal interval +// and maxTimeout to set the maximum wait interval. +func NewExponentialBackoff(initialTimeout, maxTimeout time.Duration) *ExponentialBackoff { + return &ExponentialBackoff{ + t: float64(int64(initialTimeout / time.Millisecond)), + f: 2.0, + m: float64(int64(maxTimeout / time.Millisecond)), + } +} + +// Next implements BackoffF for ExponentialBackoff. +func (b *ExponentialBackoff) Next(attempt int) (time.Duration, bool) { + r := 1.0 + rand.Float64() // random number in [1..2] + m := math.Min(r*b.t*math.Pow(b.f, float64(attempt)), b.m) + if m >= b.m { + return 0, false + } + d := time.Duration(int64(m)) * time.Millisecond + return d, true +} diff --git a/client/backoff_test.go b/client/backoff_test.go new file mode 100644 index 000000000..156475cd5 --- /dev/null +++ b/client/backoff_test.go @@ -0,0 +1,112 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStopBackoff(t *testing.T) { + b := NewStopBackoff() + _, ok := b.Next(0) + require.False(t, ok) +} + +func TestConstantBackoff(t *testing.T) { + b := NewConstantBackoff(time.Second) + d, ok := b.Next(0) + require.True(t, ok) + require.Equal(t, time.Second, d) +} + +func TestSimpleBackoff(t *testing.T) { + + testCases := []struct { + Duration time.Duration + Continue bool + }{ + { + Duration: 1 * time.Millisecond, + Continue: true, + }, + { + Duration: 2 * time.Millisecond, + Continue: true, + }, + { + Duration: 7 * time.Millisecond, + Continue: true, + }, + { + Duration: 0, + Continue: false, + }, + { + Duration: 0, + Continue: false, + }, + } + + b := NewSimpleBackoff(1, 2, 7) + + for i, c := range testCases { + d, ok := b.Next(i) + require.Equalf(t, c.Continue, ok, "The continue value should match for test case %d", i) + require.Equalf(t, c.Duration, d, "The duration value should match for test case %d", i) + } +} + +func TestExponentialBackoff(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + min := time.Duration(8) * time.Millisecond + max := time.Duration(256) * time.Millisecond + b := NewExponentialBackoff(min, max) + + between := func(value time.Duration, a, b int) bool { + x := int(value / time.Millisecond) + return a <= x && x <= b + } + + d, ok := b.Next(0) + require.True(t, ok) + require.True(t, between(d, 8, 256)) + + d, ok = b.Next(1) + require.True(t, ok) + require.True(t, between(d, 8, 256)) + + d, ok = b.Next(3) + require.True(t, ok) + require.True(t, between(d, 8, 256)) + + d, ok = b.Next(4) + require.True(t, ok) + require.True(t, between(d, 8, 256)) + + _, ok = b.Next(5) + require.False(t, ok) + + _, ok = b.Next(6) + require.False(t, ok) + +} diff --git a/client/client.go b/client/client.go index cfbdbe105..e5a073e2a 100644 --- a/client/client.go +++ b/client/client.go @@ -14,20 +14,18 @@ limitations under the License. */ -// Package client implements the command line interface to interact with -// the REST API +// Package client implements the client to interact with QED servers. package client import ( - "bytes" - "crypto/tls" + "context" "encoding/json" + "errors" "fmt" "io/ioutil" - "math/rand" - "net" "net/http" "net/url" + "sync" "time" "github.com/bbva/qed/balloon" @@ -36,175 +34,397 @@ import ( "github.com/bbva/qed/protocol" ) -// HTTPClient ist the stuct that has the required information for the cli. +// HTTPClient is an HTTP QED client. type HTTPClient struct { - conf *Config - *http.Client - topology Topology + httpClient *http.Client + retrier RequestRetrier + topology *topology + apiKey string + readPreference ReadPref + maxRetries int + healthcheckEnabled bool + healthcheckTimeout time.Duration + discoveryEnabled bool + discoveryTimeout time.Duration + + mu sync.RWMutex // guards the next block + running bool + healthcheckStopCh chan bool // notify healthchecker to stop, and notify back + discoveryStopCh chan bool // notify sniffer to stop, and notify back } -// NewHTTPClient will return a new instance of HTTPClient. -func NewHTTPClient(conf Config) *HTTPClient { - var tlsConf *tls.Config +// NewSimpleHTTPClient creates a new short-lived client thath can be +// used in use cases where you need one client per request. +// +// All checks are disabled, including timeouts and periodic checks. +// The number of retries is set to 0. +func NewSimpleHTTPClient(httpClient *http.Client, urls []string) (*HTTPClient, error) { + + // defaultTransport := http.DefaultTransport.(*http.Transport) + // // Create new Transport that ignores self-signed SSL + // transportWithSelfSignedTLS := &http.Transport{ + // Proxy: defaultTransport.Proxy, + // DialContext: defaultTransport.DialContext, + // MaxIdleConns: defaultTransport.MaxIdleConns, + // IdleConnTimeout: defaultTransport.IdleConnTimeout, + // ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, + // TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout, + // TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}, + // } + + // httpClient := &http.Client{ + // Timeout: DefaultTimeout, + // Transport: transportWithSelfSignedTLS, + // } + + if len(urls) == 0 { + return nil, errors.New("Invalid urls") + } - if conf.Insecure { - tlsConf = &tls.Config{InsecureSkipVerify: true} - } else { - tlsConf = &tls.Config{} + if httpClient == nil { + httpClient = http.DefaultClient } client := &HTTPClient{ - &conf, - &http.Client{ - Timeout: time.Second * 10, - Transport: &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - }).Dial, - TLSClientConfig: tlsConf, - TLSHandshakeTimeout: 5 * time.Second, - }, - }, - Topology{}, + httpClient: httpClient, + topology: newTopology(false), + healthcheckEnabled: false, + healthcheckTimeout: off, + discoveryEnabled: false, + discoveryTimeout: off, + readPreference: Primary, + maxRetries: 0, + retrier: NewNoRequestRetrier(httpClient), } - // Initial topology assignment - client.topology.Leader = conf.Endpoints[0] - client.topology.Endpoints = conf.Endpoints + client.topology.Update(urls[0], urls[1:]...) - var info map[string]interface{} - var err error + return client, nil +} - info, err = client.getClusterInfo() +// NewHTTPClientFromConfig initializes a client from a configuration. +func NewHTTPClientFromConfig(conf *Config) (*HTTPClient, error) { + options, err := configToOptions(conf) if err != nil { - log.Errorf("Failed to get raft cluster info. Error: %v", err) - return nil + return nil, err } + return NewHTTPClient(options...) +} - client.updateTopology(info) +// NewHTTPClient creates a new HTTP client to work with QED. +// +// The client, by default, is meant to be long-lived and shared across +// your application. If you need a short-lived client, e.g. for request-scope, +// consider using NewSimpleHttpClient instead. +// +func NewHTTPClient(options ...HTTPClientOptionF) (*HTTPClient, error) { - return client -} + client := &HTTPClient{ + httpClient: http.DefaultClient, + topology: newTopology(false), + healthcheckEnabled: DefaultHealthCheckEnabled, + healthcheckTimeout: DefaultHealthCheckTimeout, + discoveryEnabled: DefaultTopologyDiscoveryEnabled, + discoveryTimeout: DefaultTopologyDiscoveryTimeout, + readPreference: Primary, + maxRetries: DefaultMaxRetries, + } -func (c *HTTPClient) exponentialBackoff(req *http.Request) (*http.Response, error) { + // Run the options on the client + for _, option := range options { + if err := option(client); err != nil { + return nil, err + } + } - var retries uint + // configure retrier + client.setRetrier(client.maxRetries) - for { - resp, err := c.Do(req) - if err != nil { - if retries == 5 { - return nil, err - } - retries = retries + 1 - delay := time.Duration(10 << retries * time.Millisecond) - time.Sleep(delay) - continue + // Initial topology assignment + if client.discoveryEnabled { + // try to discover the cluster topology initially + if err := client.discover(); err != nil { + return nil, err } - return resp, err } + + if client.healthcheckEnabled { + // perform an initial healthcheck + client.healthCheck(client.healthcheckTimeout) + } + + // Ensure thath we have at least one endpoint, the primary, available + if !client.topology.HasActivePrimary() { + return nil, ErrNoPrimary + } + + // if t.discoveryEnabled { + // go t.startDiscoverer() // periodically update cluster information + // } + // if t.healthcheckEnabled { + // go c.startHealthChecker() // periodically ping all nodes of the cluster + // } + + client.running = true + return client, nil } -func (c HTTPClient) getClusterInfo() (map[string]interface{}, error) { - var retries uint - info := make(map[string]interface{}) +// Close stops the background processes that the client is running, +// i.e. sniffing the cluster periodically and running health checks +// on the nodes. +// +// If the background processes are not running, this is a no-op. +func (c *HTTPClient) Close() { + c.mu.RLock() + if !c.running { + c.mu.RUnlock() + return + } + c.mu.RUnlock() - for { - body, err := c.doReq("GET", "/info/shards", []byte{}) + log.Info("Closing QED client...") - if err != nil { - log.Debugf("Failed to get raft cluster info through server %s. Error: %v", - c.topology.Leader, err) - if retries == 5 { - return nil, err - } - c.topology.Leader = c.topology.Endpoints[rand.Intn(len(c.topology.Endpoints))] - retries = retries + 1 - delay := time.Duration(10 << retries * time.Millisecond) - time.Sleep(delay) - continue - } + if c.healthcheckEnabled { + c.healthcheckStopCh <- true + <-c.healthcheckStopCh + } - err = json.Unmarshal(body, &info) - if err != nil { - return nil, err - } + if c.discoveryEnabled { + c.discoveryStopCh <- true + <-c.discoveryStopCh + } + + c.mu.Lock() + if c.topology != nil { + c.topology = nil + } + c.running = false + c.mu.Unlock() + + log.Info("QED client closed") - return info, err +} + +func (c *HTTPClient) setRetrier(maxRetries int) error { + if maxRetries < 0 { + return errors.New("MaxRetries must be greater than or equal to 0") } + if maxRetries == 0 { + c.retrier = NewNoRequestRetrier(c.httpClient) + } else { + // Create a Retrier that will wait for 100ms between requests. + ticks := make([]int, maxRetries) + for i := 0; i < len(ticks); i++ { + ticks[i] = 100 + } + backoff := NewSimpleBackoff(ticks...) + c.retrier = NewBackoffRequestRetrier(c.httpClient, c.maxRetries, backoff) + } + return nil } -func (c *HTTPClient) updateTopology(info map[string]interface{}) { +// startDiscoverer periodically runs discover. +func (c *HTTPClient) startDiscoverer() { + c.mu.RLock() - clusterMeta := info["meta"].(map[string]interface{}) - leaderID := info["leaderID"].(string) - scheme := info["URIScheme"].(string) +} - var leaderAddr string - var endpoints []string +func (c *HTTPClient) callPrimary(method, path string, data []byte) ([]byte, error) { - leaderMeta := clusterMeta[leaderID].(map[string]interface{}) - for k, addr := range leaderMeta { - if k == "HTTPAddr" { - leaderAddr = scheme + addr.(string) + var endpoint *endpoint + var err error + var retried bool + for { + // we always send POST requests to the primary endpoint + endpoint, err = c.topology.Primary() + if err != nil { + if !retried && c.discoveryEnabled { + c.discover() + retried = true + continue + } + return nil, err } + + if !retried && endpoint.IsDead() { + if c.healthcheckEnabled { + c.healthCheck(c.healthcheckTimeout) + } + retried = true + continue + } + break } - c.topology.Leader = leaderAddr + return c.doReq(method, endpoint, path, data) +} - for _, nodeMeta := range clusterMeta { - for k, address := range nodeMeta.(map[string]interface{}) { - if k == "HTTPAddr" { - url := scheme + address.(string) - endpoints = append(endpoints, url) +func (c *HTTPClient) callAny(method, path string, data []byte) ([]byte, error) { + + var endpoint *endpoint + var retried bool + var err error + var result []byte + for { + // check every endpoint available in a round-robin manner + endpoint, err = c.topology.NextReadEndpoint(c.readPreference) + if err != nil { + if !retried && c.discoveryEnabled { + c.discover() + retried = true + continue } + return nil, err + } + result, err = c.doReq(method, endpoint, path, data) + if err == nil { + break } + endpoint.MarkAsDead() } - c.topology.Endpoints = endpoints + return result, err } -func (c *HTTPClient) doReq(method, path string, data []byte) ([]byte, error) { +func (c *HTTPClient) doReq(method string, endpoint *endpoint, path string, data []byte) ([]byte, error) { - url, err := url.Parse(c.topology.Leader + path) + url, err := url.Parse(endpoint.URL() + path) if err != nil { - return nil, err //panic(err) + return nil, err } - req, err := http.NewRequest(method, fmt.Sprintf("%s", url), bytes.NewBuffer(data)) + // Build request + req, err := NewRetriableRequest(method, url.String(), data) if err != nil { - return nil, err //panic(err) + return nil, err } + // Set headers req.Header.Set("Content-Type", "application/json") - req.Header.Set("Api-Key", c.conf.APIKey) + req.Header.Set("Api-Key", c.apiKey) - resp, err := c.exponentialBackoff(req) + // Get response + resp, err := c.retrier.DoReq(req) if err != nil { + log.Infof("Request error: %v\n", err) + log.Infof("%s is dead\n", endpoint) + endpoint.MarkAsDead() return nil, err - // NetworkTransport error. Check topology info } - defer resp.Body.Close() - - bodyBytes, _ := ioutil.ReadAll(resp.Body) - if resp.StatusCode >= 500 { - return nil, fmt.Errorf("Server error: %v", string(bodyBytes)) - // Non Leader error. Check topology info. + if resp.Body != nil { + defer resp.Body.Close() } + bodyBytes, _ := ioutil.ReadAll(resp.Body) if resp.StatusCode >= 400 && resp.StatusCode < 500 { return nil, fmt.Errorf("Invalid request %v", string(bodyBytes)) } + // we successfully made a request to this endpoint + endpoint.MarkAsHealthy() + return bodyBytes, nil } -// Ping will do a request to the server -func (c HTTPClient) Ping() error { - _, err := c.doReq("GET", "/health-check", nil) +// healthCheck does a health check on all nodes in the cluster. +// Depending on the node state, it marks connections as dead, alive etc. +// The timeout specifies how long to wait for a response from QED. +func (c *HTTPClient) healthCheck(timeout time.Duration) { + + for _, e := range c.topology.Endpoints() { + + endpoint := e + // the goroutines execute the health-check HTTP request and sets status + go func(endpointURL string) { + + // Run a GET request against QED with a timeout + // TODO it should be a HEAD instead of a GET + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequest("HEAD", endpointURL+"/healthcheck", nil) + if err != nil { + return + } + req.Header.Set("Api-Key", c.apiKey) + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + log.Infof("%s is dead", endpoint.URL()) + endpoint.MarkAsDead() + } + if resp != nil { + status := resp.StatusCode + if resp.Body != nil { + resp.Body.Close() + } + if status >= 200 && status < 300 { + endpoint.MarkAsAlive() + } else { + log.Infof("%s is dead [status=%d]", endpoint.URL(), status) + endpoint.MarkAsDead() + } + } + }(endpoint.URL()) + + } + +} + +// discover uses the shards info API to return the list of nodes in the cluster. +// It uses the list of URLs passed on startup plus the list of URLs found +// by the preceding discovery process (if discovery is enabled). +func (c *HTTPClient) discover() error { + + if !c.discoveryEnabled { + return nil + } + + for { + + e, err := c.topology.NextReadEndpoint(Any) + if err != nil { + return err + } + + body, err := c.doReq("GET", e, "/info/shards", nil) + if err == nil { + info := make(map[string]interface{}) + err = json.Unmarshal(body, &info) + if err != nil { + return err + } + + clusterMeta := info["meta"].(map[string]interface{}) + primaryID := info["leaderID"].(string) + scheme := info["URIScheme"].(string) + + var prim string + secondaries := make([]string, 0) + for id, nodeMeta := range clusterMeta { + for k, address := range nodeMeta.(map[string]interface{}) { + if k == "HTTPAddr" { + if id == primaryID { + prim = scheme + address.(string) + } else { + secondaries = append(secondaries, scheme+address.(string)) + } + } + } + } + c.topology.Update(prim, secondaries...) + break + } + } + + return nil +} + +// Ping will do a healthcheck request to the primary node +func (c *HTTPClient) Ping() error { + _, err := c.callPrimary("HEAD", "/healthcheck", nil) if err != nil { return err } - return nil } @@ -212,14 +432,16 @@ func (c HTTPClient) Ping() error { func (c *HTTPClient) Add(event string) (*protocol.Snapshot, error) { data, _ := json.Marshal(&protocol.Event{Event: []byte(event)}) - - body, err := c.doReq("POST", "/events", data) + body, err := c.callPrimary("POST", "/events", data) if err != nil { return nil, err } var snapshot protocol.Snapshot - _ = json.Unmarshal(body, &snapshot) + err = json.Unmarshal(body, &snapshot) + if err != nil { + return nil, err + } return &snapshot, nil @@ -233,7 +455,7 @@ func (c *HTTPClient) Membership(key []byte, version uint64) (*protocol.Membershi Version: version, }) - body, err := c.doReq("POST", "/proofs/membership", query) + body, err := c.callAny("POST", "/proofs/membership", query) if err != nil { return nil, err } @@ -253,7 +475,7 @@ func (c *HTTPClient) MembershipDigest(keyDigest hashing.Digest, version uint64) Version: version, }) - body, err := c.doReq("POST", "/proofs/digest-membership", query) + body, err := c.callAny("POST", "/proofs/digest-membership", query) if err != nil { return nil, err } @@ -273,7 +495,7 @@ func (c *HTTPClient) Incremental(start, end uint64) (*protocol.IncrementalRespon End: end, }) - body, err := c.doReq("POST", "/proofs/incremental", query) + body, err := c.callAny("POST", "/proofs/incremental", query) if err != nil { return nil, err } diff --git a/client/client_test.go b/client/client_test.go index d4f015c94..b9106d731 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -19,22 +19,23 @@ package client import ( "bytes" "encoding/json" + "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" + "time" + + "github.com/stretchr/testify/require" "github.com/bbva/qed/hashing" + "github.com/pkg/errors" "github.com/bbva/qed/log" "github.com/bbva/qed/protocol" "github.com/stretchr/testify/assert" ) -func init() { - log.SetLogger("client-test", "info") -} - func setupServer(input []byte) (string, func()) { mux := http.NewServeMux() server := httptest.NewServer(mux) @@ -44,22 +45,208 @@ func setupServer(input []byte) (string, func()) { mux.HandleFunc("/proofs/membership", defaultHandler(input)) mux.HandleFunc("/proofs/incremental", defaultHandler(input)) mux.HandleFunc("/proofs/digest-membership", defaultHandler(input)) + mux.HandleFunc("/healthcheck", defaultHandler(nil)) return server.URL, func() { server.Close() } } -func setupClient(urls []string) *HTTPClient { - return NewHTTPClient(Config{ - Endpoints: urls, - APIKey: "my-awesome-api-key", - Insecure: false, +func setupClient(t *testing.T, urls []string) *HTTPClient { + httpClient := http.DefaultClient + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs(urls[0], urls[1:]...), + SetRequestRetrier(NewNoRequestRetrier(httpClient)), + SetReadPreference(Primary), + SetMaxRetries(0), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + if err != nil { + t.Fatal(errors.Wrap(err, "Cannot create http client")) + } + return client +} + +func TestCallPrimaryWorking(t *testing.T) { + + log.SetLogger("TestCallPrimaryWorking", log.SILENT) + + var numRequests int + httpClient := NewTestHttpClient(func(req *http.Request) (*http.Response, error) { + numRequests++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewBufferString("OK")), + Header: make(http.Header), + }, nil + }) + + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs("http://primary.foo"), + SetReadPreference(PrimaryPreferred), + SetMaxRetries(1), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + require.NoError(t, err) + + resp, err := client.callPrimary("GET", "/test", nil) + require.NoError(t, err, "The requests should not fail") + require.True(t, len(resp) > 0, "The response should not be empty") + require.Equal(t, 1, numRequests, "The number of requests should match") +} + +func TestCallPrimaryFails(t *testing.T) { + + log.SetLogger("TestCallPrimaryFails", log.SILENT) + + var numRequests int + httpClient := NewTestHttpClient(func(req *http.Request) (*http.Response, error) { + numRequests++ + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + Header: make(http.Header), + }, nil + }) + + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs("http://primary.foo", "http://secondary1.foo"), + SetReadPreference(PrimaryPreferred), + SetMaxRetries(1), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + require.NoError(t, err) + + resp, err := client.callPrimary("GET", "/test", nil) + require.Error(t, err, "The requests should fail") + require.True(t, len(resp) == 0, "The response should be empty") + require.Equal(t, 2, numRequests, "The number of requests should match") +} + +func TestCallAnyPrimaryFails(t *testing.T) { + + log.SetLogger("TestCallAnyPrimaryFails", log.SILENT) + + var numRequests int + httpClient := NewTestHttpClient(func(req *http.Request) (*http.Response, error) { + numRequests++ + if strings.HasPrefix(req.URL.Hostname(), "primary") { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewBufferString("OK")), + Header: make(http.Header), + }, nil + }) + + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs("http://primary.foo", "http://secondary1.foo"), + SetReadPreference(PrimaryPreferred), + SetMaxRetries(1), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + require.NoError(t, err) + + resp, err := client.callAny("GET", "/test", nil) + require.NoError(t, err, "The requests should not fail") + require.True(t, len(resp) > 0, "The response should not be empty") + require.Equal(t, 3, numRequests, "The number of requests should match") +} + +func TestCallAnyAllFail(t *testing.T) { + + log.SetLogger("TestCallAnyAllFail", log.SILENT) + + var numRequests int + httpClient := NewTestHttpClient(func(req *http.Request) (*http.Response, error) { + numRequests++ + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + Header: make(http.Header), + }, nil + }) + + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs("http://primary.foo", "http://secondary1.foo", "http://secondary2.foo"), + SetReadPreference(PrimaryPreferred), + SetMaxRetries(1), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + require.NoError(t, err) + + resp, err := client.callAny("GET", "/test", nil) + require.Error(t, err, "The request should fail") + require.True(t, len(resp) == 0, "The response should be empty") + require.Equal(t, 6, numRequests, "The number of requests should match") +} + +func TestHealthCheck(t *testing.T) { + + log.SetLogger("TestHealthCheck", log.INFO) + + var numRequests int + httpClient := NewTestHttpClient(func(req *http.Request) (*http.Response, error) { + if req.Method == "HEAD" { + numRequests++ + return &http.Response{ + StatusCode: http.StatusNoContent, + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + Header: make(http.Header), + }, nil }) + + client, err := NewHTTPClient( + SetHttpClient(httpClient), + SetAPIKey("my-awesome-api-key"), + SetURLs("http://primary.foo", "http://secondary1.foo", "http://secondary2.foo"), + SetReadPreference(PrimaryPreferred), + SetMaxRetries(1), + SetTopologyDiscovery(false), + SetHealthchecks(false), + ) + require.NoError(t, err) + + // force all endpoints to get marked as dead + _, err = client.callAny("GET", "/events", nil) + require.Error(t, err) + require.False(t, client.topology.HasActiveEndpoint()) + + // try to revive them + client.healthCheck(5 * time.Second) + time.Sleep(1 * time.Second) + require.True(t, client.topology.HasActiveEndpoint()) } func TestAddSuccess(t *testing.T) { + log.SetLogger("TestAddSuccess", log.SILENT) + event := "Hello world!" snap := &protocol.Snapshot{ HistoryDigest: []byte("history"), @@ -71,7 +258,7 @@ func TestAddSuccess(t *testing.T) { serverURL, tearDown := setupServer(input) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) snapshot, err := client.Add(event) assert.NoError(t, err) @@ -79,9 +266,12 @@ func TestAddSuccess(t *testing.T) { } func TestAddWithServerFailure(t *testing.T) { + + log.SetLogger("TestAddWithServerFailure", log.SILENT) + serverURL, tearDown := setupServer(nil) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) event := "Hello world!" _, err := client.Add(event) @@ -89,6 +279,9 @@ func TestAddWithServerFailure(t *testing.T) { } func TestMembership(t *testing.T) { + + log.SetLogger("TestMembership", log.SILENT) + event := "Hello world!" version := uint64(0) fakeResult := &protocol.MembershipResult{ @@ -105,7 +298,7 @@ func TestMembership(t *testing.T) { serverURL, tearDown := setupServer(inputJSON) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) result, err := client.Membership([]byte(event), version) assert.NoError(t, err) @@ -114,6 +307,8 @@ func TestMembership(t *testing.T) { func TestDigestMembership(t *testing.T) { + log.SetLogger("TestDigestMembership", log.SILENT) + event := "Hello world!" version := uint64(0) fakeResult := &protocol.MembershipResult{ @@ -130,7 +325,7 @@ func TestDigestMembership(t *testing.T) { serverURL, tearDown := setupServer(inputJSON) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) result, err := client.MembershipDigest([]byte("digest"), version) assert.NoError(t, err) @@ -138,9 +333,12 @@ func TestDigestMembership(t *testing.T) { } func TestMembershipWithServerFailure(t *testing.T) { + + log.SetLogger("TestMembershipWithServerFailure", log.SILENT) + serverURL, tearDown := setupServer(nil) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) event := "Hello world!" @@ -150,6 +348,8 @@ func TestMembershipWithServerFailure(t *testing.T) { func TestIncremental(t *testing.T) { + log.SetLogger("TestIncremental", log.SILENT) + start := uint64(2) end := uint64(8) fakeResult := &protocol.IncrementalResponse{ @@ -162,7 +362,7 @@ func TestIncremental(t *testing.T) { serverURL, tearDown := setupServer(inputJSON) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) result, err := client.Incremental(start, end) assert.NoError(t, err) @@ -170,9 +370,12 @@ func TestIncremental(t *testing.T) { } func TestIncrementalWithServerFailure(t *testing.T) { + + log.SetLogger("TestIncrementalWithServerFailure", log.SILENT) + serverURL, tearDown := setupServer(nil) defer tearDown() - client := setupClient([]string{serverURL}) + client := setupClient(t, []string{serverURL}) _, err := client.Incremental(uint64(2), uint64(8)) assert.Error(t, err) diff --git a/client/config.go b/client/config.go index b444d9d94..3078e37c9 100644 --- a/client/config.go +++ b/client/config.go @@ -16,6 +16,82 @@ package client +import ( + "time" +) + +// ReadPref specifies the preferred type of node in the cluster +// to send request to. +type ReadPref int + +const ( + // Primary forces to read only from the primary node (or leader). + Primary ReadPref = iota + + // PrimaryPreferred aims to read from the primary node (or leader). + // + // Use PrimaryPreferred if you want an application to read from the primary + // under normal circumstances, but to allow stale reads from secondaries when + // the primary is unavailable. This provides a "read-only mode" for your + // application during a failover. + PrimaryPreferred + + // Secondary force to read only from secondary nodes (or replicas). + Secondary + + // SecondaryPreferred aims to read from secondary nodes (or replicas). + // + // In general, do not use SecondaryPreferred to provide extra capacity for reads, + // because all members of a cluster have roughly equivalent write traffic; as + // a result, secondaries will service reads at roughly the same rate as the + // primary. In addition, although replication is synchronous, there is some amount + // of dely between event replication to secondaries and change application + // to the corresponding balloon. Reading from a secondary can return stale data. + SecondaryPreferred + + // Any forces to read from any node in the cluster including the leader. + Any +) + +const ( + // DefaultTimeout is the default number of seconds to wait for a request to QED. + DefaultTimeout = 10 * time.Second + + // DefaultDialTimeout is the default number of seconds to wait for the connection + // to be established. + DefaultDialTimeout = 5 * time.Second + + // DefaultHandshakeTimeout is the default number of seconds to wait for a handshake + // negotiation. + DefaultHandshakeTimeout = 5 * time.Second + + // DefaultInsecure sets if the client verifies, by default, the server's + // certificate chain and host name, allowing MiTM vector attacks. + DefaultInsecure = false + + // DefaultMaxRetries sets the default maximum number of retries before giving up + // when performing an HTTP request to QED. + DefaultMaxRetries = 0 + + // DefaultHealthCheckEnabled specifies if healthchecks are enabled by default. + DefaultHealthCheckEnabled = true + + // DefaultHealthCheckTimeout specifies the time the healtch checker waits for + // a response from QED. + DefaultHealthCheckTimeout = 2 * time.Second + + // DefaultTopologyDiscoveryEnabled specifies if the discoverer is enabled by default. + DefaultTopologyDiscoveryEnabled = true + + // DefaultTopologyDiscoveryTimeout specifies the time the discoverer waits for + // a response from QED. + DefaultTopologyDiscoveryTimeout = 2 * time.Second + + // off is used to disable timeouts. + off = -1 * time.Second +) + +// Config sets the HTTP client configuration type Config struct { // Endpoints [host:port,host:port,...] to ask for QED cluster-topology. Endpoints []string @@ -23,35 +99,56 @@ type Config struct { // ApiKey to query the server endpoint. APIKey string - // Enable self-signed certificates, allowing MiTM vector attacks. + // Insecure enables the verification of the server's certificate chain + // and host name, allowing MiTM vector attacks. Insecure bool - // Seconds to wait for an established connection. - TimeoutSeconds int + // Timeout is the number of seconds to wait for a request to QED. + Timeout time.Duration - // Seconds to wait for the connection to be established. - DialTimeoutSeconds int + // DialTimeout is the number of seconds to wait for the connection to be established. + DialTimeout time.Duration - // Seconds to wait for a handshake negotiation. - HandshakeTimeoutSeconds int -} + // HandshakeTimeout is the number of seconds to wait for a handshake negotiation. + HandshakeTimeout time.Duration -type Topology struct { - // Topology endpoints [host:port,host:port,...a] - Endpoints []string + // Controls how the client will route all queries to members of the cluster. + ReadPreference ReadPref + + // MaxRetries sets the maximum number of retries before giving up + // when performing an HTTP request to QED. + MaxRetries int + + // EnableTopologyDiscovery enables the process of discovering the cluster + // topology when requests fail. + EnableTopologyDiscovery bool + + // DiscoveryTimeout is the timeout in seconds the discoverer waits for a response + // from a QED server. + DiscoveryTimeout time.Duration + + // EnableHealthChecks enables helthchecks of all endpoints in the current cluster topology. + EnableHealthChecks bool - // Endpoint [host:port] to operate. - // Must be the QED cluster leader. - Leader string + // HealthCheckTimeout is the timeout in seconds the healthcheck waits for a response + // from a QED server. + HealthCheckTimeout time.Duration } +// DefaultConfig creates a Config structures with default values. func DefaultConfig() *Config { return &Config{ Endpoints: []string{"127.0.0.1:8800"}, APIKey: "my-key", - Insecure: true, - TimeoutSeconds: 10, - DialTimeoutSeconds: 5, - HandshakeTimeoutSeconds: 5, + Insecure: DefaultInsecure, + Timeout: DefaultTimeout, + DialTimeout: DefaultDialTimeout, + HandshakeTimeout: DefaultHandshakeTimeout, + ReadPreference: Primary, + MaxRetries: DefaultMaxRetries, + EnableTopologyDiscovery: DefaultTopologyDiscoveryEnabled, + EnableHealthChecks: DefaultHealthCheckEnabled, + DiscoveryTimeout: DefaultTopologyDiscoveryTimeout, + HealthCheckTimeout: DefaultHealthCheckTimeout, } } diff --git a/client/endpoint.go b/client/endpoint.go new file mode 100644 index 000000000..fd1b0fa88 --- /dev/null +++ b/client/endpoint.go @@ -0,0 +1,108 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "fmt" + "sync" + "time" +) + +type nodeType int + +const ( + primary nodeType = iota + secondary +) + +// endpoint represents status information of a single endpointection to a node in a cluster +type endpoint struct { + sync.RWMutex + url string // [scheme://host:port] + nodeType nodeType + failures int + dead bool + deadSince *time.Time +} + +// newEndpoint creates a new endpoint to the given URL [scheme://host:port]. +func newEndpoint(url string, nodeType nodeType) *endpoint { + return &endpoint{ + url: url, + nodeType: nodeType, + } +} + +// String returns a representation of the endpoint status. +func (c *endpoint) String() string { + c.RLock() + defer c.RUnlock() + return fmt.Sprintf("%s [type=%v,dead=%v,failures=%d,deadSince=%v]", c.url, c.nodeType, c.dead, c.failures, c.deadSince) +} + +// URL returns the url string of this endpoint. +func (c *endpoint) URL() string { + c.RLock() + defer c.RUnlock() + return c.url +} + +// Type returns true if the node type is primary. +func (c *endpoint) IsPrimary() bool { + c.RLock() + defer c.RUnlock() + return c.nodeType == primary +} + +// IsDead returns true if this endpoint is marked as dead, i.e. a previous +// request to the url has been unsuccessful. +func (c *endpoint) IsDead() bool { + c.RLock() + defer c.RUnlock() + return c.dead +} + +// MarkAsDead marks this endpoint as dead, increments the failures +// counter and stores the current time in dead since. +func (c *endpoint) MarkAsDead() { + c.Lock() + c.dead = true + if c.deadSince == nil { + utcNow := time.Now().UTC() + c.deadSince = &utcNow + } + c.failures++ + c.Unlock() +} + +// MarkAsAlive marks this endpoint as eligible to be returned from the +// pool of endpoint by the selector. +func (c *endpoint) MarkAsAlive() { + c.Lock() + c.dead = false + c.Unlock() +} + +// MarkAsHealthy marks this endpoint as healthy, i.e. a request has been +// successfully performed with it. +func (c *endpoint) MarkAsHealthy() { + c.Lock() + c.dead = false + c.deadSince = nil + c.failures = 0 + c.Unlock() +} diff --git a/client/error.go b/client/error.go new file mode 100644 index 000000000..28ac10447 --- /dev/null +++ b/client/error.go @@ -0,0 +1,34 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import "errors" + +var ( + // ErrNoEndpoint is raised when no QED node is available. + ErrNoEndpoint = errors.New("no QED node available") + + // ErrNoPrimary is raised when no QED primary node is available. + ErrNoPrimary = errors.New("no QED primary node available") + + // ErrRetry is raised when a request cannot be executed after + // the configured number of retries. + ErrRetry = errors.New("cannot connect after serveral retries") + + // ErrTimeout is raised when a request timed out. + ErrTimeout = errors.New("timeout") +) diff --git a/client/options.go b/client/options.go new file mode 100644 index 000000000..bd879ddcc --- /dev/null +++ b/client/options.go @@ -0,0 +1,147 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + "time" +) + +// HTTPClientOptionF is a function that configures an HTTPClient. +type HTTPClientOptionF func(*HTTPClient) error + +func configToOptions(conf *Config) ([]HTTPClientOptionF, error) { + var options []HTTPClientOptionF + if conf != nil { + options = []HTTPClientOptionF{ + SetAPIKey(conf.APIKey), + SetReadPreference(conf.ReadPreference), + SetMaxRetries(conf.MaxRetries), + SetTopologyDiscovery(conf.EnableTopologyDiscovery), + SetDiscoveryTimeout(conf.DiscoveryTimeout), + SetHealthchecks(conf.EnableHealthChecks), + SetHealthCheckTimeout(conf.HealthCheckTimeout), + } + if len(conf.Endpoints) > 0 { + options = append(options, SetURLs(conf.Endpoints[0], conf.Endpoints[1:]...)) + } + + defaultTransport := http.DefaultTransport.(*http.Transport) + options = append(options, SetHttpClient(&http.Client{ + Timeout: conf.Timeout, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: conf.DialTimeout, + }).Dial, + Proxy: defaultTransport.Proxy, + DialContext: defaultTransport.DialContext, + MaxIdleConns: defaultTransport.MaxIdleConns, + IdleConnTimeout: defaultTransport.IdleConnTimeout, + ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Insecure}, + TLSHandshakeTimeout: conf.HandshakeTimeout, + }, + })) + } + return options, nil +} + +func SetHttpClient(client *http.Client) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.httpClient = client + return nil + } +} + +func SetURLs(primary string, secondaries ...string) HTTPClientOptionF { + return func(c *HTTPClient) error { + if len(primary) > 0 { + c.topology.Update(primary, secondaries...) + return nil + } + return errors.New("Cannot use empty string for the primary url") + } +} + +func SetAttemptToReviveEndpoints(value bool) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.topology.attemptToRevive = value + return nil + } +} + +func SetRequestRetrier(retrier RequestRetrier) HTTPClientOptionF { + return func(c *HTTPClient) error { + if retrier != nil { + c.retrier = retrier + return nil + } + return errors.New("The request retrier cannot be nil") + } +} + +func SetAPIKey(key string) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.apiKey = key + return nil + } +} + +func SetReadPreference(preference ReadPref) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.readPreference = preference + return nil + } +} + +func SetMaxRetries(retries int) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.maxRetries = retries + return nil + } +} + +func SetTopologyDiscovery(enable bool) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.discoveryEnabled = enable + return nil + } +} + +func SetDiscoveryTimeout(seconds time.Duration) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.discoveryTimeout = seconds + return nil + } +} + +func SetHealthchecks(enable bool) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.healthcheckEnabled = enable + return nil + } +} + +func SetHealthCheckTimeout(seconds time.Duration) HTTPClientOptionF { + return func(c *HTTPClient) error { + c.healthcheckTimeout = seconds + return nil + } +} diff --git a/client/request.go b/client/request.go new file mode 100644 index 000000000..0955da7a2 --- /dev/null +++ b/client/request.go @@ -0,0 +1,60 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "bytes" + "io" + "net/http" +) + +// ReaderFunc is the type of function that can be given natively to NewRetriableRequest. +type ReaderFunc func() (io.Reader, error) + +// RetriableRequest wraps the metadata needed to create HTTP requests +// and allow to reused it between retries. +type RetriableRequest struct { + // body is a seekable reader over the request body payload. This is + // used to rewind the request data in between retries + body ReaderFunc + + // Embed an HTTP request directly. This makes a *Request act exactly + // like an *http.Request so that all meta methods are supported. + *http.Request +} + +// NewRetriableRequest creates a new retriable request. +func NewRetriableRequest(method, url string, rawBody []byte) (*RetriableRequest, error) { + + var body ReaderFunc + var contentLength int64 + + if rawBody != nil { + body = func() (io.Reader, error) { + return bytes.NewReader(rawBody), nil + } + contentLength = int64(len(rawBody)) + } + + httpReq, err := http.NewRequest(method, url, nil) + if err != nil { + return nil, err + } + httpReq.ContentLength = contentLength + + return &RetriableRequest{body, httpReq}, nil +} diff --git a/client/retrier.go b/client/retrier.go new file mode 100644 index 000000000..c3e28e4ce --- /dev/null +++ b/client/retrier.go @@ -0,0 +1,176 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "time" + + "github.com/bbva/qed/log" +) + +var ( + // We need to consume response bodies to maintain http connections, but + // limit the size we consume to respReadLimit. + respReadLimit = int64(4096) +) + +// RequestRetrier decides whether to retry a failed HTTP request. +type RequestRetrier interface { + // DoReq executes the given request and if fails, it decides whether to retry + // the call, how long to wait for the next call, or whether to return an + // error (which will be returned to the service that started the HTTP + // request in the first place). + DoReq(req *RetriableRequest) (*http.Response, error) +} + +// NoRequestRetrier is an implementation that does no retries. +type NoRequestRetrier struct { + *http.Client +} + +// NewNoRequestRetrier returns a retrier that does no retries. +func NewNoRequestRetrier(httpClient *http.Client) *NoRequestRetrier { + return &NoRequestRetrier{Client: httpClient} +} + +func (r *NoRequestRetrier) DoReq(req *RetriableRequest) (*http.Response, error) { + // always rewind + if req.body != nil { + body, err := req.body() + if err != nil { + return nil, err + } + if c, ok := body.(io.ReadCloser); ok { + req.Request.Body = c + } else { + req.Request.Body = ioutil.NopCloser(body) + } + } + resp, err := r.Do(req.Request) + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid reponse codes as well, like 0. + if err == nil && resp.StatusCode > 0 && resp.StatusCode < 500 { + return resp, nil + } + return nil, fmt.Errorf("%s %s: giving up after %d attempts", + req.Method, req.URL, 1) +} + +// BackoffRequestRetrier is an implementation that uses the given backoff strategy. +type BackoffRequestRetrier struct { + *http.Client + maxRetries int + backoff Backoff +} + +// NewBackoffRequestRetrier returns a retrier that uses the given backoff strategy. +func NewBackoffRequestRetrier(httpClient *http.Client, maxRetries int, backoff Backoff) *BackoffRequestRetrier { + return &BackoffRequestRetrier{ + Client: httpClient, + maxRetries: maxRetries, + backoff: backoff, + } +} + +func (r *BackoffRequestRetrier) DoReq(req *RetriableRequest) (*http.Response, error) { + + var resp *http.Response + var err error + + for i := 0; ; i++ { + + var code int // HTTP response status code + + // always rewind + if req.body != nil { + body, err := req.body() + if err != nil { + return resp, err + } + if c, ok := body.(io.ReadCloser); ok { + req.Request.Body = c + } else { + req.Request.Body = ioutil.NopCloser(body) + } + } + + // attempt the request + resp, err = r.Do(req.Request) + if resp != nil { + code = resp.StatusCode + } + if err != nil { + log.Infof("%s %s request failed: %v", req.Method, req.URL, err) + } + + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid reponse codes as well, like 0. + if err == nil && resp.StatusCode > 0 && resp.StatusCode < 500 { + return resp, nil + } + + // we decide to continue with retrying + + // We do this before drainBody beause there's no need for the I/O if + // we're breaking out + remain := r.maxRetries - i + if remain <= 0 { + break + } + + // We're going to retry, consume any response to reuse the connection. + if err == nil && resp != nil { + // drain body + _, err := io.Copy(ioutil.Discard, io.LimitReader(resp.Body, respReadLimit)) + if err != nil { + log.Infof("Error reading response body: %v", err) + } + } + + wait, goahead := r.backoff.Next(i) + if !goahead { + break + } + + desc := fmt.Sprintf("%s %s", req.Method, req.URL) + if code > 0 { + desc = fmt.Sprintf("%s (status: %d)", desc, code) + } + log.Infof("%s: retrying in %s (%d left)", desc, wait, remain) + + time.Sleep(wait) + + } + + // By default, we close the response body and return an error without + // returning the response + if resp != nil { + resp.Body.Close() + } + + return nil, fmt.Errorf("%s %s giving up after %d attempts", + req.Method, req.URL, r.maxRetries+1) + +} diff --git a/client/retrier_test.go b/client/retrier_test.go new file mode 100644 index 000000000..072bfc637 --- /dev/null +++ b/client/retrier_test.go @@ -0,0 +1,134 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "bytes" + "errors" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNoRequestRetrier(t *testing.T) { + var numFailedReqs int + fail := func(req *http.Request) (*http.Response, error) { + numFailedReqs++ + return nil, errors.New("request failed") + } + + httpClient := NewTestHttpClient( + NewFailingTransport("/fail", fail, nil), + ) + retrier := NewNoRequestRetrier(httpClient) + + req, err := NewRetriableRequest("GET", "http://foo.bar/fail", nil) + require.NoError(t, err) + + resp, err := retrier.DoReq(req) + require.Error(t, err) + require.Nil(t, resp) + require.Equal(t, 1, numFailedReqs, "The expected number of failed requests does not match") + +} + +func TestBackoffRequestRetrier(t *testing.T) { + var numFailedReqs int + fail := func(req *http.Request) (*http.Response, error) { + numFailedReqs++ + return nil, errors.New("request failed") + } + + httpClient := NewTestHttpClient( + NewFailingTransport("/fail", fail, nil), + ) + maxRetries := 5 + retrier := NewBackoffRequestRetrier(httpClient, maxRetries, + NewSimpleBackoff(100, 100, 100, 100, 100)) + + req, err := NewRetriableRequest("GET", "http://foo.bar/fail", nil) + require.NoError(t, err) + + resp, err := retrier.DoReq(req) + require.Error(t, err) + require.Nil(t, resp) + require.Equal(t, maxRetries+1, numFailedReqs, "The expected number of failed requests does not match") + +} + +func TestBackoffRequestRetrierWithStatus(t *testing.T) { + var numFailedReqs int + serverFail := func(req *http.Request) (*http.Response, error) { + numFailedReqs++ + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + // Must be set to non-nil value or it panics + Header: make(http.Header), + }, nil + } + + httpClient := NewTestHttpClient(serverFail) + maxRetries := 5 + retrier := NewBackoffRequestRetrier(httpClient, maxRetries, + NewSimpleBackoff(100, 100, 100, 100, 100)) + + req, err := NewRetriableRequest("GET", "http://foo.bar/fail", nil) + require.NoError(t, err) + + resp, err := retrier.DoReq(req) + require.Error(t, err) + require.Nil(t, resp) + require.Equal(t, maxRetries+1, numFailedReqs, "The expected number of failed requests does not match") +} + +func TestBackoffRequestRetrierTwoRetries(t *testing.T) { + var numFailedReqs int + serverTemporaryFail := func(req *http.Request) (*http.Response, error) { + numFailedReqs++ + if numFailedReqs > 1 { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("OK")), + // Must be set to non-nil value or it panics + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewBufferString("Internal server error")), + // Must be set to non-nil value or it panics + Header: make(http.Header), + }, nil + } + + httpClient := NewTestHttpClient(serverTemporaryFail) + maxRetries := 5 + retrier := NewBackoffRequestRetrier(httpClient, maxRetries, + NewSimpleBackoff(100, 100, 100, 100, 100)) + + req, err := NewRetriableRequest("GET", "http://foo.bar/fail", nil) + require.NoError(t, err) + + resp, err := retrier.DoReq(req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 2, numFailedReqs, "The expected number of failed requests does not match") + +} diff --git a/client/test_util.go b/client/test_util.go new file mode 100644 index 000000000..3762e6fe9 --- /dev/null +++ b/client/test_util.go @@ -0,0 +1,48 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "net/http" + "strings" +) + +type RoundTripFunc func(req *http.Request) (*http.Response, error) + +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// NewTestHttpClient returns *http.Client with Transport replaced to avoid making real calls +func NewTestHttpClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +// NewFailingTransport will run a fail callback if it sees a given URL path prefix. +func NewFailingTransport(path string, fail RoundTripFunc, next http.RoundTripper) RoundTripFunc { + return func(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.Path, path) && fail != nil { + return fail(req) + } + if next != nil { + return next.RoundTrip(req) + } + return http.DefaultTransport.RoundTrip(req) + } +} diff --git a/client/topology.go b/client/topology.go new file mode 100644 index 000000000..7e5d3e4ef --- /dev/null +++ b/client/topology.go @@ -0,0 +1,206 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "sync" +) + +// QED cluster encapsulation +type topology struct { + endpoints []*endpoint + primary *endpoint + cIndex int // index into endpoints (for round-robin) + attemptToRevive bool + sync.RWMutex +} + +func newTopology(attemptToRevive bool) *topology { + return &topology{ + endpoints: make([]*endpoint, 0), + cIndex: -1, + primary: nil, + attemptToRevive: attemptToRevive, + } +} + +func (t *topology) Update(primaryNode string, secondaries ...string) { + t.Lock() + defer t.Unlock() + + // Build up new endpoints. + // If we find an existing endpoint, then use it including + // the previous number of failures, etc. + var newEndpoints []*endpoint + if primaryNode != "" { + t.primary = newEndpoint(primaryNode, primary) + newEndpoints = append(newEndpoints, t.primary) + } + + for _, url := range secondaries { + var found bool + for _, oldEndpoint := range t.endpoints { + if oldEndpoint.url == url { + // Take over the old endpoint + newEndpoints = append(newEndpoints, oldEndpoint) + found = true + break + } + } + if !found && url != "" { + // New endpoint didn't exist, so add it to our list of new endpoints. + newEndpoints = append(newEndpoints, newEndpoint(url, secondary)) + } + } + t.endpoints = newEndpoints + t.cIndex = -1 +} + +func (t *topology) Primary() (*endpoint, error) { + t.Lock() + defer t.Unlock() + + if t.primary == nil { + return nil, ErrNoPrimary + } + return t.primary, nil +} + +func (t *topology) Endpoints() []*endpoint { + t.Lock() + defer t.Unlock() + return t.endpoints +} + +// NextReadendpoint returns the next available endpoint to query +// in a round-robin manner, or ErrNoEndpoint +func (t *topology) NextReadEndpoint(pref ReadPref) (*endpoint, error) { + + t.Lock() + defer t.Unlock() + + switch pref { + + case PrimaryPreferred: + if t.primary != nil && !t.primary.IsDead() { + return t.primary, nil + } + fallthrough + + case Secondary: + i := 0 + numEndpoints := len(t.endpoints) + if numEndpoints > 0 { + for { + if i > numEndpoints { + break // we visited all endpoints and they all seem to be dead + } + t.cIndex++ + if t.cIndex >= numEndpoints { + t.cIndex = 0 + } + endpoint := t.endpoints[t.cIndex] + if endpoint.nodeType == secondary && !endpoint.IsDead() { + return endpoint, nil + } + i++ + } + } + break + + case SecondaryPreferred: + i := 0 + numEndpoints := len(t.endpoints) + if numEndpoints > 0 { + for { + if i > numEndpoints { + break // we visited all endpoints and they all seem to be dead + } + t.cIndex++ + if t.cIndex >= numEndpoints { + t.cIndex = 0 + } + endpoint := t.endpoints[t.cIndex] + if endpoint.nodeType == secondary && !endpoint.IsDead() { + return endpoint, nil + } + i++ + } + } + fallthrough + + case Primary: + if t.primary != nil && !t.primary.IsDead() { + return t.primary, nil + } + break + + case Any: + i := 0 + numEndpoints := len(t.endpoints) + if numEndpoints > 0 { + for { + if i > numEndpoints { + break // we visited all endpoints and they all seem to be dead + } + t.cIndex++ + if t.cIndex >= numEndpoints { + t.cIndex = 0 + } + endpoint := t.endpoints[t.cIndex] + if !endpoint.IsDead() { + return endpoint, nil + } + i++ + } + } + break + } + + // Now all nodes are marked as dead. If attemptToRevive is disabled, + // endpoints will never be marked alive again, so we need to + // mark all of them as alive. This way, they will be picked up + // in the next call to performRequest. + if t.attemptToRevive { + for _, endpoint := range t.endpoints { + endpoint.MarkAsAlive() + } + } + + // we tried every endpoint but there is no one available + return nil, ErrNoEndpoint +} + +// HasActivePrimary returns true if there is an active primary endpoint. +func (t *topology) HasActivePrimary() bool { + t.Lock() + defer t.Unlock() + return t.primary != nil && !t.primary.IsDead() +} + +// HasActiveEndpoint returns true there is an active endpoint (primary +// or secondary). +func (t *topology) HasActiveEndpoint() bool { + t.Lock() + defer t.Unlock() + for _, e := range t.endpoints { + if !e.IsDead() { + return true + } + } + return false +} diff --git a/client/topology_test.go b/client/topology_test.go new file mode 100644 index 000000000..e6f0151f4 --- /dev/null +++ b/client/topology_test.go @@ -0,0 +1,262 @@ +/* + Copyright 2018 Banco Bilbao Vizcaya Argentaria, S.A. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package client + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTopologyUpdate(t *testing.T) { + topology := newTopology(false) + require.Empty(t, topology.Endpoints(), "The list of endpoints should be empty") + + topology.Update( + "http://primary:8080", + "http://secondary1:8080", + "http://secondary2:8080", + ) + + endpoints := topology.Endpoints() + expectedEndpoints := []*endpoint{ + newEndpoint("http://primary:8080", primary), + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + } + require.ElementsMatch(t, expectedEndpoints, endpoints, "The endpoints should match") +} + +func TestTopologyPrimary(t *testing.T) { + + topology := newTopology(false) + endpoint, err := topology.Primary() + require.Nil(t, endpoint) + require.Error(t, err) + + topology.Update("http://primary:8080") + endpoint, err = topology.Primary() + require.NoError(t, err) + require.Equalf(t, primary, endpoint.nodeType, "The type of node should match") + require.Equalf(t, "http://primary:8080", endpoint.URL(), "The URL should match") + +} + +func TestTopologyNextReadEndpoint(t *testing.T) { + + testCases := []struct { + primary string + secondaries []string + readPref ReadPref + expectError bool + rounds int + expectedEndpoints []*endpoint + }{ + { + // Preference=Primary with existent primary node + primary: "http://primary:8080", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: Primary, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + }, + }, + { + // Preference=Primary with non-existent primary node + primary: "", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: Primary, + expectError: true, + rounds: 4, + expectedEndpoints: []*endpoint{}, + }, + { + // Preference=PrimaryPreferred with existent primary node + primary: "http://primary:8080", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: PrimaryPreferred, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + }, + }, + { + // Preference=PrimaryPreferred with non-existent primary node + primary: "", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: PrimaryPreferred, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + }, + }, + { + // Preference=Secondary with existent secondary nodes + primary: "http://primary:8080", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: Secondary, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + }, + }, + { + // Preference=Secondary with non-existent secondary nodes + primary: "http://primary:8080", + secondaries: []string{}, + readPref: Secondary, + expectError: true, + rounds: 4, + expectedEndpoints: []*endpoint{}, + }, + { + // Preference=SecondaryPreferred with existent secondary nodes + primary: "http://primary:8080", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: SecondaryPreferred, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + }, + }, + { + // Preference=SecondaryPreferred with non-existent secondary nodes + primary: "http://primary:8080", + secondaries: []string{}, + readPref: SecondaryPreferred, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + newEndpoint("http://primary:8080", primary), + }, + }, + { + // Preference=SecondaryPreferred with nor existent primary neither secondary nodes + primary: "", + secondaries: []string{}, + readPref: SecondaryPreferred, + expectError: true, + rounds: 4, + expectedEndpoints: []*endpoint{}, + }, + { + // Preference=Any with both existent primary and secondary nodes + primary: "http://primary:8080", + secondaries: []string{ + "http://secondary1:8080", + "http://secondary2:8080", + }, + readPref: Any, + expectError: false, + rounds: 4, + expectedEndpoints: []*endpoint{ + newEndpoint("http://primary:8080", primary), + newEndpoint("http://secondary1:8080", secondary), + newEndpoint("http://secondary2:8080", secondary), + newEndpoint("http://primary:8080", primary), + }, + }, + { + // Preference=Any with non-existent nodes + primary: "", + secondaries: []string{}, + readPref: Any, + expectError: true, + rounds: 4, + expectedEndpoints: []*endpoint{}, + }, + } + + for i, c := range testCases { + topology := newTopology(false) + topology.Update(c.primary, c.secondaries...) + collectedEndpoints := make([]*endpoint, 0) + round := 0 + for { + if round >= c.rounds { + break + } + round++ + endpoint, err := topology.NextReadEndpoint(c.readPref) + if c.expectError { + require.NotNil(t, err, "Should return error for test case %d", i) + break + } + collectedEndpoints = append(collectedEndpoints, endpoint) + } + require.ElementsMatch(t, c.expectedEndpoints, collectedEndpoints, "The endpoints should match for test case %d", i) + } + +} + +func TestTopologyHasActivePrimary(t *testing.T) { + topology := newTopology(false) + require.False(t, topology.HasActivePrimary()) + + topology.Update("http://primary:8080") + require.True(t, topology.HasActivePrimary()) +} + +func TestTopologyHashActiveEndpoint(t *testing.T) { + topology := newTopology(false) + require.False(t, topology.HasActiveEndpoint()) + + topology.Update("http://primary:8080", "http://secondary1:8080") + require.True(t, topology.HasActiveEndpoint()) +} diff --git a/cmd/client.go b/cmd/client.go index ca663aec0..4f801fe2b 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -17,6 +17,9 @@ package cmd import ( + "fmt" + "time" + "github.com/spf13/cobra" v "github.com/spf13/viper" @@ -37,9 +40,9 @@ func newClientCommand(ctx *cmdContext) *cobra.Command { f := cmd.PersistentFlags() f.StringSliceVarP(&clientCtx.config.Endpoints, "endpoints", "e", []string{"127.0.0.1:8800"}, "Endpoint for REST requests on (host:port)") f.BoolVar(&clientCtx.config.Insecure, "insecure", false, "Allow self signed certificates") - f.IntVar(&clientCtx.config.TimeoutSeconds, "timeout-seconds", 10, "Seconds to cut the connection") - f.IntVar(&clientCtx.config.DialTimeoutSeconds, "dial-timeout-seconds", 5, "Seconds to cut the dialing") - f.IntVar(&clientCtx.config.HandshakeTimeoutSeconds, "handshake-timeout-seconds", 5, "Seconds to cut the handshaking") + f.DurationVar(&clientCtx.config.Timeout, "timeout-seconds", 10*time.Second, "Seconds to cut the connection") + f.DurationVar(&clientCtx.config.DialTimeout, "dial-timeout-seconds", 5*time.Second, "Seconds to cut the dialing") + f.DurationVar(&clientCtx.config.HandshakeTimeout, "handshake-timeout-seconds", 5*time.Second, "Seconds to cut the handshaking") // Lookups v.BindPFlag("client.endpoints", f.Lookup("endpoints")) @@ -49,17 +52,25 @@ func newClientCommand(ctx *cmdContext) *cobra.Command { v.BindPFlag("client.timeout.handshake", f.Lookup("handshake-timeout-seconds")) clientPreRun := func(cmd *cobra.Command, args []string) { + log.SetLogger("QEDClient", ctx.logLevel) clientCtx.config.APIKey = ctx.apiKey clientCtx.config.Endpoints = v.GetStringSlice("client.endpoints") clientCtx.config.Insecure = v.GetBool("client.insecure") - clientCtx.config.TimeoutSeconds = v.GetInt("client.timeout.connection") - clientCtx.config.DialTimeoutSeconds = v.GetInt("client.timeout.dial") - clientCtx.config.HandshakeTimeoutSeconds = v.GetInt("client.timeout.handshake") - - clientCtx.client = client.NewHTTPClient(*clientCtx.config) - + clientCtx.config.Timeout = v.GetDuration("client.timeout.connection") + clientCtx.config.DialTimeout = v.GetDuration("client.timeout.dial") + clientCtx.config.HandshakeTimeout = v.GetDuration("client.timeout.handshake") + clientCtx.config.ReadPreference = client.Any + clientCtx.config.EnableTopologyDiscovery = false + clientCtx.config.EnableHealthChecks = false + clientCtx.config.MaxRetries = 0 + + client, err := client.NewHTTPClientFromConfig(clientCtx.config) + if err != nil { + panic(fmt.Sprintf("Unable to start http client: %v", err)) + } + clientCtx.client = client } cmd.AddCommand( diff --git a/deploy/aws/network.tf b/deploy/aws/network.tf index fdfce3cc5..0fe9ae05f 100644 --- a/deploy/aws/network.tf +++ b/deploy/aws/network.tf @@ -144,6 +144,12 @@ module "security_group" { protocol = "tcp" cidr_blocks = "${chomp(data.http.ip.body)}/32" }, + { + from_port = 7700 + to_port = 7700 + protocol = "tcp" + cidr_blocks = "${chomp(data.http.ip.body)}/32" + }, { from_port = 9100 to_port = 9100 diff --git a/deploy/aws/provision/main.yml b/deploy/aws/provision/main.yml index 8c04b10d6..f6b00bb8f 100644 --- a/deploy/aws/provision/main.yml +++ b/deploy/aws/provision/main.yml @@ -167,6 +167,7 @@ with_items: - /var/qed/riot-stop.sh - /var/qed/riot-start.sh + - name: wait for raised riot port wait_for: port: 7700 host: "{{ansible_hostname}}" diff --git a/deploy/aws/provision/templates/riot-start.j2 b/deploy/aws/provision/templates/riot-start.sh.j2 similarity index 100% rename from deploy/aws/provision/templates/riot-start.j2 rename to deploy/aws/provision/templates/riot-start.sh.j2 diff --git a/go.mod b/go.mod index e318a8707..4cb751501 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/bbva/qed -go 1.12 +go 1.12.1 require ( github.com/BurntSushi/toml v0.3.1 // indirect @@ -17,6 +17,7 @@ require ( github.com/imdario/mergo v0.3.7 github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/mitchellh/go-homedir v1.1.0 + github.com/pkg/errors v0.8.0 github.com/prometheus/client_golang v0.9.2 github.com/spf13/cobra v0.0.3 github.com/spf13/viper v1.3.1 diff --git a/gossip/auditor/auditor.go b/gossip/auditor/auditor.go index c6a568d1b..e30e23063 100644 --- a/gossip/auditor/auditor.go +++ b/gossip/auditor/auditor.go @@ -19,6 +19,7 @@ package auditor import ( "bytes" "context" + "crypto/tls" "fmt" "io" "io/ioutil" @@ -32,6 +33,7 @@ import ( "github.com/bbva/qed/hashing" "github.com/bbva/qed/log" "github.com/bbva/qed/protocol" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" ) @@ -65,13 +67,25 @@ type Auditor struct { } func NewAuditor(conf Config) (*Auditor, error) { + metrics.QedAuditorInstancesCount.Inc() + + // QED client + transport := http.DefaultTransport.(*http.Transport) + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: false} + httpClient := http.DefaultClient + httpClient.Transport = transport + client, err := client.NewHTTPClient( + client.SetHttpClient(httpClient), + client.SetURLs(conf.QEDUrls[0], conf.QEDUrls[1:]...), + client.SetAPIKey(conf.APIKey), + ) + if err != nil { + return nil, errors.Wrap(err, "Cannot start http client: ") + } + auditor := Auditor{ - qed: client.NewHTTPClient(client.Config{ - Endpoints: conf.QEDUrls, - APIKey: conf.APIKey, - Insecure: false, - }), + qed: client, conf: conf, taskCh: make(chan Task, 100), quitCh: make(chan bool), diff --git a/gossip/monitor/monitor.go b/gossip/monitor/monitor.go index 6629e0c24..1532214da 100644 --- a/gossip/monitor/monitor.go +++ b/gossip/monitor/monitor.go @@ -19,6 +19,7 @@ package monitor import ( "bytes" "context" + "crypto/tls" "fmt" "io" "io/ioutil" @@ -32,6 +33,7 @@ import ( "github.com/bbva/qed/hashing" "github.com/bbva/qed/log" "github.com/bbva/qed/protocol" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" ) @@ -68,12 +70,22 @@ func NewMonitor(conf Config) (*Monitor, error) { // Metrics metrics.QedMonitorInstancesCount.Inc() + // QED client + transport := http.DefaultTransport.(*http.Transport) + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: false} + httpClient := http.DefaultClient + httpClient.Transport = transport + client, err := client.NewHTTPClient( + client.SetHttpClient(httpClient), + client.SetURLs(conf.QEDUrls[0], conf.QEDUrls[1:]...), + client.SetAPIKey(conf.APIKey), + ) + if err != nil { + return nil, errors.Wrap(err, "Cannot start http client: ") + } + monitor := Monitor{ - client: client.NewHTTPClient(client.Config{ - Endpoints: conf.QEDUrls, - APIKey: conf.APIKey, - Insecure: false, - }), + client: client, conf: conf, taskCh: make(chan Task, 100), quitCh: make(chan bool), diff --git a/tests/e2e/add_verify_test.go b/tests/e2e/add_verify_test.go index 94e666653..d82cbe510 100644 --- a/tests/e2e/add_verify_test.go +++ b/tests/e2e/add_verify_test.go @@ -38,7 +38,7 @@ func TestAddVerify(t *testing.T) { var snapshot *protocol.Snapshot var err error - client := getClient(0) + client := getClient(t, 0) let("Add event", func(t *testing.T) { snapshot, err = client.Add(event) @@ -77,7 +77,7 @@ func TestAddVerify(t *testing.T) { var err error var first, last *protocol.Snapshot - client := getClient(0) + client := getClient(t, 0) first, err = client.Add("Test event 1") assert.NoError(t, err) @@ -106,7 +106,7 @@ func TestAddVerify(t *testing.T) { var s [size]*protocol.Snapshot - client := getClient(0) + client := getClient(t, 0) for i := 0; i < size; i++ { s[i], _ = client.Add(fmt.Sprintf("Test Event %d", i)) diff --git a/tests/e2e/agents_test.go b/tests/e2e/agents_test.go index d7121ae61..c6095f546 100644 --- a/tests/e2e/agents_test.go +++ b/tests/e2e/agents_test.go @@ -88,7 +88,7 @@ func TestAgentsWithoutTampering(t *testing.T) { var ss *protocol.SignedSnapshot var err error - client := getClient(0) + client := getClient(t, 0) let("Add event", func(t *testing.T) { snapshot, err = client.Add(event) @@ -139,7 +139,7 @@ func TestAgentsDeleteTampering(t *testing.T) { scenario("Add 1st event. Tamper it. Check auditor alerts correctly", func() { var err error - client := getClient(0) + client := getClient(t, 0) let("Add 1st event", func(t *testing.T) { _, err = client.Add(event) @@ -190,7 +190,7 @@ func TestAgentsPatchTampering(t *testing.T) { tampered := rand.RandomString(10) event2 := rand.RandomString(10) - client := getClient(0) + client := getClient(t, 0) let("Add 1st event", func(t *testing.T) { _, err := client.Add(event) diff --git a/tests/e2e/incremental_test.go b/tests/e2e/incremental_test.go index 0c6fadd3c..3acb4b5c5 100644 --- a/tests/e2e/incremental_test.go +++ b/tests/e2e/incremental_test.go @@ -34,7 +34,7 @@ func TestIncrementalConsistency(t *testing.T) { scenario("Add multiple events and verify consistency between two of them", func() { - client := getClient(0) + client := getClient(t, 0) events := make([]string, 10) snapshots := make([]*protocol.Snapshot, 10) diff --git a/tests/e2e/setup.go b/tests/e2e/setup.go index c0e0cfbbb..3a2816fc1 100644 --- a/tests/e2e/setup.go +++ b/tests/e2e/setup.go @@ -17,6 +17,7 @@ package e2e import ( + "crypto/tls" "fmt" "io/ioutil" "net/http" @@ -35,6 +36,7 @@ import ( "github.com/bbva/qed/gossip/publisher" "github.com/bbva/qed/server" "github.com/bbva/qed/testutils/scope" + "github.com/pkg/errors" ) const ( @@ -296,10 +298,21 @@ func setupServer(id int, joinAddr string, tls bool, t *testing.T) (scope.TestF, return before, after } -func getClient(id int) *client.HTTPClient { - return client.NewHTTPClient(client.Config{ - Endpoints: []string{fmt.Sprintf("http://127.0.0.1:880%d", id)}, - APIKey: APIKey, - Insecure: false, - }) +func getClient(t *testing.T, id int) *client.HTTPClient { + // QED client + transport := http.DefaultTransport.(*http.Transport) + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: false} + httpClient := http.DefaultClient + httpClient.Transport = transport + client, err := client.NewHTTPClient( + client.SetHttpClient(httpClient), + client.SetURLs(fmt.Sprintf("http://127.0.0.1:880%d", id)), + client.SetAPIKey(APIKey), + client.SetTopologyDiscovery(false), + client.SetHealthchecks(false), + ) + if err != nil { + t.Fatal(errors.Wrap(err, "Cannot start http client: ")) + } + return client } diff --git a/tests/riot.go b/tests/riot.go index 74535dbc1..66506e205 100644 --- a/tests/riot.go +++ b/tests/riot.go @@ -151,6 +151,10 @@ func newRiotCommand() *cobra.Command { }() } + if !APIMode && riot.Config.Kind == "" { + log.Fatal("Argument `kind` is required") + } + }, Run: func(cmd *cobra.Command, args []string) { riot.Start(APIMode) @@ -272,8 +276,13 @@ func newAttack(conf Config) { cConf.APIKey = conf.APIKey cConf.Insecure = conf.Insecure + client, err := client.NewHTTPClientFromConfig(cConf) + if err != nil { + panic(err) + } + attack := Attack{ - client: client.NewHTTPClient(*cConf), + client: client, config: conf, kind: kind(conf.Kind), balloonVersion: uint64(conf.NumRequests + conf.Offset - 1),