forked from gnolang/gno
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththrottle.go
110 lines (86 loc) · 2.18 KB
/
throttle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package main
import (
"context"
"errors"
"net/netip"
"sync"
"time"
"golang.org/x/time/rate"
)
const (
maxRequestsPerMinute = 5
defaultCleanTimeout = time.Minute * 3
defaultRateLimitInterval = time.Minute / maxRequestsPerMinute
)
var errInvalidNumberOfRequests = errors.New("invalid number of requests")
type client struct {
limiter *rate.Limiter
seen time.Time
}
type requestMap map[netip.Addr]*client
// iterate ranges over the request map (NOT thread safe)
func (r requestMap) iterate(cb func(key netip.Addr, value *client)) {
for ip, requests := range r {
cb(ip, requests)
}
}
type ipThrottler struct {
cleanupInterval time.Duration
rateLimitInterval time.Duration
requestMap requestMap
sync.Mutex
}
// newIPThrottler creates a new ip throttler
func newIPThrottler(rateLimitInterval, cleanupInterval time.Duration) *ipThrottler {
return &ipThrottler{
cleanupInterval: cleanupInterval,
rateLimitInterval: rateLimitInterval,
requestMap: make(requestMap),
}
}
// start starts the throttle cleanup service
func (st *ipThrottler) start(ctx context.Context) {
go st.runCleanup(ctx)
}
// runCleanup runs the main ip throttle cleanup loop
func (st *ipThrottler) runCleanup(ctx context.Context) {
ticker := time.NewTicker(st.cleanupInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
st.Lock()
// Clean up stale requests
st.requestMap.iterate(func(ip netip.Addr, client *client) {
// Check if the request was last seen a while ago
if time.Since(client.seen) < st.cleanupInterval {
return
}
delete(st.requestMap, ip)
})
st.Unlock()
}
}
}
// registerNewRequest registers a new IP request with the throttler
func (st *ipThrottler) registerNewRequest(ip netip.Addr) error {
st.Lock()
defer st.Unlock()
// Get the client associated with the address, if any
c := st.requestMap[ip]
if c == nil {
c = &client{
limiter: rate.NewLimiter(rate.Every(st.rateLimitInterval), 5),
seen: time.Now(),
}
st.requestMap[ip] = c
}
// Check if the IP exceeded the request count
if !c.limiter.Allow() {
return errInvalidNumberOfRequests
}
// Update the last seen time
c.seen = time.Now()
return nil
}