Skip to content

Commit

Permalink
Refactor response handling
Browse files Browse the repository at this point in the history
Related to elastic#2489
  • Loading branch information
simitt committed Jul 24, 2019
1 parent 4c24607 commit 7c9eaa5
Show file tree
Hide file tree
Showing 15 changed files with 215 additions and 240 deletions.
95 changes: 43 additions & 52 deletions beater/agent_config_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,56 +55,54 @@ func agentConfigHandler(kbClient kibana.Client, config *agentConfig, secretToken
fetcher := agentcfg.NewFetcher(kbClient, config.Cache.Expiration)

var handler = func(c *request.Context) {
sendResp := wrap(c)
sendErr := wrapErr(c, secretToken)

if valid, shortMsg, detailMsg := validateKbClient(kbClient); !valid {
sendErr(http.StatusServiceUnavailable, shortMsg, detailMsg)
// error handling
c.Header().Set(headers.CacheControl, errCacheControl)
if valid, fullMsg := validateKbClient(kbClient); !valid {
c.WriteWithError(extractInternalError(fullMsg, secretToken))
return
}

query, requestErr := buildQuery(c.Req)
if requestErr != nil {
if strings.Contains(requestErr.Error(), errMsgMethodUnsupported) {
sendErr(http.StatusMethodNotAllowed, errMsgMethodUnsupported, requestErr.Error())
return
}
sendErr(http.StatusBadRequest, errMsgInvalidQuery, requestErr.Error())
query, queryErr := buildQuery(c.Req)
if queryErr != nil {
c.WriteWithError(extractQueryError(queryErr.Error(), secretToken))
return
}

cfg, upstreamEtag, internalErr := fetcher.Fetch(query, nil)
if internalErr != nil {
sendErr(http.StatusServiceUnavailable, internalErrMsg(internalErr.Error()), internalErr.Error())
cfg, upstreamEtag, err := fetcher.Fetch(query, nil)
if err != nil {
c.WriteWithError(extractInternalError(err.Error(), secretToken))
return
}

// configuration successfully fetched
c.Header().Set(headers.CacheControl, cacheControl)
etag := fmt.Sprintf("\"%s\"", upstreamEtag)
c.Header().Set(headers.Etag, etag)
if etag == c.Req.Header.Get(headers.IfNoneMatch) {
sendResp(nil, http.StatusNotModified, cacheControl)
c.Write(nil, http.StatusNotModified)
} else {
sendResp(cfg, http.StatusOK, cacheControl)
c.Write(cfg, http.StatusOK)
}
}

return killSwitchHandler(kbClient != nil,
authHandler(secretToken, handler))
}

func validateKbClient(client kibana.Client) (bool, string, string) {
func validateKbClient(client kibana.Client) (bool, string) {
if client == nil {
return false, errMsgKibanaDisabled, errMsgKibanaDisabled
return false, errMsgKibanaDisabled
}
if !client.Connected() {
return false, errMsgNoKibanaConnection, errMsgNoKibanaConnection
return false, errMsgNoKibanaConnection
}
if supported, _ := client.SupportsVersion(minKibanaVersion); !supported {
version, _ := client.GetVersion()
return false, errMsgKibanaVersionNotCompatible, fmt.Sprintf("min required Kibana version %+v, "+
"configured Kibana version %+v", minKibanaVersion, version)

return false, fmt.Sprintf("%s: min version %+v, configured version %+v",
errMsgKibanaVersionNotCompatible, minKibanaVersion, version.String())
}
return true, "", ""
return true, ""
}

// Returns (zero, error) if request body can't be unmarshalled or service.name is missing
Expand All @@ -129,40 +127,33 @@ func buildQuery(r *http.Request) (query agentcfg.Query, err error) {
return
}

func wrap(c *request.Context) func(interface{}, int, string) {
return func(body interface{}, code int, cacheControl string) {
c.Header().Set(headers.CacheControl, cacheControl)
if body == nil {
c.WriteHeader(code)
return
}
c.Send(body, code)
func extractInternalError(msg string, token string) (string, string, int) {
var shortMsg = errMsgServiceUnavailable
switch {
case msg == errMsgKibanaDisabled || msg == errMsgNoKibanaConnection:
shortMsg = msg
case strings.Contains(msg, errMsgKibanaVersionNotCompatible):
shortMsg = errMsgKibanaVersionNotCompatible
case strings.Contains(msg, agentcfg.ErrMsgSendToKibanaFailed):
shortMsg = agentcfg.ErrMsgSendToKibanaFailed
case strings.Contains(msg, agentcfg.ErrMsgMultipleChoices):
shortMsg = agentcfg.ErrMsgMultipleChoices
case strings.Contains(msg, agentcfg.ErrMsgReadKibanaResponse):
shortMsg = agentcfg.ErrMsgReadKibanaResponse
}
return authErrMsg(msg, shortMsg, token), msg, http.StatusServiceUnavailable
}

func wrapErr(c *request.Context, token string) func(int, string, string) {
authErrMsg := func(errMsg, logMsg string) map[string]string {
if token == "" {
return map[string]string{"error": errMsg}
}
return map[string]string{"error": logMsg}
}

return func(status int, errMsg, logMsg string) {
c.Header().Set(headers.CacheControl, errCacheControl)
body := authErrMsg(errMsg, logMsg)
c.SendError(body, logMsg, status)
func extractQueryError(msg string, token string) (string, string, int) {
if strings.Contains(msg, errMsgMethodUnsupported) {
return authErrMsg(msg, errMsgMethodUnsupported, token), msg, http.StatusMethodNotAllowed
}
return authErrMsg(msg, errMsgInvalidQuery, token), msg, http.StatusBadRequest
}

func internalErrMsg(msg string) string {
switch {
case strings.Contains(msg, agentcfg.ErrMsgSendToKibanaFailed):
return agentcfg.ErrMsgSendToKibanaFailed
case strings.Contains(msg, agentcfg.ErrMsgMultipleChoices):
return agentcfg.ErrMsgMultipleChoices
case strings.Contains(msg, agentcfg.ErrMsgReadKibanaResponse):
return agentcfg.ErrMsgReadKibanaResponse
func authErrMsg(fullMsg, shortMsg string, token string) string {
if token == "" {
return shortMsg
}
return errMsgServiceUnavailable
return fullMsg
}
7 changes: 4 additions & 3 deletions beater/agent_config_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ var (
},

"InvalidVersion": {
kbClient: tests.MockKibana(http.StatusServiceUnavailable, m{}, *common.MustNewVersion("7.2.0"), true),
kbClient: tests.MockKibana(http.StatusServiceUnavailable, m{},
*common.MustNewVersion("7.2.0"), true),
method: http.MethodGet,
respStatus: http.StatusServiceUnavailable,
respCacheControlHeader: "max-age=300, must-revalidate",
respBody: errWrap(errMsgKibanaVersionNotCompatible),
respBodyToken: errWrap("min required Kibana version 7.3.0," +
" configured Kibana version {version:7.2.0 Major:7 Minor:2 Bugfix:0 Meta:}"),
respBodyToken: errWrap(fmt.Sprintf("%s: min version 7.3.0, configured version 7.2.0",
errMsgKibanaVersionNotCompatible)),
},

"NoService": {
Expand Down
3 changes: 1 addition & 2 deletions beater/asset_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ type assetHandler struct {

func (h *assetHandler) Handle(beaterConfig *Config, report publish.Reporter) Handler {
return func(c *request.Context) {
res := h.processRequest(c.Req, report)
sendStatus(c, res)
h.processRequest(c.Req, report).writeTo(c)
}
}

Expand Down
49 changes: 21 additions & 28 deletions beater/common_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ var (
supportedMethods = fmt.Sprintf("%s, %s", http.MethodPost, http.MethodOptions)
)

type serverResponse struct {
err error
code int
counter *monitoring.Int
body map[string]interface{}
}

var (
serverMetrics = monitoring.Default.NewRegistry("apm-server.server", monitoring.PublishExpvar)
counter = func(s string) *monitoring.Int {
Expand Down Expand Up @@ -140,6 +133,23 @@ var (
}
)

type serverResponse struct {
code int
counter *monitoring.Int
err error
body interface{}
}

func (r serverResponse) writeTo(c *request.Context) {
if r.code >= http.StatusBadRequest || r.err != nil {
//TODO: remove extra handling when changing logs
err := map[string]string{"error": r.err.Error()}
c.WriteWithError(r.err.Error(), err, r.code)
return
}
c.Write(r.body, r.code)
}

func requestTimeHandler(h Handler) Handler {
return func(c *request.Context) {
c.Req = c.Req.WithContext(utility.ContextWithRequestTime(c.Req.Context(), time.Now()))
Expand All @@ -152,15 +162,15 @@ func killSwitchHandler(killSwitch bool, h Handler) Handler {
if killSwitch {
h(c)
} else {
sendStatus(c, forbiddenResponse(errors.New("endpoint is disabled")))
forbiddenResponse(errors.New("endpoint is disabled")).writeTo(c)
}
}
}

func authHandler(secretToken string, h Handler) Handler {
return func(c *request.Context) {
if !isAuthorized(c.Req, secretToken) {
sendStatus(c, unauthorizedResponse)
unauthorizedResponse.writeTo(c)
return
}
h(c)
Expand Down Expand Up @@ -219,32 +229,15 @@ func corsHandler(allowedOrigins []string, h Handler) Handler {

c.Header().Set(headers.ContentLength, "0")

sendStatus(c, okResponse)
okResponse.writeTo(c)

} else if validOrigin {
// we need to check the origin and set the ACAO header in both the OPTIONS preflight and the actual request
c.Header().Set(headers.AccessControlAllowOrigin, origin)
h(c)

} else {
sendStatus(c, forbiddenResponse(errors.New("origin: '"+origin+"' is not allowed")))
forbiddenResponse(errors.New("origin: '" + origin + "' is not allowed")).writeTo(c)
}
}
}

//TODO: move to Context when reworking response handling.
func sendStatus(c *request.Context, res serverResponse) {
if res.err != nil {
body := map[string]interface{}{"error": res.err.Error()}
//TODO: refactor response handling: get rid of additional `error` and just pass in error
c.SendError(body, body, res.code)
return
}

if res.body == nil {
c.WriteHeader(res.code)
return
}

c.Send(res.body, res.code)
}
14 changes: 6 additions & 8 deletions beater/common_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ func TestIncCounter(t *testing.T) {
c.Reset(httptest.NewRecorder(), req)
monitoringHandler(func(c *request.Context) {
c.AddMonitoringCt(res.counter)
if res.err != nil {
c.SendError(nil, res.err, res.code)
}
res.writeTo(c)
})(c)
assert.Equal(t, int64(i), res.counter.Get(), string(res.code))
}
Expand Down Expand Up @@ -110,11 +108,11 @@ func TestOkBody(t *testing.T) {
w := httptest.NewRecorder()
c := &request.Context{}
c.Reset(w, req)
sendStatus(c, serverResponse{
serverResponse{
code: http.StatusNonAuthoritativeInfo,
counter: requestCounter,
body: map[string]interface{}{"some": "body"},
})
}.writeTo(c)
rsp := w.Result()
got := body(t, rsp)
assert.Equal(t, "{\"some\":\"body\"}\n", string(got))
Expand All @@ -128,11 +126,11 @@ func TestOkBodyJson(t *testing.T) {
w := httptest.NewRecorder()
c := &request.Context{}
c.Reset(w, req)
sendStatus(c, serverResponse{
serverResponse{
code: http.StatusNonAuthoritativeInfo,
counter: requestCounter,
body: map[string]interface{}{"version": "1.0"},
})
}.writeTo(c)
rsp := w.Result()
got := body(t, rsp)
assert.Equal(t,
Expand Down Expand Up @@ -167,7 +165,7 @@ func TestAccept(t *testing.T) {
w := httptest.NewRecorder()
c := &request.Context{}
c.Reset(w, req)
sendStatus(c, cannotValidateResponse(errors.New("error message")))
cannotValidateResponse(errors.New("error message")).writeTo(c)
rsp := w.Result()
got := body(t, rsp)
assert.Equal(t, 400, w.Code)
Expand Down
Loading

0 comments on commit 7c9eaa5

Please sign in to comment.