forked from mediocregopher/radix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn.go
354 lines (307 loc) · 9.49 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
package radix
import (
"bufio"
"crypto/tls"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/mediocregopher/radix/v3/resp"
)
// Conn is a Client wrapping a single network connection which synchronously
// reads/writes data using the redis resp protocol.
//
// A Conn can be used directly as a Client, but in general you probably want to
// use a *Pool instead.
type Conn interface {
// The Do method of a Conn is _not_ expected to be thread-safe with the
// other methods of Conn, and merely calls the Action's Run method with
// itself as the argument.
Client
// Encode and Decode may be called at the same time by two different
// go-routines, but each should only be called once at a time (i.e. two
// routines shouldn't call Encode at the same time, same with Decode).
//
// Encode and Decode should _not_ be called at the same time as Do.
//
// If either Encode or Decode encounter a net.Error the Conn will be
// automatically closed.
//
// Encode is expected to encode an entire resp message, not a partial one.
// In other words, when sending commands to redis, Encode should only be
// called once per command. Similarly, Decode is expected to decode an
// entire resp response.
Encode(resp.Marshaler) error
Decode(resp.Unmarshaler) error
// Returns the underlying network connection, as-is. Read, Write, and Close
// should not be called on the returned Conn.
NetConn() net.Conn
}
// ConnFunc is a function which returns an initialized, ready-to-be-used Conn.
// Functions like NewPool or NewCluster take in a ConnFunc in order to allow for
// things like calls to AUTH on each new connection, setting timeouts, custom
// Conn implementations, etc... See the package docs for more details.
type ConnFunc func(network, addr string) (Conn, error)
// DefaultConnFunc is a ConnFunc which will return a Conn for a redis instance
// using sane defaults.
var DefaultConnFunc = func(network, addr string) (Conn, error) {
return Dial(network, addr)
}
func wrapDefaultConnFunc(addr string) ConnFunc {
_, opts := parseRedisURL(addr)
return func(network, addr string) (Conn, error) {
return Dial(network, addr, opts...)
}
}
type connWrap struct {
net.Conn
brw *bufio.ReadWriter
}
// NewConn takes an existing net.Conn and wraps it to support the Conn interface
// of this package. The Read and Write methods on the original net.Conn should
// not be used after calling this method.
func NewConn(conn net.Conn) Conn {
return &connWrap{
Conn: conn,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
}
}
func (cw *connWrap) Do(a Action) error {
return a.Run(cw)
}
func (cw *connWrap) Encode(m resp.Marshaler) error {
if err := m.MarshalRESP(cw.brw); err != nil {
return err
}
return cw.brw.Flush()
}
func (cw *connWrap) Decode(u resp.Unmarshaler) error {
return u.UnmarshalRESP(cw.brw.Reader)
}
func (cw *connWrap) NetConn() net.Conn {
return cw.Conn
}
type dialOpts struct {
connectTimeout, readTimeout, writeTimeout time.Duration
authUser, authPass string
selectDB string
useTLSConfig bool
tlsConfig *tls.Config
}
// DialOpt is an optional behavior which can be applied to the Dial function to
// effect its behavior, or the behavior of the Conn it creates.
type DialOpt func(*dialOpts)
// DialConnectTimeout determines the timeout value to pass into net.DialTimeout
// when creating the connection. If not set then net.Dial is called instead.
func DialConnectTimeout(d time.Duration) DialOpt {
return func(do *dialOpts) {
do.connectTimeout = d
}
}
// DialReadTimeout determines the deadline to set when reading from a dialed
// connection. If not set then SetReadDeadline is never called.
func DialReadTimeout(d time.Duration) DialOpt {
return func(do *dialOpts) {
do.readTimeout = d
}
}
// DialWriteTimeout determines the deadline to set when writing to a dialed
// connection. If not set then SetWriteDeadline is never called.
func DialWriteTimeout(d time.Duration) DialOpt {
return func(do *dialOpts) {
do.writeTimeout = d
}
}
// DialTimeout is the equivalent to using DialConnectTimeout, DialReadTimeout,
// and DialWriteTimeout all with the same value.
func DialTimeout(d time.Duration) DialOpt {
return func(do *dialOpts) {
DialConnectTimeout(d)(do)
DialReadTimeout(d)(do)
DialWriteTimeout(d)(do)
}
}
const defaultAuthUser = "default"
// DialAuthPass will cause Dial to perform an AUTH command once the connection
// is created, using the given pass.
//
// If this is set and a redis URI is passed to Dial which also has a password
// set, this takes precedence.
//
// Using DialAuthPass is equivalent to calling DialAuthUser with user "default"
// and is kept for compatibility with older package versions.
func DialAuthPass(pass string) DialOpt {
return DialAuthUser(defaultAuthUser, pass)
}
// DialAuthUser will cause Dial to perform an AUTH command once the connection
// is created, using the given user and pass.
//
// If this is set and a redis URI is passed to Dial which also has a username
// and password set, this takes precedence.
func DialAuthUser(user, pass string) DialOpt {
return func(do *dialOpts) {
do.authUser = user
do.authPass = pass
}
}
// DialSelectDB will cause Dial to perform a SELECT command once the connection
// is created, using the given database index.
//
// If this is set and a redis URI is passed to Dial which also has a database
// index set, this takes precedence.
func DialSelectDB(db int) DialOpt {
return func(do *dialOpts) {
do.selectDB = strconv.Itoa(db)
}
}
// DialUseTLS will cause Dial to perform a TLS handshake using the provided
// config. If config is nil the config is interpreted as equivalent to the zero
// configuration. See https://golang.org/pkg/crypto/tls/#Config
func DialUseTLS(config *tls.Config) DialOpt {
return func(do *dialOpts) {
do.tlsConfig = config
do.useTLSConfig = true
}
}
type timeoutConn struct {
net.Conn
readTimeout, writeTimeout time.Duration
}
func (tc *timeoutConn) Read(b []byte) (int, error) {
if tc.readTimeout > 0 {
err := tc.Conn.SetReadDeadline(time.Now().Add(tc.readTimeout))
if err != nil {
return 0, err
}
}
return tc.Conn.Read(b)
}
func (tc *timeoutConn) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout))
if err != nil {
return 0, err
}
}
return tc.Conn.Write(b)
}
var defaultDialOpts = []DialOpt{
DialTimeout(10 * time.Second),
}
func parseRedisURL(urlStr string) (string, []DialOpt) {
// do a quick check before we bust out url.Parse, in case that is very
// unperformant
if !strings.HasPrefix(urlStr, "redis://") {
return urlStr, nil
}
u, err := url.Parse(urlStr)
if err != nil {
return urlStr, nil
}
q := u.Query()
username := defaultAuthUser
if n := u.User.Username(); n != "" {
username = n
} else if n := q.Get("username"); n != "" {
username = n
}
password := q.Get("password")
if p, ok := u.User.Password(); ok {
password = p
}
opts := []DialOpt{
DialAuthUser(username, password),
}
dbStr := q.Get("db")
if u.Path != "" && u.Path != "/" {
dbStr = u.Path[1:]
}
if dbStr, err := strconv.Atoi(dbStr); err == nil {
opts = append(opts, DialSelectDB(dbStr))
}
return u.Host, opts
}
// Dial is a ConnFunc which creates a Conn using net.Dial and NewConn. It takes
// in a number of options which can overwrite its default behavior as well.
//
// In place of a host:port address, Dial also accepts a URI, as per:
// https://www.iana.org/assignments/uri-schemes/prov/redis
// If the URI has an AUTH password or db specified Dial will attempt to perform
// the AUTH and/or SELECT as well.
//
// If either DialAuthPass or DialSelectDB is used it overwrites the associated
// value passed in by the URI.
//
// The default options Dial uses are:
//
// DialTimeout(10 * time.Second)
//
func Dial(network, addr string, opts ...DialOpt) (Conn, error) {
var do dialOpts
for _, opt := range defaultDialOpts {
opt(&do)
}
addr, addrOpts := parseRedisURL(addr)
for _, opt := range addrOpts {
opt(&do)
}
for _, opt := range opts {
opt(&do)
}
var netConn net.Conn
var err error
dialer := net.Dialer{}
if do.connectTimeout > 0 {
dialer.Timeout = do.connectTimeout
}
if do.useTLSConfig {
netConn, err = tls.DialWithDialer(&dialer, network, addr, do.tlsConfig)
} else {
netConn, err = dialer.Dial(network, addr)
}
if err != nil {
return nil, err
}
// If the netConn is a net.TCPConn (or some wrapper for it) and so can have
// keepalive enabled, do so with a sane (though slightly aggressive)
// default.
{
type keepaliveConn interface {
SetKeepAlive(bool) error
SetKeepAlivePeriod(time.Duration) error
}
if kaConn, ok := netConn.(keepaliveConn); ok {
if err = kaConn.SetKeepAlive(true); err != nil {
netConn.Close()
return nil, err
} else if err = kaConn.SetKeepAlivePeriod(10 * time.Second); err != nil {
netConn.Close()
return nil, err
}
}
}
conn := NewConn(&timeoutConn{
readTimeout: do.readTimeout,
writeTimeout: do.writeTimeout,
Conn: netConn,
})
if do.authUser != "" && do.authUser != defaultAuthUser {
if err := conn.Do(Cmd(nil, "AUTH", do.authUser, do.authPass)); err != nil {
conn.Close()
return nil, err
}
} else if do.authPass != "" {
if err := conn.Do(Cmd(nil, "AUTH", do.authPass)); err != nil {
conn.Close()
return nil, err
}
}
if do.selectDB != "" {
if err := conn.Do(Cmd(nil, "SELECT", do.selectDB)); err != nil {
conn.Close()
return nil, err
}
}
return conn, nil
}