Skip to content

Commit

Permalink
Merge pull request #510 from ctennis/more_descriptive_errors
Browse files Browse the repository at this point in the history
More descriptive errors with specific HTTP return codes
  • Loading branch information
armon committed Aug 11, 2015
2 parents f4c7846 + 1621f5e commit be88630
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 23 deletions.
5 changes: 5 additions & 0 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
status = http.StatusServiceUnavailable
}

// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}

w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)

Expand Down
34 changes: 34 additions & 0 deletions http/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package http

import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)

Expand Down Expand Up @@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
}
testResponseStatus(t, resp, 503)
}

func TestHandler_error(t *testing.T) {
w := httptest.NewRecorder()

respondError(w, 500, errors.New("Test Error"))

if w.Code != 500 {
t.Fatalf("expected 500, got %d", w.Code)
}

// The code inside of the error should override
// the argument to respondError
w2 := httptest.NewRecorder()
e := logical.CodedError(403, "error text")

respondError(w2, 500, e)

if w2.Code != 403 {
t.Fatalf("expected 403, got %d", w2.Code)
}

// vault.ErrSealed is a special case
w3 := httptest.NewRecorder()

respondError(w3, 400, vault.ErrSealed)

if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code)
}

}
2 changes: 2 additions & 0 deletions http/sys_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func handleSysMount(
"description": req.Description,
},
}))

if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
Expand All @@ -149,6 +150,7 @@ func handleSysUnmount(
Path: "sys/mounts/" + path,
Connection: getConnection(r),
}))

if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
Expand Down
24 changes: 24 additions & 0 deletions logical/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package logical

type HTTPCodedError interface {
Error() string
Code() int
}

func CodedError(c int, s string) HTTPCodedError {
return &codedError{s,c}
}

type codedError struct {
s string
code int
}

func (e *codedError) Error() string {
return e.s
}

func (e *codedError) Code() int {
return e.code
}

2 changes: 1 addition & 1 deletion logical/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ var (
// ErrInvalidRequest is returned if the request is invalid
ErrInvalidRequest = errors.New("invalid request")

// ErrPermissionDeneid is returned if the client is not authorized
// ErrPermissionDenied is returned if the client is not authorized
ErrPermissionDenied = errors.New("permission denied")
)
2 changes: 1 addition & 1 deletion logical/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
// avoided like the HTTPContentType. The value must be a byte slice.
HTTPRawBody = "http_raw_body"

// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType.
// HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
// This can only be specified for non-secrets, and should should be similarly
// avoided like the HTTPContentType. The value must be an integer.
HTTPStatusCode = "http_status_code"
Expand Down
46 changes: 29 additions & 17 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,23 @@ func (b *SystemBackend) handleMount(
// Attempt mount
if err := b.Core.mount(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
}

// used to intercept an HTTPCodedError so it goes back to callee
func handleError(
err error) (*logical.Response, error) {
switch err.(type) {
case logical.HTTPCodedError:
return logical.ErrorResponse(err.Error()), err
default:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
}

// handleUnmount is used to unmount a path
func (b *SystemBackend) handleUnmount(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
Expand All @@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
// Attempt unmount
if err := b.Core.unmount(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
Expand All @@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
// Attempt remount
if err := b.Core.remount(fromPath, toPath); err != nil {
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
Expand All @@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
resp, err := b.Core.expiration.Renew(leaseID, increment)
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return resp, err
}
Expand All @@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
// Invoke the expiration manager directly
if err := b.Core.expiration.Revoke(leaseID); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
// Invoke the expiration manager directly
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
// Attempt enabling
if err := b.Core.enableCredential(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
// Attempt disable
if err := b.Core.disableCredential(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(

policy, err := b.Core.policy.GetPolicy(name)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

if policy == nil {
Expand All @@ -567,15 +579,15 @@ func (b *SystemBackend) handlePolicySet(
// Validate the rules parse
parse, err := Parse(rules)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

// Override the name
parse.Name = name

// Update the policy
if err := b.Core.policy.SetPolicy(parse); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if err := b.Core.policy.DeletePolicy(name); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
// Attempt enabling
if err := b.Core.enableAudit(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
// Attempt disable
if err := b.Core.disableAudit(path); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(

entry, err := b.Core.barrier.Get(path)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
if entry == nil {
return nil, nil
Expand Down Expand Up @@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
}

if err := b.Core.barrier.Delete(path); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
newTerm, err := b.Core.barrier.Rotate()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")

Expand Down
8 changes: 4 additions & 4 deletions vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
// barrier view for the backends.
backendBarrierPrefix = "logical/"

// systemBarrierPrefix is sthe prefix used for the
// systemBarrierPrefix is the prefix used for the
// system logical backend.
systemBarrierPrefix = "sys/"
)
Expand Down Expand Up @@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
me.Path += "/"
}

// Prevent protected paths from being unmounted
// Prevent protected paths from being mounted
for _, p := range protectedMounts {
if strings.HasPrefix(me.Path, p) {
return fmt.Errorf("cannot mount '%s'", me.Path)
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
}
}

// Verify there is no conflicting mount
if match := c.router.MatchingMount(me.Path); match != "" {
return fmt.Errorf("existing mount at '%s'", match)
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}

// Generate a new UUID and view
Expand Down

0 comments on commit be88630

Please sign in to comment.