diff --git a/p2p/net/swarm/dial_error.go b/p2p/net/swarm/dial_error.go index 54260e6398..46e991a0df 100644 --- a/p2p/net/swarm/dial_error.go +++ b/p2p/net/swarm/dial_error.go @@ -49,15 +49,11 @@ func (e *DialError) Error() string { return builder.String() } -func (e *DialError) Unwrap() []error { +func (e *DialError) Unwrap() error { if e == nil || len(e.DialErrors) == 0 { return nil } - errs := make([]error, len(e.DialErrors)) - for i := 0; i < len(e.DialErrors); i++ { - errs[i] = &e.DialErrors[i] - } - return errs + return chainError(e.DialErrors) } func (e *DialError) Is(target error) bool { @@ -84,3 +80,32 @@ func (e *TransportError) Unwrap() error { } var _ error = (*TransportError)(nil) + +// chainError is used to implement Unwrap for DialError +// errors.Is and errors.As only support `interface { Unwrap() []error }` from go1.20 +// +// The implementation is a modified version of multierror.chain from: +// https://github.com/hashicorp/go-multierror/blob/main/multierror.go#L96 +type chainError []TransportError + +func (c chainError) Error() string { + // The actual value is not important. Only want to implement error. + if len(c) == 0 { + return "chainError: []" + } + return c[0].Error() +} + +func (c chainError) Unwrap() error { + if len(c) == 1 { + return nil + } + return c[1:] +} + +func (c chainError) Is(target error) bool { + if len(c) == 0 { + return false + } + return errors.Is(&c[0], target) +}