-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: beginning of audit functionality
Middleware for setup, writing of log, context handling.
- Loading branch information
1 parent
6bbb904
commit 3fc90c9
Showing
3 changed files
with
303 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package audit | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
|
||
"github.com/rs/zerolog" | ||
) | ||
|
||
// marker for interface implementation | ||
var _ zerolog.LogObjectMarshaler = (*Entry)(nil) | ||
|
||
// marker for context key | ||
type key struct{} | ||
|
||
// logKey is the key used to store the audit log entry in the context. | ||
var logKey key | ||
|
||
// Entry is an audit log entry for the current request. | ||
type Entry struct { | ||
Method string | ||
Path string | ||
Status int | ||
SourceIP string | ||
UserAgent string | ||
RequestedProfile string | ||
Identity string | ||
Authorized bool | ||
Error string | ||
Repositories []string | ||
Permissions []string | ||
} | ||
|
||
// MarshalZerologObject implements zerolog.LogObjectMarshaler. This avoids the | ||
// need for reflection when logging, at the cost of requiring maintenance when | ||
// the Entry struct changes. | ||
func (e *Entry) MarshalZerologObject(event *zerolog.Event) { | ||
event.Str("method", e.Method). | ||
Str("path", e.Path). | ||
Int("status", e.Status). | ||
Str("sourceIP", e.SourceIP). | ||
Str("userAgent", e.UserAgent). | ||
Str("requestedProfile", e.RequestedProfile). | ||
Str("identity", e.Identity). | ||
Bool("authorized", e.Authorized). | ||
Str("error", e.Error). | ||
Strs("repositories", e.Repositories). | ||
Strs("permissions", e.Permissions) | ||
} | ||
|
||
// Begin sets up the audit log entry for the current request with details from the request. | ||
func (e *Entry) Begin(r *http.Request) { | ||
e.Path = r.URL.Path | ||
e.Method = r.Method | ||
e.UserAgent = r.UserAgent() | ||
e.SourceIP = r.RemoteAddr | ||
} | ||
|
||
// End writes the audit log entry. If the returned func is deferred, any panic | ||
// will be recovered so the log entry can be written before the panic is | ||
// re-raised. | ||
func (e *Entry) End(ctx context.Context) func() { | ||
return func() { | ||
// recover from panic if necessary | ||
r := recover() | ||
if r != nil { | ||
// record the details of the panid, attempting to avoid overwriting an | ||
// earlier error | ||
e.Status = http.StatusInternalServerError | ||
err := fmt.Sprintf("panic: %v", r) | ||
if e.Error != "" { | ||
e.Error += "; " | ||
} | ||
e.Error += err | ||
} | ||
|
||
zerolog.Ctx(ctx).Warn().EmbedObject(e).Msg("audit") | ||
|
||
if r != nil { | ||
// repanic the panic | ||
panic(r) | ||
} | ||
} | ||
} | ||
|
||
// Middleware is an HTTP middleware that creates a new audit log entry for the | ||
// current request and enriches it with information about the request. The log | ||
// entry is written to the log when the request is complete. | ||
// | ||
// A panic during the request will be recovered and logged as an error in the | ||
// audit entry. The HTTP status code of the response is also logged in the audit | ||
// entry; further details may be added by the application. | ||
func Middleware() func(next http.Handler) http.Handler { | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ctx, entry := Context(r.Context()) | ||
|
||
// wrap the response writer to capture the status code | ||
response := wrapResponseWriter(w, Log(ctx)) | ||
|
||
entry.Begin(r) | ||
defer entry.End(ctx)() | ||
|
||
next.ServeHTTP(response, r.WithContext(ctx)) | ||
}) | ||
} | ||
} | ||
|
||
// Get the log entry for the current request. This is safe to use even if the | ||
// context does not create an entry. | ||
func Log(ctx context.Context) *Entry { | ||
_, e := Context(ctx) | ||
return e | ||
} | ||
|
||
// Context returns the Entry for the current request, creating one if it | ||
// does not exist. If the returned context is kept, the returned entry can be | ||
// further enriched. If not, information written to the entry will be lost. | ||
func Context(ctx context.Context) (context.Context, *Entry) { | ||
e, ok := ctx.Value(logKey).(*Entry) | ||
if !ok { | ||
e = &Entry{} | ||
|
||
ctx = context.WithValue(ctx, logKey, e) | ||
} | ||
|
||
return ctx, e | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package audit_test | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/jamestelfer/chinmina-bridge/internal/audit" | ||
"github.com/rs/zerolog" | ||
"github.com/rs/zerolog/log" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestMiddleware(t *testing.T) { | ||
|
||
t.Run("captures request info and configures context", func(t *testing.T) { | ||
testAgent := "kettle/1.0" | ||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
|
||
entry := audit.Log(ctx) | ||
assert.Equal(t, testAgent, entry.UserAgent) | ||
|
||
w.WriteHeader(http.StatusTeapot) | ||
}) | ||
|
||
middleware := audit.Middleware()(handler) | ||
|
||
req, w := requestSetup() | ||
req.Header.Set("User-Agent", testAgent) | ||
|
||
middleware.ServeHTTP(w, req) | ||
|
||
assert.Equal(t, http.StatusTeapot, w.Result().StatusCode) | ||
}) | ||
|
||
t.Run("captures status code", func(t *testing.T) { | ||
var capturedContext context.Context | ||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
capturedContext = r.Context() | ||
w.WriteHeader(http.StatusTeapot) | ||
}) | ||
|
||
req, w := requestSetup() | ||
|
||
middleware := audit.Middleware()(handler) | ||
|
||
middleware.ServeHTTP(w, req) | ||
|
||
entry := audit.Log(capturedContext) | ||
|
||
assert.Equal(t, http.StatusTeapot, w.Result().StatusCode) | ||
assert.Equal(t, http.StatusTeapot, entry.Status) | ||
}) | ||
|
||
t.Run("log written", func(t *testing.T) { | ||
auditWritten := false | ||
|
||
hook := zerolog.HookFunc(func(e *zerolog.Event, level zerolog.Level, msg string) { | ||
if msg == "audit" { | ||
auditWritten = true | ||
} | ||
}) | ||
testLog := log.Logger.With().Bool("test", true).Logger().Hook(hook) | ||
ctx := testLog.WithContext(context.Background()) | ||
|
||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusTeapot) | ||
}) | ||
|
||
middleware := audit.Middleware()(handler) | ||
|
||
req, w := requestSetup() | ||
|
||
middleware.ServeHTTP(w, req.WithContext(ctx)) | ||
|
||
assert.True(t, auditWritten, "audit log entry should be written") | ||
}) | ||
|
||
t.Run("log written on panic", func(t *testing.T) { | ||
auditWritten := false | ||
|
||
hook := zerolog.HookFunc(func(e *zerolog.Event, level zerolog.Level, msg string) { | ||
if msg == "audit" { | ||
auditWritten = true | ||
} | ||
}) | ||
testLog := log.Logger.With().Bool("test", true).Logger().Hook(hook) | ||
ctx := testLog.WithContext(context.Background()) | ||
|
||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
panic("not a teapot") | ||
}) | ||
|
||
middleware := audit.Middleware()(handler) | ||
|
||
req, w := requestSetup() | ||
|
||
assert.PanicsWithValue(t, "not a teapot", func() { | ||
middleware.ServeHTTP(w, req.WithContext(ctx)) | ||
// this will panic as it's expected that the middleware will re-panic | ||
}) | ||
|
||
assert.True(t, auditWritten, "audit log entry should be written") | ||
}) | ||
} | ||
|
||
func TestAuditing(t *testing.T) { | ||
ctx := context.Background() | ||
r, _ := requestSetup() | ||
|
||
_, e := audit.Context(ctx) | ||
e.Begin(r) | ||
|
||
assert.NotEmpty(t, e.SourceIP) | ||
e.SourceIP = "" // clear IP as it will change between tests | ||
|
||
assert.Equal(t, &audit.Entry{Method: "GET", Path: "/foo"}, e) | ||
} | ||
|
||
func requestSetup() (*http.Request, *httptest.ResponseRecorder) { | ||
req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil) | ||
w := httptest.NewRecorder() | ||
return req, w | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package audit | ||
|
||
import ( | ||
"bufio" | ||
"net" | ||
"net/http" | ||
) | ||
|
||
func wrapResponseWriter(w http.ResponseWriter, e *Entry) http.ResponseWriter { | ||
wrapped := &responseWrapper{responseWriter: w, entry: e} | ||
if _, ok := w.(http.Hijacker); ok { | ||
return &hijackWrapper{*wrapped} | ||
} | ||
return wrapped | ||
} | ||
|
||
type responseWrapper struct { | ||
responseWriter http.ResponseWriter | ||
entry *Entry | ||
} | ||
|
||
func (w *responseWrapper) Header() http.Header { | ||
return w.responseWriter.Header() | ||
} | ||
|
||
func (w *responseWrapper) Write(buf []byte) (int, error) { | ||
return w.responseWriter.Write(buf) | ||
} | ||
|
||
func (w *responseWrapper) WriteHeader(code int) { | ||
w.entry.Status = code | ||
w.responseWriter.WriteHeader(code) | ||
} | ||
|
||
func (w *responseWrapper) Flush() { | ||
if flusher, ok := w.responseWriter.(http.Flusher); ok { | ||
flusher.Flush() | ||
} | ||
} | ||
|
||
// hijackWrapper wraps a response writer that supports hijacking. | ||
type hijackWrapper struct { | ||
responseWrapper | ||
} | ||
|
||
func (h *hijackWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { | ||
return h.responseWriter.(http.Hijacker).Hijack() | ||
} |