diff --git a/errors/errors.go b/errors/errors.go index 8327ecb..c3027d0 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -20,7 +20,11 @@ package xerrors -import "bytes" +import ( + "bytes" + "errors" + "fmt" +) // FirstError will return the first non nil error func FirstError(errs ...error) error { @@ -80,6 +84,19 @@ type invalidParamsError struct { containedError } +// Wrap wraps an error with a message but preserves the type of the error. +func Wrap(err error, msg string) error { + renamed := errors.New(msg + ": " + err.Error()) + return NewRenamedError(err, renamed) +} + +// Wrapf formats according to a format specifier and uses that string to +// wrap an error while still preserving the type of the error. +func Wrapf(err error, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + return Wrap(err, msg) +} + // NewInvalidParamsError creates a new invalid params error func NewInvalidParamsError(inner error) error { return invalidParamsError{containedError{inner}} diff --git a/errors/errors_test.go b/errors/errors_test.go index 2d4e4b9..013cb5c 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -29,6 +29,48 @@ import ( "github.com/stretchr/testify/require" ) +func TestWrap(t *testing.T) { + inner := errors.New("detailed error message") + err := NewInvalidParamsError(inner) + wrappedErr := Wrap(err, "context about params error") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about params error: detailed error message", wrappedErr.Error()) + assert.True(t, IsInvalidParams(wrappedErr)) + + err = NewRetryableError(inner) + wrappedErr = Wrap(err, "context about retryable error") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about retryable error: detailed error message", wrappedErr.Error()) + assert.True(t, IsRetryableError(wrappedErr)) + + err = NewNonRetryableError(inner) + wrappedErr = Wrap(err, "context about nonretryable error") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about nonretryable error: detailed error message", wrappedErr.Error()) + assert.True(t, IsNonRetryableError(wrappedErr)) +} + +func TestWrapf(t *testing.T) { + inner := errors.New("detailed error message") + err := NewInvalidParamsError(inner) + wrappedErr := Wrapf(err, "context about %s error", "params") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about params error: detailed error message", wrappedErr.Error()) + assert.True(t, IsInvalidParams(wrappedErr)) + + err = NewRetryableError(inner) + wrappedErr = Wrapf(err, "context about %s error", "retryable") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about retryable error: detailed error message", wrappedErr.Error()) + assert.True(t, IsRetryableError(wrappedErr)) + + err = NewNonRetryableError(inner) + wrappedErr = Wrapf(err, "context about %s error", "nonretryable") + assert.Error(t, wrappedErr) + assert.Equal(t, "context about nonretryable error: detailed error message", wrappedErr.Error()) + assert.True(t, IsNonRetryableError(wrappedErr)) +} + func TestMultiErrorNoError(t *testing.T) { err := NewMultiError() require.Nil(t, err.FinalError())