-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathspeedbump.go
152 lines (129 loc) · 3.8 KB
/
speedbump.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Package speedbump provides a Redis-backed rate limiter.
package speedbump
import (
"strconv"
"time"
"gopkg.in/redis.v5"
)
// RateLimiter is a Redis-backed rate limiter.
type RateLimiter struct {
// redisClient is the client that will be used to talk to the Redis server.
redisClient *redis.Client
// hasher is used to generate keys for each counter and to set their
// expiration time.
hasher RateHasher
// max defines the maximum number of attempts that can occur during a
// period.
max int64
}
// RateHasher is an object capable of generating a hash that uniquely
// identifies a counter to track the number of requests for an id over a
// certain time interval. The input of the Hash function can be any unique id,
// such as an IP address.
type RateHasher interface {
// Hash is the hashing function.
Hash(id string) string
// Duration returns the duration of each period. This is used to determine
// when to expire each counter key, and can also be used by other libraries
// to generate messages that provide an estimate of when the limit will
// expire.
Duration() time.Duration
}
// NewLimiter creates a new instance of a rate limiter.
func NewLimiter(
client *redis.Client,
hasher RateHasher,
max int64,
) *RateLimiter {
return &RateLimiter{
redisClient: client,
hasher: hasher,
max: max,
}
}
// Has returns whether the rate limiter has seen a request for a specific id
// during the current period.
func (r *RateLimiter) Has(id string) (bool, error) {
hash := r.hasher.Hash(id)
return r.redisClient.Exists(hash).Result()
}
// Attempted returns the number of attempted requests for an id in the current
// period. Attempted does not count attempts that exceed the max requests in an
// interval and only returns the max count after this is reached.
func (r *RateLimiter) Attempted(id string) (int64, error) {
hash := r.hasher.Hash(id)
val, err := r.redisClient.Get(hash).Result()
if err != nil {
if err == redis.Nil {
// Key does not exist. See: http://redis.io/commands/GET
return 0, nil
}
return 0, err
}
if err != nil {
return 0, err
}
return strconv.ParseInt(val, 10, 64)
}
// Left returns the number of remaining requests for id during a current
// period.
func (r *RateLimiter) Left(id string) (int64, error) {
// Retrieve attempted count.
attempted, err := r.Attempted(id)
if err != nil {
return 0, err
}
// Left is max minus attempted.
left := r.max - attempted
if left < 0 {
return 0, nil
}
return left, nil
}
// Attempt attempts to perform a request for an id and returns whether it was
// successful or not.
func (r *RateLimiter) Attempt(id string) (bool, error) {
// Create hash from id
hash := r.hasher.Hash(id)
// Get value for hash in Redis. If redis.Nil is returned, key does not
// exist.
exists := true
val, err := r.redisClient.Get(hash).Result()
if err != nil {
if err == redis.Nil {
// Key does not exist. See: http://redis.io/commands/GET
exists = false
} else {
return false, err
}
}
// If key exists and is >= max requests, return false.
if exists {
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return false, err
}
if intVal >= r.max {
return false, nil
}
}
// Otherwise, increment and expire key for hasher.Duration(). Note, we call
// Expire even when key already exists to avoid race condition where key
// expires between prior existence check and this Incr call.
//
// See: http://redis.io/commands/INCR
// See: http://redis.io/commands/INCR#pattern-rate-limiter-1
err = r.redisClient.Watch(func(rx *redis.Tx) error {
_, err := rx.Pipelined(func(pipe *redis.Pipeline) error {
if err := pipe.Incr(hash).Err(); err != nil {
return err
}
return pipe.Expire(hash, r.hasher.Duration()).Err()
})
return err
})
if err != nil {
return false, err
}
return true, nil
}