From 4a951fdeacf0d87b1b2d94a591ed5ab550031111 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 19 Jun 2017 20:20:44 -0400 Subject: [PATCH 1/3] Recover during a request forward. gRPC doesn't have a handler for recovering from a panic like a normal HTTP request so a panic will actually kill Vault's listener. This basically copies the net/http logic for managing this. The SSH-specific logic is removed here as the underlying issue is caused by the request forwarding mechanism. --- builtin/logical/ssh/path_sign.go | 10 ---------- vault/request_forwarding.go | 27 ++++++++++++++++++++++----- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/builtin/logical/ssh/path_sign.go b/builtin/logical/ssh/path_sign.go index d13451cd22e0..c1f133948053 100644 --- a/builtin/logical/ssh/path_sign.go +++ b/builtin/logical/ssh/path_sign.go @@ -389,16 +389,6 @@ func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.D } func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) { - defer func() { - if r := recover(); r != nil { - err, ok := r.(error) - if ok { - retCert = nil - retErr = err - } - } - }() - serialNumber, err := certutil.GenerateSerialNumber() if err != nil { return nil, err diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 84b89afe5b80..f585cd209cbe 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/url" + "runtime" "sync" "sync/atomic" "time" @@ -352,11 +353,27 @@ func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *fo // meets the interface requirements. w := forwarding.NewRPCResponseWriter() - s.handler.ServeHTTP(w, req) - - resp := &forwarding.Response{ - StatusCode: uint32(w.StatusCode()), - Body: w.Body().Bytes(), + resp := &forwarding.Response{} + var respSet bool + + runRequest := func() { + defer func() { + // Logic here comes mostly from the Go source code + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + resp.StatusCode = 500 + s.core.logger.Error("forwarding: panic serving request for %v: %v\n%s", req.URL.Path, err, buf) + respSet = true + } + }() + s.handler.ServeHTTP(w, req) + } + runRequest() + if !respSet { + resp.StatusCode = uint32(w.StatusCode()) + resp.Body = w.Body().Bytes() } header := w.Header() From cd9d21fa8932debc7b1bbe336981d15a413510a5 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 20 Jun 2017 19:54:10 -0400 Subject: [PATCH 2/3] Fix error message formatting and response body --- vault/request_forwarding.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index f585cd209cbe..7d764b7a72da 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -354,7 +354,6 @@ func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *fo w := forwarding.NewRPCResponseWriter() resp := &forwarding.Response{} - var respSet bool runRequest := func() { defer func() { @@ -363,18 +362,14 @@ func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *fo const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - resp.StatusCode = 500 - s.core.logger.Error("forwarding: panic serving request for %v: %v\n%s", req.URL.Path, err, buf) - respSet = true + s.core.logger.Error("forwarding: panic serving request", "path", req.URL.Path, "error", err, "stacktrace", buf) } }() s.handler.ServeHTTP(w, req) } runRequest() - if !respSet { - resp.StatusCode = uint32(w.StatusCode()) - resp.Body = w.Body().Bytes() - } + resp.StatusCode = uint32(w.StatusCode()) + resp.Body = w.Body().Bytes() header := w.Header() if header != nil { From 0ac923d38b20898f6479461c1098fe736fd981c2 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Thu, 27 Jul 2017 21:00:31 -0400 Subject: [PATCH 3/3] fixing recovery from x/golang/crypto panics --- builtin/logical/ssh/path_sign.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/builtin/logical/ssh/path_sign.go b/builtin/logical/ssh/path_sign.go index c1f133948053..4d62f4a37539 100644 --- a/builtin/logical/ssh/path_sign.go +++ b/builtin/logical/ssh/path_sign.go @@ -389,6 +389,16 @@ func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.D } func (b *creationBundle) sign() (retCert *ssh.Certificate, retErr error) { + defer func() { + if r := recover(); r != nil { + errMsg, ok := r.(string) + if ok { + retCert = nil + retErr = errors.New(errMsg) + } + } + }() + serialNumber, err := certutil.GenerateSerialNumber() if err != nil { return nil, err