Skip to content

Commit

Permalink
Add options lambdaurl.WithDetectContentType and lambda.WithContextVal…
Browse files Browse the repository at this point in the history
…ue (#516)
bmoffatt authored Dec 1, 2023
1 parent 1dca084 commit 752114b
Showing 6 changed files with 252 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ jobs:
name: run tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go:
- "1.21"
22 changes: 22 additions & 0 deletions lambda/handler.go
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ type Handler interface {
type handlerOptions struct {
handlerFunc
baseContext context.Context
contextValues map[interface{}]interface{}
jsonRequestUseNumber bool
jsonRequestDisallowUnknownFields bool
jsonResponseEscapeHTML bool
@@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option {
})
}

// WithContextValue adds a value to the handler context.
// If a base context was set using WithContext, that base is used as the parent.
//
// Usage:
//
// lambda.StartWithOptions(
// func (ctx context.Context) (string, error) {
// return ctx.Value("foo"), nil
// },
// lambda.WithContextValue("foo", "bar")
// )
func WithContextValue(key interface{}, value interface{}) Option {
return Option(func(h *handlerOptions) {
h.contextValues[key] = value
})
}

// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
//
// Usage:
@@ -211,13 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
}
h := &handlerOptions{
baseContext: context.Background(),
contextValues: map[interface{}]interface{}{},
jsonResponseEscapeHTML: false,
jsonResponseIndentPrefix: "",
jsonResponseIndentValue: "",
}
for _, option := range options {
option(h)
}
for k, v := range h.contextValues {
h.baseContext = context.WithValue(h.baseContext, k, v)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
12 changes: 7 additions & 5 deletions lambda/sigterm_test.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path"
"strconv"
"strings"
"testing"
"time"
@@ -17,10 +18,6 @@ import (
"github.com/stretchr/testify/require"
)

const (
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
)

func TestEnableSigterm(t *testing.T) {
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
@@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) {
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

portI := 0
for name, opts := range map[string]struct {
envVars []string
assertLogs func(t *testing.T, logs string)
@@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
portI += 1
addr1 := "localhost:" + strconv.Itoa(8000+portI)
addr2 := "localhost:" + strconv.Itoa(9000+portI)
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler")
cmd.Env = append([]string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
91 changes: 76 additions & 15 deletions lambdaurl/http_handler.go
Original file line number Diff line number Diff line change
@@ -18,24 +18,76 @@ import (
"github.com/aws/aws-lambda-go/lambda"
)

type detectContentTypeContextKey struct{}

// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided.
// When true, the first Write call will pass the intial bytes to http.DetectContentType.
// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda,
// and the Lambda Function URL will fallback to it's default.
//
// Note: The http.ResponseWriter passed to the handler is unbuffered.
// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe.
//
// Usage:
//
// lambdaurl.Start(
// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
// w.Write("<!DOCTYPE html><html></html>")
// }),
// lambdaurl.WithDetectContentType(true)
// )
func WithDetectContentType(detectContentType bool) lambda.Option {
return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType)
}

type httpResponseWriter struct {
detectContentType bool
header http.Header
writer io.Writer
once sync.Once
ready chan<- header
}

type header struct {
code int
header http.Header
writer io.Writer
once sync.Once
status chan<- int
}

func (w *httpResponseWriter) Header() http.Header {
if w.header == nil {
w.header = http.Header{}
}
return w.header
}

func (w *httpResponseWriter) Write(p []byte) (int, error) {
w.once.Do(func() { w.status <- http.StatusOK })
w.writeHeader(http.StatusOK, p)
return w.writer.Write(p)
}

func (w *httpResponseWriter) WriteHeader(statusCode int) {
w.once.Do(func() { w.status <- statusCode })
w.writeHeader(statusCode, nil)
}

func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) {
w.once.Do(func() {
if w.detectContentType {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", detectContentType(initialPayload))
}
}
w.ready <- header{code: statusCode, header: w.header}
})
}

func detectContentType(p []byte) string {
// http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices.
// This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that.
// This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty.
if len(p) == 0 {
return "application/octet-stream"
}
return http.DetectContentType(p)
}

type requestContextKey struct{}
@@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest,
return req, ok
}

