Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More descriptive errors with specific HTTP return codes #510

Merged
merged 2 commits into from
Aug 11, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
status = http.StatusServiceUnavailable
}

// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}

w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)

Expand Down
34 changes: 34 additions & 0 deletions http/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package http

import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)

Expand Down Expand Up @@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
}
testResponseStatus(t, resp, 503)
}

func TestHandler_error(t *testing.T) {
w := httptest.NewRecorder()

respondError(w, 500, errors.New("Test Error"))

if w.Code != 500 {
t.Fatalf("expected 500, got %d", w.Code)
}

// The code inside of the error should override
// the argument to respondError
w2 := httptest.NewRecorder()
e := logical.CodedError(403, "error text")

respondError(w2, 500, e)

if w2.Code != 403 {
t.Fatalf("expected 403, got %d", w2.Code)
}

// vault.ErrSealed is a special case
w3 := httptest.NewRecorder()

respondError(w3, 400, vault.ErrSealed)

if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code)
}

}
2 changes: 2 additions & 0 deletions http/sys_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func handleSysMount(
"description": req.Description,
},
}))

if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
Expand All @@ -149,6 +150,7 @@ func handleSysUnmount(
Path: "sys/mounts/" + path,
Connection: getConnection(r),
}))

if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
Expand Down
24 changes: 24 additions & 0 deletions logical/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package logical

type HTTPCodedError interface {
Error() string
Code() int
}

func CodedError(c int, s string) HTTPCodedError {
return &codedError{s,c}
}

type codedError struct {
s string
code int
}

func (e *codedError) Error() string {
return e.s
}

func (e *codedError) Code() int {
return e.code
}

2 changes: 1 addition & 1 deletion logical/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ var (
// ErrInvalidRequest is returned if the request is invalid
ErrInvalidRequest = errors.New("invalid request")

// ErrPermissionDeneid is returned if the client is not authorized
// ErrPermissionDenied is returned if the client is not authorized
ErrPermissionDenied = errors.New("permission denied")
)
2 changes: 1 addition & 1 deletion logical/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
// avoided like the HTTPContentType. The value must be a byte slice.
HTTPRawBody = "http_raw_body"

// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType.
// HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
// This can only be specified for non-secrets, and should should be similarly
// avoided like the HTTPContentType. The value must be an integer.
HTTPStatusCode = "http_status_code"
Expand Down
46 changes: 29 additions & 17 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,23 @@ func (b *SystemBackend) handleMount(
// Attempt mount
if err := b.Core.mount(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
}

// used to intercept an HTTPCodedError so it goes back to callee
func handleError(
err error) (*logical.Response, error) {
switch err.(type) {
case logical.HTTPCodedError:
return logical.ErrorResponse(err.Error()), err
default:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
}

// handleUnmount is used to unmount a path
func (b *SystemBackend) handleUnmount(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
Expand All @@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
// Attempt unmount
if err := b.Core.unmount(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
Expand All @@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
// Attempt remount
if err := b.Core.remount(fromPath, toPath); err != nil {
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

return nil, nil
Expand All @@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
resp, err := b.Core.expiration.Renew(leaseID, increment)
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return resp, err
}
Expand All @@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
// Invoke the expiration manager directly
if err := b.Core.expiration.Revoke(leaseID); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
// Invoke the expiration manager directly
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
// Attempt enabling
if err := b.Core.enableCredential(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
// Attempt disable
if err := b.Core.disableCredential(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(

policy, err := b.Core.policy.GetPolicy(name)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

if policy == nil {
Expand All @@ -567,15 +579,15 @@ func (b *SystemBackend) handlePolicySet(
// Validate the rules parse
parse, err := Parse(rules)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}

// Override the name
parse.Name = name

// Update the policy
if err := b.Core.policy.SetPolicy(parse); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if err := b.Core.policy.DeletePolicy(name); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
// Attempt enabling
if err := b.Core.enableAudit(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
// Attempt disable
if err := b.Core.disableAudit(path); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand All @@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(

entry, err := b.Core.barrier.Get(path)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
if entry == nil {
return nil, nil
Expand Down Expand Up @@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
}

if err := b.Core.barrier.Delete(path); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
Expand Down Expand Up @@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
newTerm, err := b.Core.barrier.Rotate()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")

Expand Down
8 changes: 4 additions & 4 deletions vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
// barrier view for the backends.
backendBarrierPrefix = "logical/"

// systemBarrierPrefix is sthe prefix used for the
// systemBarrierPrefix is the prefix used for the
// system logical backend.
systemBarrierPrefix = "sys/"
)
Expand Down Expand Up @@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
me.Path += "/"
}

// Prevent protected paths from being unmounted
// Prevent protected paths from being mounted
for _, p := range protectedMounts {
if strings.HasPrefix(me.Path, p) {
return fmt.Errorf("cannot mount '%s'", me.Path)
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
}
}

// Verify there is no conflicting mount
if match := c.router.MatchingMount(me.Path); match != "" {
return fmt.Errorf("existing mount at '%s'", match)
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}

// Generate a new UUID and view
Expand Down