diff --git a/action.go b/action.go index 0c5655f..2147c3d 100644 --- a/action.go +++ b/action.go @@ -19,17 +19,19 @@ const ( type Event struct { Type EventType Container *Container - Domain string - Path string - Result chan *Container + Endpoint *Endpoint + Result chan struct { + Container *Container + Endpoint *Endpoint + } } type ActionRunner struct { pingerCallback func() addCallback func(*Container) - updateCallback func(*Container, string, string) + updateCallback func(*Container, *Endpoint) removeCallback func(*Container) - getCallback func(string, string) *Container + getCallback func(string, string) (*Container, *Endpoint) events chan *Event close chan struct{} // using this to make sure pushing to events stops when Close() is called @@ -45,31 +47,33 @@ func (ar *ActionRunner) Add(container *Container) { ar.push(&Event{Type: addEvent, Container: container}) } -func (ar *ActionRunner) Update(container *Container, domain string, path string) { - ar.push(&Event{Type: updateEvent, Container: container, Domain: domain, Path: path}) +func (ar *ActionRunner) Update(container *Container, endpoint *Endpoint) { + ar.push(&Event{Type: updateEvent, Container: container, Endpoint: endpoint}) } func (ar *ActionRunner) Remove(container *Container) { ar.push(&Event{Type: removeEvent, Container: container}) } -func (ar *ActionRunner) Get(ctx context.Context, domain, path string) *Container { +func (ar *ActionRunner) Get(ctx context.Context, endpoint *Endpoint) (*Container, *Endpoint) { evt := &Event{ - Type: getEvent, - Domain: domain, - Path: path, - Result: make(chan *Container, 1), + Type: getEvent, + Endpoint: endpoint, + Result: make(chan struct { + Container *Container + Endpoint *Endpoint + }, 1), } ar.push(evt) select { - case container := <-evt.Result: - return container + case r := <-evt.Result: + return r.Container, r.Endpoint case <-ctx.Done(): - return nil + return nil, nil case <-ar.close: - return nil + return nil, nil } } @@ -85,7 +89,6 @@ func (ar *ActionRunner) push(event *Event) { func (ar *ActionRunner) Close() { close(ar.close) - close(ar.events) } func WithPingerCallback(callback func()) func(*ActionRunner) { @@ -100,7 +103,7 @@ func WithAddCallback(callback func(*Container)) func(*ActionRunner) { } } -func WithUpdateCallback(callback func(*Container, string, string)) func(*ActionRunner) { +func WithUpdateCallback(callback func(*Container, *Endpoint)) func(*ActionRunner) { return func(ar *ActionRunner) { ar.updateCallback = callback } @@ -112,7 +115,7 @@ func WithRemoveCallback(callback func(*Container)) func(*ActionRunner) { } } -func WithGetCallback(callback func(string, string) *Container) func(*ActionRunner) { +func WithGetCallback(callback func(string, string) (*Container, *Endpoint)) func(*ActionRunner) { return func(ar *ActionRunner) { ar.getCallback = callback } @@ -133,20 +136,33 @@ func NewActionRunner(bufferSize int, cbs ...ActionCallback) *ActionRunner { go func() { defer slog.Debug("ActionRunner: stopped") - for event := range ar.events { - switch event.Type { - case pingerEvent: - ar.pingerCallback() - case addEvent: - ar.addCallback(event.Container) - case updateEvent: - ar.updateCallback(event.Container, event.Domain, event.Path) - case removeEvent: - ar.removeCallback(event.Container) - case getEvent: - event.Result <- ar.getCallback(event.Domain, event.Path) - default: - continue + for { + select { + case <-ar.close: + return + case event, ok := <-ar.events: + if !ok { + return + } + + switch event.Type { + case pingerEvent: + ar.pingerCallback() + case addEvent: + ar.addCallback(event.Container) + case updateEvent: + ar.updateCallback(event.Container, event.Endpoint) + case removeEvent: + ar.removeCallback(event.Container) + case getEvent: + container, endpoint := ar.getCallback(event.Endpoint.Domain, event.Endpoint.Path) + event.Result <- struct { + Container *Container + Endpoint *Endpoint + }{container, endpoint} + default: + continue + } } } }() diff --git a/data.go b/data.go index 05d65c0..5d686e5 100644 --- a/data.go +++ b/data.go @@ -1,6 +1,7 @@ package baker import ( + "encoding/json" "net/netip" ) @@ -13,12 +14,23 @@ type Container struct { type Endpoint struct { Domain string `json:"domain"` Path string `json:"path"` + Rules []Rule `json:"rules"` +} + +type Rule struct { + Type string `json:"type"` + Args json.RawMessage `json:"args"` } type Config struct { Endpoints []Endpoint `json:"endpoints"` } +type Service struct { + Containers []*Container + Endpoint *Endpoint +} + type Driver interface { Add(*Container) Remove(*Container) diff --git a/go.mod b/go.mod index fbde1c1..b1de141 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module ella.to/baker go 1.22.0 require ( + github.com/alinz/baker.go v1.2.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.23.0 ) diff --git a/go.sum b/go.sum index 608633e..a21dbf2 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/alinz/baker.go v1.2.0 h1:XHdn58jMGLTupj2+qtUJuN3mImDOC/kABk7iRG+JXfk= +github.com/alinz/baker.go v1.2.0/go.mod h1:W7xTX8eE5v0ddkzDvg3tvCNdjZGjWk5s9S0sPIDuo/c= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/rule/internal/rate/LICENSE b/rule/internal/rate/LICENSE new file mode 100644 index 0000000..0ee2979 --- /dev/null +++ b/rule/internal/rate/LICENSE @@ -0,0 +1 @@ +This package is copied and modified from https://github.com/go-chi/httprate diff --git a/rule/internal/rate/context.go b/rule/internal/rate/context.go new file mode 100644 index 0000000..ba94de0 --- /dev/null +++ b/rule/internal/rate/context.go @@ -0,0 +1,16 @@ +package rate + +import "context" + +var incrementKey = &struct{}{} + +func WithIncrement(ctx context.Context, value int) context.Context { + return context.WithValue(ctx, incrementKey, value) +} + +func getIncrement(ctx context.Context) int { + if value, ok := ctx.Value(incrementKey).(int); ok { + return value + } + return 1 +} diff --git a/rule/internal/rate/limiter.go b/rule/internal/rate/limiter.go new file mode 100644 index 0000000..62b7941 --- /dev/null +++ b/rule/internal/rate/limiter.go @@ -0,0 +1,217 @@ +package rate + +import ( + "fmt" + "math" + "net/http" + "sync" + "time" + + "github.com/cespare/xxhash/v2" +) + +type LimitCounter interface { + Config(requestLimit int, windowLength time.Duration) + Increment(key string, currentWindow time.Time) error + IncrementBy(key string, currentWindow time.Time, amount int) error + Get(key string, currentWindow, previousWindow time.Time) (int, int, error) +} + +func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { + return newRateLimiter(requestLimit, windowLength, options...) +} + +func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { + rl := &rateLimiter{ + requestLimit: requestLimit, + windowLength: windowLength, + } + + for _, opt := range options { + opt(rl) + } + + if rl.keyFn == nil { + rl.keyFn = func(r *http.Request) (string, error) { + return "*", nil + } + } + + if rl.limitCounter == nil { + rl.limitCounter = &localCounter{ + counters: make(map[uint64]*count), + windowLength: windowLength, + } + } + rl.limitCounter.Config(requestLimit, windowLength) + + if rl.onRequestLimit == nil { + rl.onRequestLimit = func(w http.ResponseWriter, r *http.Request) { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + } + } + + return rl +} + +type rateLimiter struct { + requestLimit int + windowLength time.Duration + keyFn KeyFunc + limitCounter LimitCounter + onRequestLimit http.HandlerFunc + mu sync.Mutex +} + +func (l *rateLimiter) Counter() LimitCounter { + return l.limitCounter +} + +func (l *rateLimiter) Status(key string) (bool, float64, error) { + t := time.Now().UTC() + currentWindow := t.Truncate(l.windowLength) + previousWindow := currentWindow.Add(-l.windowLength) + + currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow) + if err != nil { + return false, 0, err + } + + diff := t.Sub(currentWindow) + rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount) + + if rate > float64(l.requestLimit) { + return false, rate, nil + } + return true, rate, nil +} + +func (l *rateLimiter) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key, err := l.keyFn(r) + if err != nil { + http.Error(w, err.Error(), http.StatusPreconditionRequired) + return + } + + currentWindow := time.Now().UTC().Truncate(l.windowLength) + + w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", l.requestLimit)) + w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", 0)) + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) + + l.mu.Lock() + _, rate, err := l.Status(key) + if err != nil { + l.mu.Unlock() + http.Error(w, err.Error(), http.StatusPreconditionRequired) + return + } + nrate := int(math.Round(rate)) + + if l.requestLimit > nrate { + w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", l.requestLimit-nrate)) + } + + if nrate >= l.requestLimit { + l.mu.Unlock() + w.Header().Set("Retry-After", fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 + l.onRequestLimit(w, r) + return + } + + err = l.limitCounter.IncrementBy(key, currentWindow, getIncrement(r.Context())) + if err != nil { + l.mu.Unlock() + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + l.mu.Unlock() + + next.ServeHTTP(w, r) + }) +} + +type localCounter struct { + counters map[uint64]*count + windowLength time.Duration + lastEvict time.Time + mu sync.Mutex +} + +var _ LimitCounter = &localCounter{} + +type count struct { + value int + updatedAt time.Time +} + +func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.windowLength = windowLength +} + +func (c *localCounter) Increment(key string, currentWindow time.Time) error { + return c.IncrementBy(key, currentWindow, 1) +} + +func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error { + c.evict() + + c.mu.Lock() + defer c.mu.Unlock() + + hkey := LimitCounterKey(key, currentWindow) + + v, ok := c.counters[hkey] + if !ok { + v = &count{} + c.counters[hkey] = v + } + v.value += amount + v.updatedAt = time.Now() + + return nil +} + +func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) { + c.mu.Lock() + defer c.mu.Unlock() + + curr, ok := c.counters[LimitCounterKey(key, currentWindow)] + if !ok { + curr = &count{value: 0, updatedAt: time.Now()} + } + prev, ok := c.counters[LimitCounterKey(key, previousWindow)] + if !ok { + prev = &count{value: 0, updatedAt: time.Now()} + } + + return curr.value, prev.value, nil +} + +func (c *localCounter) evict() { + c.mu.Lock() + defer c.mu.Unlock() + + d := c.windowLength * 3 + + if time.Since(c.lastEvict) < d { + return + } + c.lastEvict = time.Now() + + for k, v := range c.counters { + if time.Since(v.updatedAt) >= d { + delete(c.counters, k) + } + } +} + +func LimitCounterKey(key string, window time.Time) uint64 { + h := xxhash.New() + h.WriteString(key) + h.WriteString(fmt.Sprintf("%d", window.Unix())) + return h.Sum64() +} diff --git a/rule/internal/rate/limiter_test.go b/rule/internal/rate/limiter_test.go new file mode 100644 index 0000000..ea3bf42 --- /dev/null +++ b/rule/internal/rate/limiter_test.go @@ -0,0 +1,208 @@ +package rate_test + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "ella.to/baker/rule/internal/rate" +) + +func TestLimit(t *testing.T) { + type test struct { + name string + requestsLimit int + windowLength time.Duration + respCodes []int + } + tests := []test{ + { + name: "no-block", + requestsLimit: 3, + windowLength: 4 * time.Second, + respCodes: []int{200, 200, 200}, + }, + { + name: "block", + requestsLimit: 3, + windowLength: 2 * time.Second, + respCodes: []int{200, 200, 200, 429}, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + router := rate.LimitAll(tt.requestsLimit, tt.windowLength)(h) + + for _, code := range tt.respCodes { + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + if respCode := recorder.Result().StatusCode; respCode != code { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) + } + } + }) + } +} + +func TestWithIncrement(t *testing.T) { + type test struct { + name string + requestsLimit int + windowLength time.Duration + respCodes []int + } + tests := []test{ + { + name: "no-block", + requestsLimit: 3, + windowLength: 4 * time.Second, + respCodes: []int{200, 200, 429}, + }, + { + name: "block", + requestsLimit: 3, + windowLength: 2 * time.Second, + respCodes: []int{200, 200, 429, 429}, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + router := rate.LimitAll(tt.requestsLimit, tt.windowLength)(h) + + for _, code := range tt.respCodes { + req := httptest.NewRequest("GET", "/", nil) + req = req.WithContext(rate.WithIncrement(req.Context(), 2)) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + if respCode := recorder.Result().StatusCode; respCode != code { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) + } + } + }) + } +} + +func TestLimitHandler(t *testing.T) { + type test struct { + name string + requestsLimit int + windowLength time.Duration + responses []struct { + Body string + StatusCode int + } + } + tests := []test{ + { + name: "no-block", + requestsLimit: 3, + windowLength: 4 * time.Second, + responses: []struct { + Body string + StatusCode int + }{ + {Body: "", StatusCode: 200}, + {Body: "", StatusCode: 200}, + {Body: "", StatusCode: 200}, + }, + }, + { + name: "block", + requestsLimit: 3, + windowLength: 2 * time.Second, + responses: []struct { + Body string + StatusCode int + }{ + {Body: "", StatusCode: 200}, + {Body: "", StatusCode: 200}, + {Body: "", StatusCode: 200}, + {Body: "Wow Slow Down Kiddo", StatusCode: 429}, + }, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + router := rate.Limit( + tt.requestsLimit, + tt.windowLength, + rate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Wow Slow Down Kiddo", 429) + }), + )(h) + + for _, expected := range tt.responses { + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + result := recorder.Result() + if respStatus := result.StatusCode; respStatus != expected.StatusCode { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, expected.StatusCode) + } + buf := new(bytes.Buffer) + buf.ReadFrom(result.Body) + respBody := strings.TrimSuffix(buf.String(), "\n") + + if respBody != expected.Body { + t.Errorf("resp.Body(%v) = %v, want %v", i, respBody, expected.Body) + } + } + }) + } +} + +func TestLimitIP(t *testing.T) { + type test struct { + name string + requestsLimit int + windowLength time.Duration + reqIp []string + respCodes []int + } + tests := []test{ + { + name: "no-block", + requestsLimit: 3, + windowLength: 2 * time.Second, + reqIp: []string{"1.1.1.1:100", "2.2.2.2:200"}, + respCodes: []int{200, 200}, + }, + { + name: "block-ip", + requestsLimit: 1, + windowLength: 2 * time.Second, + reqIp: []string{"1.1.1.1:100", "1.1.1.1:100", "2.2.2.2:200"}, + respCodes: []int{200, 429, 200}, + }, + { + name: "block-ipv6", + requestsLimit: 1, + windowLength: 2 * time.Second, + reqIp: []string{"2001:DB8::21f:5bff:febf:ce22:1111", "2001:DB8::21f:5bff:febf:ce22:2222", "2002:DB8::21f:5bff:febf:ce22:1111"}, + respCodes: []int{200, 429, 200}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + router := rate.LimitByIP(tt.requestsLimit, tt.windowLength)(h) + + for i, code := range tt.respCodes { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tt.reqIp[i] + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + if respCode := recorder.Result().StatusCode; respCode != code { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) + } + } + }) + } +} diff --git a/rule/internal/rate/rate.go b/rule/internal/rate/rate.go new file mode 100644 index 0000000..781a0a9 --- /dev/null +++ b/rule/internal/rate/rate.go @@ -0,0 +1,141 @@ +package rate + +import ( + "net" + "net/http" + "strings" + "time" +) + +func Limit(requestLimit int, windowLength time.Duration, options ...Option) func(next http.Handler) http.Handler { + return NewRateLimiter(requestLimit, windowLength, options...).Handler +} + +type KeyFunc func(r *http.Request) (string, error) +type Option func(rl *rateLimiter) + +func LimitAll(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { + return Limit(requestLimit, windowLength) +} + +func LimitByIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { + return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByIP)) +} + +func LimitByRealIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { + return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByRealIP)) +} + +func KeyByIP(r *http.Request) (string, error) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + return canonicalizeIP(ip), nil +} + +func KeyByRealIP(r *http.Request) (string, error) { + var ip string + + if tcip := r.Header.Get("True-Client-IP"); tcip != "" { + ip = tcip + } else if xrip := r.Header.Get("X-Real-IP"); xrip != "" { + ip = xrip + } else if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + i := strings.Index(xff, ", ") + if i == -1 { + i = len(xff) + } + ip = xff[:i] + } else { + var err error + ip, _, err = net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + } + + return canonicalizeIP(ip), nil +} + +func KeyByEndpoint(r *http.Request) (string, error) { + return r.URL.Path, nil +} + +func WithKeyFuncs(keyFuncs ...KeyFunc) Option { + return func(rl *rateLimiter) { + if len(keyFuncs) > 0 { + rl.keyFn = composedKeyFunc(keyFuncs...) + } + } +} + +func WithKeyByIP() Option { + return WithKeyFuncs(KeyByIP) +} + +func WithKeyByRealIP() Option { + return WithKeyFuncs(KeyByRealIP) +} + +func WithLimitHandler(h http.HandlerFunc) Option { + return func(rl *rateLimiter) { + rl.onRequestLimit = h + } +} + +func WithLimitCounter(c LimitCounter) Option { + return func(rl *rateLimiter) { + rl.limitCounter = c + } +} + +func WithNoop() Option { + return func(rl *rateLimiter) {} +} + +func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc { + return func(r *http.Request) (string, error) { + var key strings.Builder + for i := 0; i < len(keyFuncs); i++ { + k, err := keyFuncs[i](r) + if err != nil { + return "", err + } + key.WriteString(k) + key.WriteRune(':') + } + return key.String(), nil + } +} + +// canonicalizeIP returns a form of ip suitable for comparison to other IPs. +// For IPv4 addresses, this is simply the whole string. +// For IPv6 addresses, this is the /64 prefix. +func canonicalizeIP(ip string) string { + isIPv6 := false + // This is how net.ParseIP decides if an address is IPv6 + // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704 + for i := 0; !isIPv6 && i < len(ip); i++ { + switch ip[i] { + case '.': + // IPv4 + return ip + case ':': + // IPv6 + isIPv6 = true + break + } + } + if !isIPv6 { + // Not an IP address at all + return ip + } + + ipv6 := net.ParseIP(ip) + if ipv6 == nil { + return ip + } + + return ipv6.Mask(net.CIDRMask(64, 128)).String() +} diff --git a/rule/internal/rate/rate_test.go b/rule/internal/rate/rate_test.go new file mode 100644 index 0000000..197886b --- /dev/null +++ b/rule/internal/rate/rate_test.go @@ -0,0 +1,59 @@ +package rate + +import "testing" + +func Test_canonicalizeIP(t *testing.T) { + tests := []struct { + name string + ip string + want string + }{ + { + name: "IPv4 unchanged", + ip: "1.2.3.4", + want: "1.2.3.4", + }, + { + name: "bad IP unchanged", + ip: "not an IP", + want: "not an IP", + }, + { + name: "bad IPv6 unchanged", + ip: "not:an:IP", + want: "not:an:IP", + }, + { + name: "empty string unchanged", + ip: "", + want: "", + }, + { + name: "IPv6 test 1", + ip: "2001:DB8::21f:5bff:febf:ce22:8a2e", + want: "2001:db8:0:21f::", + }, + { + name: "IPv6 test 2", + ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + want: "2001:db8:85a3::", + }, + { + name: "IPv6 test 3", + ip: "fe80::1ff:fe23:4567:890a", + want: "fe80::", + }, + { + name: "IPv6 test 4", + ip: "f:f:f:f:f:f:f:f", + want: "f:f:f:f::", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := canonicalizeIP(tt.ip); got != tt.want { + t.Errorf("canonicalizeIP() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/rule/middleware.go b/rule/middleware.go new file mode 100644 index 0000000..b67bee7 --- /dev/null +++ b/rule/middleware.go @@ -0,0 +1,25 @@ +package rule + +import ( + "encoding/json" + "net/http" +) + +type Middleware interface { + Process(next http.Handler) http.Handler + IsCachable() bool + UpdateMiddelware(newImpl Middleware) Middleware +} + +type BuilderFunc func(raw json.RawMessage) (Middleware, error) +type RegisterFunc func(map[string]BuilderFunc) error + +var Empty = []Middleware{} + +func Chain(next http.Handler, rules ...Middleware) http.Handler { + for i := len(rules) - 1; i >= 0; i-- { + next = rules[i].Process(next) + } + + return next +} diff --git a/rule/path.go b/rule/path.go new file mode 100644 index 0000000..bd482a6 --- /dev/null +++ b/rule/path.go @@ -0,0 +1,126 @@ +package rule + +import ( + "encoding/json" + "net/http" + "strings" +) + +const AppendPathName = "AppendPath" + +type AppendPath struct { + Begin string `json:"begin"` + End string `json:"end"` +} + +var _ Middleware = (*AppendPath)(nil) + +func (a *AppendPath) IsCachable() bool { + return false +} + +func (a *AppendPath) UpdateMiddelware(newImpl Middleware) Middleware { + return nil +} + +func (a *AppendPath) Process(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var bs strings.Builder + + bs.WriteString(a.Begin) + bs.WriteString(r.URL.Path) + bs.WriteString(a.End) + + r.URL.Path = bs.String() + + next.ServeHTTP(w, r) + }) +} + +func NewAppendPath(begin, end string) struct { + Type string `json:"type"` + Args any `json:"args"` +} { + return struct { + Type string `json:"type"` + Args any `json:"args"` + }{ + Type: AppendPathName, + Args: AppendPath{ + Begin: begin, + End: end, + }, + } +} + +func RegisterAppendPath() RegisterFunc { + return func(m map[string]BuilderFunc) error { + m[AppendPathName] = func(raw json.RawMessage) (Middleware, error) { + AppendPath := &AppendPath{} + err := json.Unmarshal(raw, AppendPath) + if err != nil { + return nil, err + } + return AppendPath, nil + } + + return nil + } +} + +const ReplacePathName = "ReplacePath" + +type ReplacePath struct { + Search string `json:"search"` + Replace string `json:"replace"` + Times int `json:"times"` +} + +var _ Middleware = (*ReplacePath)(nil) + +func (p *ReplacePath) IsCachable() bool { + return false +} + +func (p *ReplacePath) UpdateMiddelware(newImpl Middleware) Middleware { + return nil +} + +func (p *ReplacePath) Process(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.Replace(r.URL.Path, p.Search, p.Replace, p.Times) + next.ServeHTTP(w, r) + }) +} + +func NewReplacePath(search string, replace string, times int) struct { + Type string `json:"type"` + Args any `json:"args"` +} { + return struct { + Type string `json:"type"` + Args any `json:"args"` + }{ + Type: ReplacePathName, + Args: ReplacePath{ + Search: search, + Replace: replace, + Times: times, + }, + } +} + +func RegisterReplacePath() RegisterFunc { + return func(m map[string]BuilderFunc) error { + m[ReplacePathName] = func(raw json.RawMessage) (Middleware, error) { + ReplacePath := &ReplacePath{} + err := json.Unmarshal(raw, ReplacePath) + if err != nil { + return nil, err + } + return ReplacePath, nil + } + + return nil + } +} diff --git a/rule/ratelimiter.go b/rule/ratelimiter.go new file mode 100644 index 0000000..1780eeb --- /dev/null +++ b/rule/ratelimiter.go @@ -0,0 +1,125 @@ +package rule + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "ella.to/baker/rule/internal/rate" +) + +type WindowDuration struct { + time.Duration +} + +// MarshalJSON implements the json.Marshaler interface for WindowDuration. +func (d WindowDuration) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%s"`, d.String())), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface for WindowDuration. +func (d *WindowDuration) UnmarshalJSON(data []byte) error { + if len(data) < 2 { + d.Duration = 0 + return nil + } + + duration, err := time.ParseDuration(string(data[1 : len(data)-1])) + if err != nil { + return err + } + + d.Duration = duration + return nil +} + +type RateLimiter struct { + RequestLimit int `json:"request_limit"` + WindowDuration WindowDuration `json:"window_duration"` + middle func(next http.Handler) http.Handler +} + +var _ Middleware = (*RateLimiter)(nil) + +func (r *RateLimiter) IsCachable() bool { + return true +} + +func (r *RateLimiter) UpdateMiddelware(newImpl Middleware) Middleware { + if newImpl == nil { + slog.Debug( + "initializing for the first time", + "type", "RateLimiter", + "request_limit", r.RequestLimit, + "window_duration", r.WindowDuration.Duration, + ) + + r.middle = rate.LimitByIP(r.RequestLimit, r.WindowDuration.Duration) + return r + } + + newR, ok := newImpl.(*RateLimiter) + if !ok { + slog.Error("failed to update middleware", "type", "RateLimiter") + return r + } + + if r.RequestLimit == newR.RequestLimit && + r.WindowDuration == newR.WindowDuration && + r.middle != nil { + return r + } + + slog.Debug( + "updating middleware", + "type", "RateLimiter", + "request_limit", newR.RequestLimit, + "window_duration", newR.WindowDuration.Duration, + ) + + r.RequestLimit = newR.RequestLimit + r.WindowDuration = newR.WindowDuration + + r.middle = rate.LimitByIP(r.RequestLimit, r.WindowDuration.Duration) + + return r +} + +func (r *RateLimiter) Process(next http.Handler) http.Handler { + return r.middle(next) +} + +func NewRateLimiter(requestLimit int, windowDuration time.Duration) struct { + Type string `json:"type"` + Args any `json:"args"` +} { + return struct { + Type string `json:"type"` + Args any `json:"args"` + }{ + Type: "RateLimiter", + Args: RateLimiter{ + RequestLimit: requestLimit, + WindowDuration: WindowDuration{ + Duration: windowDuration, + }, + }, + } +} + +func RegisterRateLimiter() RegisterFunc { + return func(m map[string]BuilderFunc) error { + m["RateLimiter"] = func(raw json.RawMessage) (Middleware, error) { + rateLimiter := &RateLimiter{} + err := json.Unmarshal(raw, rateLimiter) + if err != nil { + return nil, err + } + return rateLimiter, nil + } + + return nil + } +} diff --git a/server.go b/server.go index 7bfca2d..b14f320 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "ella.to/baker/internal/httpclient" "ella.to/baker/internal/trie" + "ella.to/baker/rule" ) type containerInfo struct { @@ -25,8 +26,9 @@ type containerInfo struct { type Server struct { bufferSize int pingDuration time.Duration - containersMap map[string]*containerInfo // containerID -> containerInfo - domainsMap map[string]*trie.Node[[]*Container] // domain -> path -> containers + containersMap map[string]*containerInfo // containerID -> containerInfo + domainsMap map[string]*trie.Node[*Service] // domain -> path -> containers + rules map[string]rule.BuilderFunc runner *ActionRunner close chan struct{} } @@ -37,7 +39,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { domain := r.Host path := r.URL.Path - container := s.runner.Get(r.Context(), domain, path) + var container *Container + endpoint := &Endpoint{ + Domain: domain, + Path: path, + } + + container, endpoint = s.runner.Get(r.Context(), endpoint) if container == nil { http.NotFound(w, r) return @@ -57,7 +65,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { }, } - proxy.ServeHTTP(w, r) + middlewares, err := s.getMiddlewares(endpoint) + if err != nil { + slog.Error("failed to get middlewares", "error", err) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + rule.Chain(proxy, middlewares...).ServeHTTP(w, r) } func (s *Server) Close() { @@ -69,6 +84,30 @@ func (s *Server) RegisterDriver(fn func(Driver)) { fn(s.runner) } +func (s *Server) getMiddlewares(endpoint *Endpoint) ([]rule.Middleware, error) { + if len(endpoint.Rules) == 0 { + return rule.Empty, nil + } + + middlewares := make([]rule.Middleware, 0) + + for _, r := range endpoint.Rules { + builder, ok := s.rules[r.Type] + if !ok { + return nil, fmt.Errorf("failed to find rule builder for %s", r.Type) + } + + middleware, err := builder(r.Args) + if err != nil { + return nil, fmt.Errorf("failed to parse args for rule %s: %w", r.Type, err) + } + + middlewares = append(middlewares, middleware) + } + + return middlewares, nil +} + func (s *Server) pingContainers() { // make a copy of the containers map containers := make([]*containerInfo, 0, len(s.containersMap)) @@ -121,7 +160,7 @@ func (s *Server) pingContainers() { } for _, endpoint := range config.Endpoints { - s.runner.Update(c, endpoint.Domain, endpoint.Path) + s.runner.Update(c, &endpoint) } }(c, url, pingCount) } @@ -143,38 +182,41 @@ func (s *Server) addContainer(container *Container) { } } -func (s *Server) updateContainer(container *Container, domain, path string) { +func (s *Server) updateContainer(container *Container, endpoint *Endpoint) { cInfo, ok := s.containersMap[container.Id] - if ok && cInfo.domain == domain && cInfo.path == path { + if ok && cInfo.domain == endpoint.Domain && cInfo.path == endpoint.Path { // if the container is already in the correct domain and path, we don't need to do anything // we can just return to avoid unnecessary work return } - paths, ok := s.domainsMap[domain] + paths, ok := s.domainsMap[endpoint.Domain] if !ok { - paths = trie.New[[]*Container]() - s.domainsMap[domain] = paths + paths = trie.New[*Service]() + s.domainsMap[endpoint.Domain] = paths } - containers := paths.Get([]rune(path)) - if containers == nil { - containers = []*Container{container} + service := paths.Get([]rune(endpoint.Path)) + if service == nil { + service = &Service{ + Containers: []*Container{container}, + Endpoint: endpoint, + } } else { // we don't need to check if the container is already in the list, because we already checked that // in the beginning of this function - containers = append(containers, container) + service.Containers = append(service.Containers, container) } - paths.Put([]rune(path), containers) + paths.Put([]rune(endpoint.Path), service) // One thing to note that cInfo is not nil here // because we have intitalized it during the addContainer call // if it was nil, it should be a panic situation - cInfo.domain = domain - cInfo.path = path + cInfo.domain = endpoint.Domain + cInfo.path = endpoint.Path - slog.Debug("container updated", "container_id", container.Id, "domain", domain, "path", path) + slog.Debug("container updated", "container_id", container.Id, "domain", endpoint.Domain, "path", endpoint.Path) s.containersMap[container.Id] = cInfo } @@ -194,27 +236,27 @@ func (s *Server) removeContainer(container *Container) { return } - containers := paths.Get([]rune(containerInfo.path)) - if containers == nil { + service := paths.Get([]rune(containerInfo.path)) + if service == nil { return } - for i, c := range containers { + for i, c := range service.Containers { if c.Id != container.Id { continue } - containers = append(containers[:i], containers[i+1:]...) - if len(containers) == 0 { + service.Containers = append(service.Containers[:i], service.Containers[i+1:]...) + if len(service.Containers) == 0 { paths.Del([]rune(containerInfo.path)) } else { - paths.Put([]rune(containerInfo.path), containers) + paths.Put([]rune(containerInfo.path), service) } break } } -func (s *Server) getContainer(domain, path string) (container *Container) { +func (s *Server) getContainer(domain, path string) (container *Container, endpoint *Endpoint) { defer func() { if container != nil { slog.Debug("found container", "container_id", container.Id, "domain", domain, "path", path) @@ -225,19 +267,19 @@ func (s *Server) getContainer(domain, path string) (container *Container) { paths, ok := s.domainsMap[domain] if !ok { - return nil + return nil, nil } - containers := paths.Get([]rune(path)) - if len(containers) == 0 { - return nil + service := paths.Get([]rune(path)) + if service == nil || len(service.Containers) == 0 { + return nil, nil } // randomly select a container from the list // this is not the best way to do this, but it's good enough for now - pos := rand.Int31n(int32(len(containers))) + pos := rand.Int31n(int32(len(service.Containers))) - return containers[pos] + return service.Containers[pos], service.Endpoint } type serverOpt interface { @@ -267,7 +309,7 @@ func NewServer(opts ...serverOpt) *Server { bufferSize: 100, pingDuration: 10 * time.Second, containersMap: make(map[string]*containerInfo), - domainsMap: make(map[string]*trie.Node[[]*Container]), + domainsMap: make(map[string]*trie.Node[*Service]), close: make(chan struct{}), }