Skip to content

Commit

Permalink
fix: spurious cancelation of async webhooks, better tracing (#2969)
Browse files Browse the repository at this point in the history
Previously, async webhooks (response.ignore=true) would be canceled
early once the incoming Kratos request was served and it's associated
context released. We now dissociate the cancellation of async hooks
from the normal request processing flow.
  • Loading branch information
alnr authored Dec 20, 2022
1 parent 3e06c99 commit 72de640
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 29 deletions.
85 changes: 57 additions & 28 deletions selfservice/hook/web_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/pkg/errors"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"

"github.com/ory/kratos/ui/node"
Expand All @@ -29,7 +32,6 @@ import (
"github.com/ory/kratos/session"
"github.com/ory/kratos/text"
"github.com/ory/kratos/x"
"github.com/ory/x/otelx"
)

var (
Expand Down Expand Up @@ -253,22 +255,6 @@ func (e *WebHook) ExecuteSettingsPrePersistHook(_ http.ResponseWriter, req *http
}

func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
span := trace.SpanFromContext(ctx)
attrs := map[string]string{
"webhook.http.method": data.RequestMethod,
"webhook.http.url": data.RequestURL,
"webhook.http.headers": fmt.Sprintf("%#v", data.RequestHeaders),
}

if data.Identity != nil {
attrs["webhook.identity.id"] = data.Identity.ID.String()
} else {
attrs["webhook.identity.id"] = ""
}

span.SetAttributes(otelx.StringAttrs(attrs)...)
defer span.End()

builder, err := request.NewBuilder(e.conf, e.deps)
if err != nil {
return err
Expand All @@ -281,35 +267,78 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
return err
}

errChan := make(chan error, 1)
attrs := semconv.HTTPClientAttributesFromHTTPRequest(req.Request)
if data.Identity != nil {
attrs = append(attrs,
attribute.String("webhook.identity.id", data.Identity.ID.String()),
attribute.String("webhook.identity.nid", data.Identity.NID.String()),
)
}

var (
httpClient = e.deps.HTTPClient(ctx)
ignoreResponse = gjson.GetBytes(e.conf, "response.ignore").Bool()
canInterrupt = gjson.GetBytes(e.conf, "can_interrupt").Bool()
tracer = trace.SpanFromContext(ctx).TracerProvider().Tracer("kratos-webhooks")
spanOpts = []trace.SpanStartOption{trace.WithAttributes(attrs...)}
errChan = make(chan error, 1)
)

ctx, span := tracer.Start(ctx, "Webhook", spanOpts...)
e.deps.Logger().WithRequest(req.Request).Info("Dispatching webhook")

req = req.WithContext(ctx)
if ignoreResponse {
// This is one of the few places where spawning a context.Background() is ok. We need to do this
// because the function runs asynchronously and we don't want to cancel the request if the
// incoming request context is cancelled.
//
// The webhook will still cancel after 30 seconds as that is the configured timeout for the HTTP client.
req = req.WithContext(context.Background())
// spanOpts = append(spanOpts, trace.WithNewRoot())
}

startTime := time.Now()
go func() {
defer close(errChan)
defer span.End()

resp, err := e.deps.HTTPClient(ctx).Do(req.WithContext(ctx))
resp, err := httpClient.Do(req)
if err != nil {
span.SetStatus(codes.Error, err.Error())
errChan <- errors.WithStack(err)
return
}
defer resp.Body.Close()
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...)

if resp.StatusCode >= http.StatusBadRequest {
if gjson.GetBytes(e.conf, "can_interrupt").Bool() {
span.SetStatus(codes.Error, "HTTP status code >= 400")
if canInterrupt {
if err := parseWebhookResponse(resp); err != nil {
span.SetStatus(codes.Error, err.Error())
errChan <- err
}
}
errChan <- fmt.Errorf("web hook failed with status code %v", resp.StatusCode)
span.SetStatus(codes.Error, fmt.Sprintf("web hook failed with status code %v", resp.StatusCode))
errChan <- fmt.Errorf("webhook failed with status code %v", resp.StatusCode)
return
}

errChan <- nil
}()

if gjson.GetBytes(e.conf, "response.ignore").Bool() {
if ignoreResponse {
traceID, spanID := span.SpanContext().TraceID(), span.SpanContext().SpanID()
logger := e.deps.Logger().WithField("otel", map[string]string{
"trace_id": traceID.String(),
"span_id": spanID.String(),
})
go func() {
err := <-errChan
e.deps.Logger().WithError(err).Warning("A web hook request failed but the error was ignored because the configuration indicated that the upstream response should be ignored.")
if err := <-errChan; err != nil {
logger.WithField("duration", time.Since(startTime)).WithError(err).Warning("Webhook request failed but the error was ignored because the configuration indicated that the upstream response should be ignored.")
} else {
logger.WithField("duration", time.Since(startTime)).Info("Webhook request succeeded")
}
}()
return nil
}
Expand All @@ -323,7 +352,7 @@ func parseWebhookResponse(resp *http.Response) (err error) {
}
var hookResponse rawHookResponse
if err := json.NewDecoder(resp.Body).Decode(&hookResponse); err != nil {
return errors.Wrap(err, "hook response could not be unmarshalled properly from JSON")
return errors.Wrap(err, "webhook response could not be unmarshalled properly from JSON")
}

var validationErrs []*schema.ValidationError
Expand All @@ -343,11 +372,11 @@ func parseWebhookResponse(resp *http.Response) (err error) {
Context: detail.Context,
})
}
validationErrs = append(validationErrs, schema.NewHookValidationError(msg.InstancePtr, "a web-hook target returned an error", messages))
validationErrs = append(validationErrs, schema.NewHookValidationError(msg.InstancePtr, "a webhook target returned an error", messages))
}

