limiter.go (3320B)
1 package limiter 2 3 import ( 4 "context" 5 "golang.org/x/time/rate" 6 "time" 7 ) 8 9 type TimedRateLimiter struct { 10 // periodic forgetting of identifiers that have been seen & assigned a rate limiter to prevent bloat over time 11 timers map[string]*time.Timer 12 // buckets of access tokens, refreshing over time 13 limiters map[string]*rate.Limiter 14 // routes that are rate limited 15 routes map[string]bool 16 refreshPeriod time.Duration 17 timeToRemember time.Duration 18 burst int 19 } 20 21 func NewTimedRateLimiter(limitedRoutes []string, refresh, remember time.Duration) *TimedRateLimiter { 22 rl := TimedRateLimiter{} 23 rl.timers = make(map[string]*time.Timer) 24 rl.limiters = make(map[string]*rate.Limiter) 25 rl.routes = make(map[string]bool) 26 for _, route := range limitedRoutes { 27 rl.routes[route] = true 28 } 29 rl.refreshPeriod = refresh 30 rl.timeToRemember = remember 31 rl.burst = 15 /* default value, use rl.SetBurstAllowance to change */ 32 return &rl 33 } 34 35 // amount of accesses allowed ~concurrently, before needing to wait for a rl.refreshPeriod 36 func (rl *TimedRateLimiter) SetBurstAllowance(burst int) { 37 if burst >= 1 { 38 rl.burst = burst 39 } 40 } 41 42 // find out if resource access is allowed or not: calling consumes a rate limit token 43 func (rl *TimedRateLimiter) IsLimited(identifier, route string) bool { 44 // route isn't rate limited 45 if _, exists := rl.routes[route]; !exists { 46 return false 47 } 48 // route is designated to be rate limited, try the limiter to see if we can access it 49 ret := !rl.access(identifier) 50 return ret 51 } 52 53 func (rl *TimedRateLimiter) BlockUntilAllowed(identifier, route string, ctx context.Context) error { 54 // route isn't rate limited 55 if _, exists := rl.routes[route]; !exists { 56 return nil 57 } 58 limiter := rl.getLimiter(identifier) 59 err := limiter.Wait(ctx) 60 if err != nil { 61 return err 62 } 63 return nil 64 } 65 66 func (rl *TimedRateLimiter) getLimiter(identifier string) *rate.Limiter { 67 // limiter doesn't yet exist for this identifier 68 if _, exists := rl.limiters[identifier]; !exists { 69 // create a rate limit for it 70 rl.createRateLimit(identifier) 71 // remember this identifier (remote ip) for rl.timeToRemember before forgetting 72 rl.rememberIdentifier(identifier) 73 } 74 limiter := rl.limiters[identifier] 75 return limiter 76 } 77 78 // returns true if identifier currently allowed to access the resource 79 func (rl *TimedRateLimiter) access(identifier string) bool { 80 limiter := rl.getLimiter(identifier) 81 // consumes one token from the rate limiter bucket 82 allowed := limiter.Allow() 83 return allowed 84 } 85 86 func (rl *TimedRateLimiter) createRateLimit(identifier string) { 87 accessRate := rate.Every(rl.refreshPeriod) 88 limit := rate.NewLimiter(accessRate, rl.burst) 89 rl.limiters[identifier] = limit 90 } 91 92 func (rl *TimedRateLimiter) rememberIdentifier(identifier string) { 93 // timer already exists; refresh it 94 if timer, exists := rl.timers[identifier]; exists { 95 timer.Reset(rl.timeToRemember) 96 return 97 } 98 // new timer 99 timer := time.AfterFunc(rl.timeToRemember, func() { 100 rl.forgetLimiter(identifier) 101 }) 102 // map timer to its identifier 103 rl.timers[identifier] = timer 104 } 105 106 // forget the rate limiter associated for this identifier (to prevent memory growth over time) 107 func (rl *TimedRateLimiter) forgetLimiter(identifier string) { 108 if _, exists := rl.limiters[identifier]; exists { 109 delete(rl.limiters, identifier) 110 } 111 }