Skip to content

Commit

Permalink
feat: beginning of audit functionality
Browse files Browse the repository at this point in the history
Middleware for setup, writing of log, context handling.
  • Loading branch information
jamestelfer committed Oct 3, 2024
1 parent 6bbb904 commit 3fc90c9
Show file tree
Hide file tree
Showing 3 changed files with 303 additions and 0 deletions.
129 changes: 129 additions & 0 deletions internal/audit/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package audit

import (
"context"
"fmt"
"net/http"

"github.com/rs/zerolog"
)

// marker for interface implementation
var _ zerolog.LogObjectMarshaler = (*Entry)(nil)

// marker for context key
type key struct{}

// logKey is the key used to store the audit log entry in the context.
var logKey key

// Entry is an audit log entry for the current request.
type Entry struct {
Method string
Path string
Status int
SourceIP string
UserAgent string
RequestedProfile string
Identity string
Authorized bool
Error string
Repositories []string
Permissions []string
}

// MarshalZerologObject implements zerolog.LogObjectMarshaler. This avoids the
// need for reflection when logging, at the cost of requiring maintenance when
// the Entry struct changes.
func (e *Entry) MarshalZerologObject(event *zerolog.Event) {
event.Str("method", e.Method).
Str("path", e.Path).
Int("status", e.Status).
Str("sourceIP", e.SourceIP).
Str("userAgent", e.UserAgent).
Str("requestedProfile", e.RequestedProfile).
Str("identity", e.Identity).
Bool("authorized", e.Authorized).
Str("error", e.Error).
Strs("repositories", e.Repositories).
Strs("permissions", e.Permissions)
}

// Begin sets up the audit log entry for the current request with details from the request.
func (e *Entry) Begin(r *http.Request) {
e.Path = r.URL.Path
e.Method = r.Method
e.UserAgent = r.UserAgent()
e.SourceIP = r.RemoteAddr
}

// End writes the audit log entry. If the returned func is deferred, any panic
// will be recovered so the log entry can be written before the panic is
// re-raised.
func (e *Entry) End(ctx context.Context) func() {
return func() {
// recover from panic if necessary
r := recover()
if r != nil {
// record the details of the panid, attempting to avoid overwriting an
// earlier error
e.Status = http.StatusInternalServerError
err := fmt.Sprintf("panic: %v", r)
if e.Error != "" {
e.Error += "; "

Check warning on line 73 in internal/audit/log.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/log.go#L73

Added line #L73 was not covered by tests
}
e.Error += err
}

zerolog.Ctx(ctx).Warn().EmbedObject(e).Msg("audit")

if r != nil {
// repanic the panic
panic(r)
}
}
}

// Middleware is an HTTP middleware that creates a new audit log entry for the
// current request and enriches it with information about the request. The log
// entry is written to the log when the request is complete.
//
// A panic during the request will be recovered and logged as an error in the
// audit entry. The HTTP status code of the response is also logged in the audit
// entry; further details may be added by the application.
func Middleware() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, entry := Context(r.Context())

// wrap the response writer to capture the status code
response := wrapResponseWriter(w, Log(ctx))

entry.Begin(r)
defer entry.End(ctx)()

next.ServeHTTP(response, r.WithContext(ctx))
})
}
}

// Get the log entry for the current request. This is safe to use even if the
// context does not create an entry.
func Log(ctx context.Context) *Entry {
_, e := Context(ctx)
return e
}

// Context returns the Entry for the current request, creating one if it
// does not exist. If the returned context is kept, the returned entry can be
// further enriched. If not, information written to the entry will be lost.
func Context(ctx context.Context) (context.Context, *Entry) {
e, ok := ctx.Value(logKey).(*Entry)
if !ok {
e = &Entry{}

ctx = context.WithValue(ctx, logKey, e)
}

return ctx, e
}
126 changes: 126 additions & 0 deletions internal/audit/log_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package audit_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/jamestelfer/chinmina-bridge/internal/audit"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
)

func TestMiddleware(t *testing.T) {

t.Run("captures request info and configures context", func(t *testing.T) {
testAgent := "kettle/1.0"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

entry := audit.Log(ctx)
assert.Equal(t, testAgent, entry.UserAgent)

w.WriteHeader(http.StatusTeapot)
})

middleware := audit.Middleware()(handler)

req, w := requestSetup()
req.Header.Set("User-Agent", testAgent)

middleware.ServeHTTP(w, req)

assert.Equal(t, http.StatusTeapot, w.Result().StatusCode)
})