if len(validationErrs) == 0 {
return errors.New("error while parsing hook response: got no validation errors")
return errors.New("error while parsing webhook response: got no validation errors")
}

return schema.NewValidationListError(validationErrs)
Expand Down
84 changes: 83 additions & 1 deletion selfservice/hook/web_hook_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"
"time"

"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"

"github.com/ory/kratos/schema"
Expand Down Expand Up @@ -365,7 +366,7 @@ func TestWebHooks(t *testing.T) {
}`,
)

webhookError := schema.NewValidationListError([]*schema.ValidationError{schema.NewHookValidationError("#/traits/username", "a web-hook target returned an error", text.Messages{{ID: 1234, Type: "info", Text: "error message"}})})
webhookError := schema.NewValidationListError([]*schema.ValidationError{schema.NewHookValidationError("#/traits/username", "a webhook target returned an error", text.Messages{{ID: 1234, Type: "info", Text: "error message"}})})
for _, tc := range []struct {
uc string
callWebHook func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error
Expand Down Expand Up @@ -839,3 +840,84 @@ func TestDisallowPrivateIPRanges(t *testing.T) {
require.Contains(t, err.Error(), "192.168.178.0 is not a public IP address")
})
}

func TestAsyncWebhook(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
logger := logrusx.New("kratos", "test")
logHook := new(test.Hook)
logger.Logger.Hooks.Add(logHook)
whDeps := struct {
x.SimpleLoggerWithClient
*jsonnetsecure.TestProvider
}{
x.SimpleLoggerWithClient{L: logger, C: reg.HTTPClient(context.Background()), T: otelx.NewNoop(logger, &otelx.Config{ServiceName: "kratos"})},
jsonnetsecure.NewTestProvider(t),
}

req := &http.Request{
Header: map[string][]string{"Some-Header": {"Some-Value"}},
Host: "www.ory.sh",
TLS: new(tls.ConnectionState),
URL: &url.URL{Path: "/some_end_point"},
Method: http.MethodPost,
}

incomingCtx, incomingCancel := context.WithCancel(context.Background())
if deadline, ok := t.Deadline(); ok {
// cancel this context one second before test timeout for clean shutdown
var cleanup context.CancelFunc
incomingCtx, cleanup = context.WithDeadline(incomingCtx, deadline.Add(-time.Second))
defer cleanup()
}

req = req.WithContext(incomingCtx)
s := &session.Session{ID: x.NewUUID(), Identity: &identity.Identity{ID: x.NewUUID()}}
f := &login.Flow{ID: x.NewUUID()}

handlerEntered, blockHandlerOnExit := make(chan struct{}), make(chan struct{})
webhookReceiver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerEntered)
<-blockHandlerOnExit
w.Write([]byte("ok"))
}))
t.Cleanup(webhookReceiver.Close)

wh := hook.NewWebHook(&whDeps, json.RawMessage(fmt.Sprintf(`
{
"url": %q,
"method": "GET",
"body": "file://stub/test_body.jsonnet",
"response": {
"ignore": true
}
}`, webhookReceiver.URL)))
err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s)
require.NoError(t, err) // execution returns immediately for async webhook
select {
case <-time.After(200 * time.Millisecond):
t.Fatal("timed out waiting for webhook request to reach test handler")
case <-handlerEntered:
// ok
}
// at this point, a goroutine is in the middle of the call to our test handler and waiting for a response
incomingCancel() // simulate the incoming Kratos request having finished
close(blockHandlerOnExit)
timeout := time.After(200 * time.Millisecond)
var found bool
for !found {
for _, entry := range logHook.AllEntries() {
if entry.Message == "Webhook request succeeded" {
found = true
break
}
}

select {
case <-timeout:
t.Fatal("timed out waiting for successful webhook completion")
case <-time.After(50 * time.Millisecond):
// continue loop
}
}
require.True(t, found)
}

0 comments on commit 72de640

Please sign in to comment.