Skip to content

Commit

Permalink
[beater] Improve Response Writing (#2494)
Browse files Browse the repository at this point in the history
Change response write handling to set a response per context and then call write on the context. Introduce a shared beatertest package with testing helper structs and methods.

Related to #2489
  • Loading branch information
simitt authored Aug 5, 2019
1 parent 88513e3 commit e0136d5
Show file tree
Hide file tree
Showing 52 changed files with 2,106 additions and 866 deletions.
165 changes: 101 additions & 64 deletions beater/agent_config_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ import (
const (
errMaxAgeDuration = 5 * time.Minute

errMsgInvalidQuery = "invalid query"
errMsgKibanaDisabled = "disabled Kibana configuration"
errMsgKibanaVersionNotCompatible = "not a compatible Kibana version"
errMsgMethodUnsupported = "method not supported"
errMsgNoKibanaConnection = "unable to retrieve connection to Kibana"
errMsgServiceUnavailable = "service unavailable"
msgInvalidQuery = "invalid query"
msgKibanaDisabled = "disabled Kibana configuration"
msgKibanaVersionNotCompatible = "not a compatible Kibana version"
msgMethodUnsupported = "method not supported"
msgNoKibanaConnection = "unable to retrieve connection to Kibana"
msgServiceUnavailable = "service unavailable"
)

var (
errMsgKibanaDisabled = errors.New(msgKibanaDisabled)
errMsgNoKibanaConnection = errors.New(msgNoKibanaConnection)

minKibanaVersion = common.MustNewVersion("7.3.0")
errCacheControl = fmt.Sprintf("max-age=%v, must-revalidate", errMaxAgeDuration.Seconds())

Expand All @@ -59,76 +62,95 @@ var (
//}

// reflects current behavior
countRequest = intakeResultIDToMonitoringInt(request.IDRequestCount)
countRequest = IntakeResultIDToMonitoringInt(request.IDRequestCount)

mapping = map[request.ResultID]*monitoring.Int{
request.IDRequestCount: countRequest,
}
)

func acmResultIDToMonitoringInt(id request.ResultID) *monitoring.Int {
// ACMResultIDToMonitoringInt takes a request.ResultID and maps it to a monitoring counter. If no mapping is found,
// nil is returned.
func ACMResultIDToMonitoringInt(id request.ResultID) *monitoring.Int {
if i, ok := mapping[id]; ok {
return i
}
return nil
}

func agentConfigHandler(kbClient kibana.Client, config *agentConfig, secretToken string) request.Handler {
// AgentConfigHandler returns a request.Handler for managing ACM requests.
func AgentConfigHandler(kbClient kibana.Client, config *AgentConfig) request.Handler {
cacheControl := fmt.Sprintf("max-age=%v, must-revalidate", config.Cache.Expiration.Seconds())
fetcher := agentcfg.NewFetcher(kbClient, config.Cache.Expiration)

return func(c *request.Context) {
sendResp := wrap(c)
sendErr := wrapErr(c, secretToken)
// error handling
c.Header().Set(headers.CacheControl, errCacheControl)

if valid, shortMsg, detailMsg := validateKbClient(kbClient); !valid {
sendErr(http.StatusServiceUnavailable, shortMsg, detailMsg)
if valid := validateKbClient(c, kbClient, c.TokenSet); !valid {
c.Write()
return
}

query, requestErr := buildQuery(c.Request)
if requestErr != nil {
if strings.Contains(requestErr.Error(), errMsgMethodUnsupported) {
sendErr(http.StatusMethodNotAllowed, errMsgMethodUnsupported, requestErr.Error())
return
}
sendErr(http.StatusBadRequest, errMsgInvalidQuery, requestErr.Error())
query, queryErr := buildQuery(c.Request)
if queryErr != nil {
extractQueryError(c, queryErr, c.TokenSet)
c.Write()
return
}

cfg, upstreamEtag, internalErr := fetcher.Fetch(query, nil)
if internalErr != nil {
sendErr(http.StatusServiceUnavailable, internalErrMsg(internalErr.Error()), internalErr.Error())
cfg, upstreamEtag, err := fetcher.Fetch(query, nil)
if err != nil {
extractInternalError(c, err, c.TokenSet)
c.Write()
return
}

// configuration successfully fetched
c.Header().Set(headers.CacheControl, cacheControl)
etag := fmt.Sprintf("\"%s\"", upstreamEtag)
c.Header().Set(headers.Etag, etag)
if etag == c.Request.Header.Get(headers.IfNoneMatch) {
sendResp(nil, http.StatusNotModified, cacheControl)
c.Result.SetDefault(request.IDResponseValidNotModified)
} else {
sendResp(cfg, http.StatusOK, cacheControl)
c.Result.SetWithBody(request.IDResponseValidOK, cfg)
}
c.Write()
}
}

func validateKbClient(client kibana.Client) (bool, string, string) {
func validateKbClient(c *request.Context, client kibana.Client, withAuth bool) bool {
if client == nil {
return false, errMsgKibanaDisabled, errMsgKibanaDisabled
c.Result.Set(request.IDResponseErrorsServiceUnavailable,
http.StatusServiceUnavailable,
msgKibanaDisabled,
msgKibanaDisabled,
errMsgKibanaDisabled)
return false
}
if !client.Connected() {
return false, errMsgNoKibanaConnection, errMsgNoKibanaConnection
c.Result.Set(request.IDResponseErrorsServiceUnavailable,
http.StatusServiceUnavailable,
msgNoKibanaConnection,
msgNoKibanaConnection,
errMsgNoKibanaConnection)
return false
}
if supported, _ := client.SupportsVersion(minKibanaVersion); !supported {
version, _ := client.GetVersion()
return false, errMsgKibanaVersionNotCompatible, fmt.Sprintf("min required Kibana version %+v, "+
"configured Kibana version %+v", minKibanaVersion, version)

errMsg := fmt.Sprintf("%s: min version %+v, configured version %+v",
msgKibanaVersionNotCompatible, minKibanaVersion, version.String())
body := authErrMsg(errMsg, msgKibanaVersionNotCompatible, withAuth)
c.Result.Set(request.IDResponseErrorsServiceUnavailable,
http.StatusServiceUnavailable,
msgKibanaVersionNotCompatible,
body,
errors.New(errMsg))
return false
}
return true, "", ""
return true
}

// Returns (zero, error) if request body can't be unmarshalled or service.name is missing
// Returns (zero, zero) if request method is not GET or POST
func buildQuery(r *http.Request) (query agentcfg.Query, err error) {
switch r.Method {
case http.MethodPost:
Expand All @@ -140,7 +162,7 @@ func buildQuery(r *http.Request) (query agentcfg.Query, err error) {
params.Get(agentcfg.ServiceEnv),
)
default:
err = errors.Errorf("%s: %s", errMsgMethodUnsupported, r.Method)
err = errors.Errorf("%s: %s", msgMethodUnsupported, r.Method)
}

if err == nil && query.Service.Name == "" {
Expand All @@ -149,40 +171,55 @@ func buildQuery(r *http.Request) (query agentcfg.Query, err error) {
return
}

func wrap(c *request.Context) func(interface{}, int, string) {
return func(body interface{}, code int, cacheControl string) {
c.Header().Set(headers.CacheControl, cacheControl)
if body == nil {
c.WriteHeader(code)
return
}
c.Send(body, code)
}
}
func extractInternalError(c *request.Context, err error, withAuth bool) {
msg := err.Error()
var body interface{}
var keyword string
switch {
case strings.Contains(msg, agentcfg.ErrMsgSendToKibanaFailed):
body = authErrMsg(msg, agentcfg.ErrMsgSendToKibanaFailed, withAuth)
keyword = agentcfg.ErrMsgSendToKibanaFailed

func wrapErr(c *request.Context, token string) func(int, string, string) {
authErrMsg := func(errMsg, logMsg string) map[string]string {
if token == "" {
return map[string]string{"error": errMsg}
}
return map[string]string{"error": logMsg}
case strings.Contains(msg, agentcfg.ErrMsgMultipleChoices):
body = authErrMsg(msg, agentcfg.ErrMsgMultipleChoices, withAuth)
keyword = agentcfg.ErrMsgMultipleChoices

case strings.Contains(msg, agentcfg.ErrMsgReadKibanaResponse):
body = authErrMsg(msg, agentcfg.ErrMsgReadKibanaResponse, withAuth)
keyword = agentcfg.ErrMsgReadKibanaResponse

default:
body = authErrMsg(msg, msgServiceUnavailable, withAuth)
keyword = msgServiceUnavailable
}

return func(status int, errMsg, logMsg string) {
c.Header().Set(headers.CacheControl, errCacheControl)
body := authErrMsg(errMsg, logMsg)
c.SendError(body, logMsg, status)
c.Result.Set(request.IDResponseErrorsServiceUnavailable,
http.StatusServiceUnavailable,
keyword,
body,
err)
}

func extractQueryError(c *request.Context, err error, withAuth bool) {
msg := err.Error()
if strings.Contains(msg, msgMethodUnsupported) {
c.Result.Set(request.IDResponseErrorsMethodNotAllowed,
http.StatusMethodNotAllowed,
msgMethodUnsupported,
authErrMsg(msg, msgMethodUnsupported, withAuth),
err)
return
}
c.Result.Set(request.IDResponseErrorsInvalidQuery,
http.StatusBadRequest,
msgInvalidQuery,
authErrMsg(msg, msgInvalidQuery, withAuth),
err)
}

func internalErrMsg(msg string) string {
switch {
case strings.Contains(msg, agentcfg.ErrMsgSendToKibanaFailed):
return agentcfg.ErrMsgSendToKibanaFailed
case strings.Contains(msg, agentcfg.ErrMsgMultipleChoices):
return agentcfg.ErrMsgMultipleChoices
case strings.Contains(msg, agentcfg.ErrMsgReadKibanaResponse):
return agentcfg.ErrMsgReadKibanaResponse
func authErrMsg(fullMsg, shortMsg string, withAuth bool) string {
if withAuth {
return fullMsg
}
return errMsgServiceUnavailable
return shortMsg
}
113 changes: 113 additions & 0 deletions beater/agent_config_handler_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package beater

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/elastic/beats/libbeat/common"

"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"

"github.com/elastic/apm-server/beater/beatertest"
"github.com/elastic/apm-server/beater/headers"
"github.com/elastic/apm-server/beater/request"
"github.com/elastic/apm-server/tests"
)

func TestAgentConfigHandler_RequireAuthorizationMiddleware(t *testing.T) {
t.Run("Unauthorized", func(t *testing.T) {
cfg := cfgEnabledACM()
cfg.SecretToken = "1234"
rec := requestToACMHandler(t, cfg)

assert.Equal(t, http.StatusUnauthorized, rec.Code)
tests.AssertApproveResult(t, acmApprovalPath(t.Name()), rec.Body.Bytes())
})

t.Run("Authorized", func(t *testing.T) {
cfg := cfgEnabledACM()
cfg.SecretToken = "1234"
h, err := agentConfigHandler(cfg, beatertest.NilReporter)
require.NoError(t, err)
c, rec := beatertest.DefaultContextWithResponseRecorder()
c.Request.Header.Set(headers.Authorization, "Bearer 1234")
h(c)

assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
tests.AssertApproveResult(t, acmApprovalPath(t.Name()), rec.Body.Bytes())
})
}

func TestAgentConfigHandler_KillSwitchMiddleware(t *testing.T) {
t.Run("Off", func(t *testing.T) {
rec := requestToACMHandler(t, DefaultConfig(beatertest.MockBeatVersion()))

assert.Equal(t, http.StatusForbidden, rec.Code)
tests.AssertApproveResult(t, acmApprovalPath(t.Name()), rec.Body.Bytes())

})

t.Run("On", func(t *testing.T) {
rec := requestToACMHandler(t, cfgEnabledACM())

assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
tests.AssertApproveResult(t, acmApprovalPath(t.Name()), rec.Body.Bytes())
})
}

func TestAgentConfigHandler_PanicMiddleware(t *testing.T) {
h, err := agentConfigHandler(DefaultConfig(beatertest.MockBeatVersion()), beatertest.NilReporter)
require.NoError(t, err)
rec := &beatertest.WriterPanicOnce{}
c := &request.Context{}
c.Reset(rec, httptest.NewRequest(http.MethodGet, "/", nil))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.StatusCode)
tests.AssertApproveResult(t, acmApprovalPath(t.Name()), rec.Body.Bytes())
}

func TestAgentConfigHandler_MonitoringMiddleware(t *testing.T) {
h, err := agentConfigHandler(DefaultConfig(beatertest.MockBeatVersion()), beatertest.NilReporter)
require.NoError(t, err)
c, _ := beatertest.DefaultContextWithResponseRecorder()

expected := map[request.ResultID]int{request.IDRequestCount: 1}
equal, result := beatertest.CompareMonitoringInt(h, c, expected, serverMetrics, ACMResultIDToMonitoringInt)
assert.True(t, equal, result)
}

func requestToACMHandler(t *testing.T, cfg *Config) *httptest.ResponseRecorder {
h, err := agentConfigHandler(cfg, beatertest.NilReporter)
require.NoError(t, err)
c, rec := beatertest.DefaultContextWithResponseRecorder()
h(c)
return rec
}

func cfgEnabledACM() *Config {
cfg := DefaultConfig(beatertest.MockBeatVersion())
cfg.Kibana = common.MustNewConfigFrom(map[string]interface{}{"enabled": "true"})
return cfg
}

func acmApprovalPath(f string) string { return "test_integration/acm/" + f }
Loading

0 comments on commit e0136d5

Please sign in to comment.