t.Run("captures status code", func(t *testing.T) {
var capturedContext context.Context
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedContext = r.Context()
w.WriteHeader(http.StatusTeapot)
})

req, w := requestSetup()

middleware := audit.Middleware()(handler)

middleware.ServeHTTP(w, req)

entry := audit.Log(capturedContext)

assert.Equal(t, http.StatusTeapot, w.Result().StatusCode)
assert.Equal(t, http.StatusTeapot, entry.Status)
})

t.Run("log written", func(t *testing.T) {
auditWritten := false

hook := zerolog.HookFunc(func(e *zerolog.Event, level zerolog.Level, msg string) {
if msg == "audit" {
auditWritten = true
}
})
testLog := log.Logger.With().Bool("test", true).Logger().Hook(hook)
ctx := testLog.WithContext(context.Background())

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
})

middleware := audit.Middleware()(handler)

req, w := requestSetup()

middleware.ServeHTTP(w, req.WithContext(ctx))

assert.True(t, auditWritten, "audit log entry should be written")
})

t.Run("log written on panic", func(t *testing.T) {
auditWritten := false

hook := zerolog.HookFunc(func(e *zerolog.Event, level zerolog.Level, msg string) {
if msg == "audit" {
auditWritten = true
}
})
testLog := log.Logger.With().Bool("test", true).Logger().Hook(hook)
ctx := testLog.WithContext(context.Background())

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("not a teapot")
})

middleware := audit.Middleware()(handler)

req, w := requestSetup()

assert.PanicsWithValue(t, "not a teapot", func() {
middleware.ServeHTTP(w, req.WithContext(ctx))
// this will panic as it's expected that the middleware will re-panic
})

assert.True(t, auditWritten, "audit log entry should be written")
})
}

func TestAuditing(t *testing.T) {
ctx := context.Background()
r, _ := requestSetup()

_, e := audit.Context(ctx)
e.Begin(r)

assert.NotEmpty(t, e.SourceIP)
e.SourceIP = "" // clear IP as it will change between tests

assert.Equal(t, &audit.Entry{Method: "GET", Path: "/foo"}, e)
}

func requestSetup() (*http.Request, *httptest.ResponseRecorder) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
w := httptest.NewRecorder()
return req, w
}
48 changes: 48 additions & 0 deletions internal/audit/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package audit

import (
"bufio"
"net"
"net/http"
)

func wrapResponseWriter(w http.ResponseWriter, e *Entry) http.ResponseWriter {
wrapped := &responseWrapper{responseWriter: w, entry: e}
if _, ok := w.(http.Hijacker); ok {
return &hijackWrapper{*wrapped}

Check warning on line 12 in internal/audit/response.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/response.go#L12

Added line #L12 was not covered by tests
}
return wrapped
}

type responseWrapper struct {
responseWriter http.ResponseWriter
entry *Entry
}

func (w *responseWrapper) Header() http.Header {
return w.responseWriter.Header()

Check warning on line 23 in internal/audit/response.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/response.go#L22-L23

Added lines #L22 - L23 were not covered by tests
}

func (w *responseWrapper) Write(buf []byte) (int, error) {
return w.responseWriter.Write(buf)

Check warning on line 27 in internal/audit/response.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/response.go#L26-L27

Added lines #L26 - L27 were not covered by tests
}

func (w *responseWrapper) WriteHeader(code int) {
w.entry.Status = code
w.responseWriter.WriteHeader(code)
}

func (w *responseWrapper) Flush() {
if flusher, ok := w.responseWriter.(http.Flusher); ok {
flusher.Flush()

Check warning on line 37 in internal/audit/response.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/response.go#L35-L37

Added lines #L35 - L37 were not covered by tests
}
}

// hijackWrapper wraps a response writer that supports hijacking.
type hijackWrapper struct {
responseWrapper
}

func (h *hijackWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return h.responseWriter.(http.Hijacker).Hijack()

Check warning on line 47 in internal/audit/response.go

View check run for this annotation

Codecov / codecov/patch

internal/audit/response.go#L46-L47

Added lines #L46 - L47 were not covered by tests
}

0 comments on commit 3fc90c9

Please sign in to comment.