// Wrap converts an http.Handler into a lambda request handler.
// Wrap converts an http.Handler into a Lambda request handler.
//
// Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler.
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`.
func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {

var body io.Reader = strings.NewReader(request.Body)
if request.IsBase64Encoded {
body = base64.NewDecoder(base64.StdEncoding, body)
@@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR
for k, v := range request.Headers {
httpRequest.Header.Add(k, v)
}
status := make(chan int) // Signals when it's OK to start returning the response body to Lambda
header := http.Header{}

ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda
r, w := io.Pipe()
responseWriter := &httpResponseWriter{writer: w, ready: ready}
if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok {
responseWriter.detectContentType = detectContentType
}
go func() {
defer close(status)
defer close(ready)
defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader
handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest)
//nolint:errcheck
defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler
handler.ServeHTTP(responseWriter, httpRequest)
}()
header := <-ready
response := &events.LambdaFunctionURLStreamingResponse{
Body: r,
StatusCode: <-status,
StatusCode: header.code,
}
if len(header) > 0 {
response.Headers = make(map[string]string, len(header))
for k, v := range header {
if len(header.header) > 0 {
response.Headers = make(map[string]string, len(header.header))
for k, v := range header.header {
if k == "Set-Cookie" {
response.Cookies = v
} else {
132 changes: 121 additions & 11 deletions lambdaurl/http_handler_test.go
Original file line number Diff line number Diff line change
@@ -13,6 +13,11 @@ import (
"io/ioutil"
"log"
"net/http"
"os"
"os/exec"
"path"
"strconv"
"strings"
"testing"
"time"

@@ -35,12 +40,13 @@ var base64EncodedBodyRequest []byte

func TestWrap(t *testing.T) {
for name, params := range map[string]struct {
input []byte
handler http.HandlerFunc
expectStatus int
expectBody string
expectHeaders map[string]string
expectCookies []string
input []byte
handler http.HandlerFunc
detectContentType bool
expectStatus int
expectBody string
expectHeaders map[string]string
expectCookies []string
}{
"hello": {
input: helloRequest,
@@ -58,10 +64,8 @@ func TestWrap(t *testing.T) {
encoder := json.NewEncoder(w)
_ = encoder.Encode(struct{ RequestQueryParams, Method any }{r.URL.Query(), r.Method})
},
expectStatus: http.StatusTeapot,
expectHeaders: map[string]string{
"Hello": "world1,world2",
},
expectStatus: http.StatusTeapot,
expectHeaders: map[string]string{"Hello": "world1,world2"},
expectCookies: []string{
"yummy=cookie",
"yummy=cake",
@@ -110,6 +114,13 @@ func TestWrap(t *testing.T) {
handler: func(w http.ResponseWriter, r *http.Request) {},
expectStatus: http.StatusOK,
},
"write status code only": {
input: helloRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
},
expectStatus: http.StatusAccepted,
},
"base64request": {
input: base64EncodedBodyRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
@@ -118,12 +129,58 @@ func TestWrap(t *testing.T) {
expectStatus: http.StatusOK,
expectBody: "<idk/>",
},
"detect content type: write status code only": {
input: helloRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
},
detectContentType: true,
expectStatus: http.StatusAccepted,
expectHeaders: map[string]string{
"Content-Type": "application/octet-stream",
},
},
"detect content type: empty handler": {
input: helloRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
},
detectContentType: true,
expectStatus: http.StatusOK,
expectHeaders: map[string]string{
"Content-Type": "application/octet-stream",
},
},
"detect content type: writes html": {
input: helloRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("<!DOCTYPE HTML><html></html>"))
},
detectContentType: true,
expectBody: "<!DOCTYPE HTML><html></html>",
expectStatus: http.StatusOK,
expectHeaders: map[string]string{
"Content-Type": "text/html; charset=utf-8",
},
},
"detect content type: writes zeros": {
input: helloRequest,
handler: func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte{0, 0, 0, 0, 0})
},
detectContentType: true,
expectBody: "\x00\x00\x00\x00\x00",
expectStatus: http.StatusOK,
expectHeaders: map[string]string{
"Content-Type": "application/octet-stream",
},
},
} {
t.Run(name, func(t *testing.T) {
handler := Wrap(params.handler)
var req events.LambdaFunctionURLRequest
require.NoError(t, json.Unmarshal(params.input, &req))
res, err := handler(context.Background(), &req)
ctx := context.WithValue(context.Background(), detectContentTypeContextKey{}, params.detectContentType)
res, err := handler(ctx, &req)
require.NoError(t, err)
resultBodyBytes, err := ioutil.ReadAll(res)
require.NoError(t, err)
@@ -155,3 +212,56 @@ func TestRequestContext(t *testing.T) {
_, err := handler(context.Background(), req)
require.NoError(t, err)
}

func TestStartViaEmulator(t *testing.T) {
addr1 := "localhost:" + strconv.Itoa(6001)
addr2 := "localhost:" + strconv.Itoa(7001)
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
}

// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie
testDir := t.TempDir()
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "lambdaurl.handler"), "./testdata/lambdaurl.go")
handlerBuild.Stderr = os.Stderr
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "lambdaurl.handler")
cmd.Env = []string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
}
cmd.Stderr = os.Stderr
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
var logs string
done := make(chan interface{}) // closed on completion of log flush
go func() {
logBytes, err := ioutil.ReadAll(stdout)
require.NoError(t, err)
logs = string(logBytes)
close(done)
}()
require.NoError(t, cmd.Start())
t.Cleanup(func() { _ = cmd.Process.Kill() })

// give a moment for the port to bind
time.Sleep(500 * time.Millisecond)

client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}"))
require.NoError(t, err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
assert.NoError(t, err)

expected := "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"text/html; charset=utf-8\"}}\x00\x00\x00\x00\x00\x00\x00\x00<!DOCTYPE HTML>\n<html>\n<body>\nHello World!\n</body>\n</html>\n"
assert.Equal(t, expected, string(body))

require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained
<-done
t.Logf("stdout:\n%s", logs)
}
25 changes: 25 additions & 0 deletions lambdaurl/testdata/lambdaurl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package main

import (
"io"
"net/http"
"strings"

"github.com/aws/aws-lambda-go/lambdaurl"
)

const content = `<!DOCTYPE HTML>
<html>
<body>
Hello World!
</body>
</html>
`

func main() {
lambdaurl.Start(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(w, strings.NewReader(content))
}),
lambdaurl.WithDetectContentType(true),
)
}

0 comments on commit 752114b

Please sign in to comment.