-
Notifications
You must be signed in to change notification settings - Fork 12
/
buffer.go
579 lines (502 loc) · 14.1 KB
/
buffer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
package tds
// All packet and messages encapsulation goes here.
// No protocol logic except bytes shuffling.
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
"time"
bin "github.com/thda/tds/binary"
)
// headerSize is the size of the tds header
const headerSize = 8
const (
_ = iota
eom
cancelAck
cancel
)
// header is the header for netlib packets which enclose all messages
type header struct {
token packetType
status uint8
packetSize uint16
spid uint16
packetNo uint8
pad uint8
}
// Read deserializes a PacketHeader struct
func (h *header) read(e *bin.Encoder) error {
h.token = packetType(e.ReadByte())
h.status = e.Uint8()
h.packetSize = e.Uint16()
h.spid = e.Uint16()
h.packetNo = e.Uint8()
h.pad = e.Uint8()
err := e.Err()
return err
}
// Write serializes a PacketHeader struct
func (h header) write(e *bin.Encoder) error {
e.WriteByte(byte(h.token))
e.WriteByte(h.status)
e.WriteUint16(h.packetSize)
e.WriteUint16(h.spid)
e.WriteUint8(h.packetNo)
e.WriteUint8(h.pad)
err := e.Err()
return err
}
const maxMsgBufSize = 25000
// defaultCancelTimeout is the number of seconds to wait for the cancel to be sent
const defaultCancelTimeout = 10
// buf reads and writes netlib packets with proper header and size
type buf struct {
rw io.ReadWriter
h header // packet header
pb bytes.Buffer // packet buffer
d [50]byte // discard buffer
mb bytes.Buffer // message buffer, used to easily compute message length
me bin.Encoder // message encoder. This one is buffered
he bin.Encoder // header encoder. Reads from the network, writes to the write buffer
// packet encoder. Reads/Writes goes to this structure's read/write function to split into TDS packets.
pe bin.Encoder
debug bool
PacketSize int
// Timeouts/context variables
cancelCh chan error // chanel to inform on cancel completion
inCancel int32 // set to 1 if a cancel query is pending
WriteTimeout int
ReadTimeout int
CancelTimeout int // number of seconds before cancel is timed out and connection is marked dead
defaultMessageMap map[token]messageReader
}
// newBuf inits a buffer struct with the different buffers for packet, message and header
func newBuf(packetSize int, rw io.ReadWriter) *buf {
b := new(buf)
b.PacketSize = packetSize
b.rw = rw
b.me = bin.NewEncoder(&b.mb, binary.LittleEndian)
b.pe = bin.NewEncoder(b, binary.LittleEndian)
b.he = bin.NewEncoder(&struct {
io.Reader
io.Writer
}{b.rw, &b.pb}, binary.BigEndian)
b.cancelCh = make(chan error, 1)
b.CancelTimeout = defaultCancelTimeout
return b
}
// SetEndianness changes the endianness for the packet encoder and the packet buffer
func (b *buf) SetEndianness(endianness binary.ByteOrder) {
b.pe.SetEndianness(endianness)
b.me.SetEndianness(endianness)
}
// SetCharset changes the charset for the packet encoder and the packet buffer
func (b *buf) SetCharset(c string) error {
e, err := getEncoding(c)
if err != nil {
return fmt.Errorf("netlib: could not find encoder for %s", c)
}
b.pe.SetCharset(e)
b.me.SetCharset(e)
return nil
}
// initPkt sets the packet type and send the header.
// Usually called whenever the packet type changes and after a message send,
// when other messages are expected
func (b *buf) initPkt(t packetType) {
b.pb.Reset()
b.h.token, b.h.status = t, 0
b.h.write(&b.he)
}
// readPkt reads a tds packet and fills the header information
func (b *buf) readPkt(ignoreCan bool) (err error) {
b.pb.Reset()
// Actually read packet
if err = b.h.read(&b.he); err != nil {
return err
}
if _, err = io.CopyN(&b.pb, b.rw, int64(b.h.packetSize)-headerSize); err != nil {
return err
}
// check for cancel signal
if !ignoreCan && b.cancelling() {
err = b.processCancel()
}
return err
}
// sendPkt sends a packet to the underlying writer
// It writes the header, the payload and flushes if needed
func (b *buf) sendPkt(status uint8) (err error) {
b.pb.Bytes()[1] = status
// discard packets until the last one when cancelling.
// When a cancel is spotter, we must first send
// a packet with cancel and eom bit set,
// then ignore all the next packets until the caller
// indicates the end of the conversation by calling
// sendPkt giving a status parameter with eom bit set.
if b.cancelling() {
if b.h.status&cancel == cancel &&
status&eom != eom {
return nil
}
b.h.status |= cancel
b.pb.Bytes()[1] = eom | cancel
}
// set packet length and status
binary.BigEndian.PutUint16(b.pb.Bytes()[2:], uint16(b.pb.Len()))
// single call to write, needed for concurrent writes.
_, err = b.pb.WriteTo(b.rw)
// not the last packet, write header for next
if status&eom == 0 {
b.initPkt(b.h.token)
} else if b.cancelling() {
// last packet of a canceled request, process cancel ack
b.h.status = 0
return b.processCancel()
}
return err
}
// Read reads from the reader and fills the scratch buffer buf.
// Will also return the number of bytes read.
// Eventually reads the next packet if needed.
// Implements the io.Reader interface
func (b *buf) Read(buf []byte) (n int, err error) {
n, err = io.ReadFull(&b.pb, buf)
// could not read all now, proceed next packet
if err == io.ErrUnexpectedEOF || err == io.EOF {
if err = b.readPkt(false); err != nil {
return 0, err
}
} else if n == 0 && b.h.status == 1 {
// all data read
return 0, io.EOF
}
return n, err
}
// Write writes the buffer's data to the underlying writer.
// We flush whenever we fill the write buffer using sendPacket.
// Implements the io.Writer interface
func (b *buf) Write(p []byte) (n int, err error) {
var copied, remaining int
for {
remaining = int(b.PacketSize) - b.pb.Len()
// check if the write would fill the current packet.
if len(p) >= remaining {
n, err = b.pb.Write(p[:remaining])
} else {
n, err = b.pb.Write(p[:])
}
if err != nil {
return 0, err
}
p = p[n:]
copied += n
if len(p) == 0 {
return copied, nil
}
// not all data was copied in this batch, flush packet
if err = b.sendPkt(0); err != nil {
return copied, err
}
}
}
// Skip skips a given amount of bytes
func (b *buf) skip(cnt int) (err error) {
if cnt == 0 {
return nil
}
// optimize for a small skip. Usually done tokens
if cnt < len(b.d) {
_, err = io.ReadFull(b, b.d[:cnt])
return err
}
for skipped := 0; skipped < cnt; skipped += len(b.d) {
if cnt-skipped < len(b.d) {
_, err = io.ReadFull(b, b.d[:cnt%len(b.d)])
} else {
_, err = io.ReadFull(b, b.d[:])
}
if err != nil {
return err
}
}
return nil
}
// peek will read one byte without affecting the offset
func (b *buf) peek() (out byte, err error) {
out = b.pe.ReadByte()
err = b.pe.Err()
b.pb.UnreadByte()
return out, err
}
// writeMsg writes the message tok, computes the message size
// and writes it to the underlying writer.
// This is used when the tds needs a length right after the token
// for non-fixed length messages
func (b *buf) writeMsg(msg messageReaderWriter) (err error) {
b.mb.Reset()
if err = msg.Write(&b.me); err != nil {
return err
}
if msg.Token() != token(nonePacket) {
b.pe.WriteByte(byte(msg.Token()))
}
// if it's a token with a known size, write it
switch msg.SizeLen() {
case 8:
b.pe.WriteInt8(int8(b.mb.Len()))
case 16:
b.pe.WriteInt16(int16(b.mb.Len()))
case 32:
b.pe.WriteInt32(int32(b.mb.Len()))
}
if err = b.pe.Err(); err != nil {
return err
}
// Write to packet buffer
_, err = b.mb.WriteTo(b)
// reset buffer and check for its size
if b.mb.Cap() > maxMsgBufSize {
b.mb = *new(bytes.Buffer)
}
return err
}
// readMsg reads a message from the underlying connection.
func (b *buf) readMsg(msg messageReader) (err error) {
var size int
switch msg.SizeLen() {
case 8:
size = int(b.pe.Uint8())
case 16:
size = int(b.pe.Uint16())
case 32:
size = int(b.pe.Uint32())
}
if err = b.pe.Err(); err != nil {
return err
}
// For some messages, we have no way to know the end of a field/a serie of fields
// before reaching the last byte, as given by the packet size field.
// For those, we set the Encoder's reader to a limitedReader which
// will signal the end of processing by io.EOF.
// Reverted afterwards.
if msg.LimitRead() {
b.pe.LimitRead(int64(size))
defer func(e *bin.Encoder) {
e.UnlimitRead()
}(&b.pe)
}
return msg.Read(&b.pe)
}
// skipMsg skips a message according to its length.
func (b *buf) skipMsg(msg messageReader) (err error) {
var size int
// check for existence
if msg.Size() != 0 {
size = int(msg.Size())
} else {
switch msg.SizeLen() {
default:
return fmt.Errorf("netlib: unknown token size for %s message", msg)
case 8:
size = int(b.pe.Uint8())
case 16:
size = int(b.pe.Uint16())
case 32:
size = int(b.pe.Uint32())
}
}
err = b.skip(size)
return err
}
// send sends a list of messages given as parameters
func (b *buf) send(ctx context.Context, pt packetType, msgs ...messageReaderWriter) (err error) {
// init packet header
b.initPkt(pt)
// create a context with a Timeout of WriteTimeout if no particular context given
if ctx == nil && b.WriteTimeout > 0 {
var cancelFunc func()
ctx, cancelFunc = context.WithTimeout(context.Background(), time.Duration(b.WriteTimeout)*time.Second)
defer cancelFunc()
}
// start Timeout watcher
if ctx != nil {
if cancel := b.watchCancel(ctx, false); cancel != nil {
defer cancel()
}
}
// send messages
for _, msg := range msgs {
if err = b.writeMsg(msg); err != nil {
return err
}
}
// flush
return b.sendPkt(1)
}
// netlib session state
type state struct {
t token // last token read
handler func(t token) error // message handler to run after read
msg map[token]messageReader // messages to read into
err error // error faced during read
ctx context.Context
}
// session state
type stateFn func(*state) stateFn
// receive reads messages and updates the state accordingly.
// returns a state function to process next message.
func (b *buf) receive(s *state) stateFn {
// create a context with a Timeout of ReadTimeout if no particular context given
if s.ctx == nil && b.ReadTimeout > 0 {
var cancelFunc func()
s.ctx, cancelFunc = context.WithTimeout(context.Background(), time.Duration(b.ReadTimeout)*time.Second)
defer cancelFunc()
}
// start Timeout watcher
if s.ctx != nil {
if cancel := b.watchCancel(s.ctx, true); cancel != nil {
defer cancel()
}
}
s.t = token(b.pe.ReadByte())
if s.err = b.pe.Err(); s.err != nil {
// we should not be at EOF here
if s.err == io.EOF {
s.err = fmt.Errorf("netlib: unexpected EOF while reading message")
}
return nil
}
// expecting reply here
if b.h.token != normalPacket && b.h.token != replyPacket {
s.err = fmt.Errorf("netlib: expected reply or normal token, got %s", b.h.token)
return nil
}
// check if the message is in the ones to return
// and attempt to skip if not found
msg, ok := b.defaultMessageMap[s.t]
// look in provided message map
if !ok {
msg, ok = s.msg[s.t]
}
// message not in message maps, skip
if !ok {
if s.err = b.skipMsg(emptyMsg{msg: newMsg(s.t)}); s.err != nil {
return nil
}
} else {
// read the message
s.err = b.readMsg(msg)
if s.err != nil {
return nil
}
// call message handler
if s.err = s.handler(s.t); s.err != nil {
return nil
}
}
// return
return b.receive
}
// watchCancel will start a cancelation goroutine
// if the context can be terminated.
// Returns a function to end the goroutine
func (b *buf) watchCancel(ctx context.Context, reading bool) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
_ = b.cancel(ctx.Err(), reading)
finished <- struct{}{}
case <-finished:
}
}()
return func() {
select {
case <-finished:
case finished <- struct{}{}:
}
}
}
return nil
}
// cancel simply sends a cancel message to the cancel channel.
func (b *buf) cancel(cancelErr error, reading bool) (err error) {
if swapped := atomic.CompareAndSwapInt32(&b.inCancel, 0, 1); !swapped {
// cancel already in progress
return nil
}
// send to the cancel channel when the cancel is sent or an error is faced.
defer func() {
b.cancelCh <- cancelErr
}()
// set deadline on the underlying conn to be sure to process on time
if conn, ok := b.rw.(net.Conn); ok {
defer conn.SetDeadline(time.Time{})
err = conn.SetDeadline(time.Now().Add(time.Duration(b.CancelTimeout) * time.Second))
if err != nil {
return err
}
}
// we are currently reading, so we need to send a cancel packet
// to avoid draining cancel channel
if reading {
canBuf := newBuf(int(b.h.packetSize), b.rw)
canBuf.initPkt(cancelPacket)
err = canBuf.sendPkt(1)
}
return err
}
// cancelling checks if a cancel was requested.
func (b *buf) cancelling() bool {
return atomic.LoadInt32(&b.inCancel) == 1
}
// processCancel reads packets until finding the cancel ack.
func (b *buf) processCancel() (err error) {
var cancelErr error
defer atomic.StoreInt32(&b.inCancel, 0)
// this will effectively block until cancel packet is sent
select {
case cancelErr = <-b.cancelCh:
case <-time.After(time.Duration(b.CancelTimeout) * time.Second):
return fmt.Errorf("netlib: timeout while processing cancel")
}
// read until last packet
for {
// last packet read
if b.h.status&eom != 0 {
break
}
if err = b.readPkt(true); err != nil {
return fmt.Errorf("netlib: error while reading cancel packet: %s", err)
}
}
// the server has 2 ways to send cancel ack:
// - a normal packet with headerCancelAck status bit set
// - a reply packet containing a done message with doneCancel bit set
switch b.h.token {
default:
err = fmt.Errorf("netlib: unexpected token type %s while looking for cancel token", b.h.token)
case normalPacket:
err = cancelErr
if b.h.status&cancelAck == 0 {
err = errors.New("netlib: Timeout reached, yet the cancel was not acknowledged")
}
case replyPacket:
if err = b.skip(b.pb.Len() - 9); err == nil {
err = cancelErr
// find done token with cancel ack bit set
if !(b.pe.ReadByte() == 0xFD && int(b.pe.Uint16())&0x0020 != 0) {
err = errors.New("netlib: Timeout reached, yet the cancel was not acknowledged")
}
}
}
return err
}