diff --git a/Gopkg.lock b/Gopkg.lock index a4d877e..1d73817 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -103,6 +103,12 @@ revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" version = "v1.1.4" +[[projects]] + branch = "master" + name = "github.com/troian/goring" + packages = ["."] + revision = "f23b2d237abc4603ebeb2509c2cf1907debdf83f" + [[projects]] branch = "master" name = "github.com/troian/omap" @@ -154,6 +160,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "ff53b5f9fcbd74c1a6d7008101ed8b2f5145a7ffc9c6fc7e77ee659993d4a1f7" + inputs-digest = "79d21ebc8b01ced112de19100ca600001e9e62e84ceddc6efc2fe21118ac8a63" solver-name = "gps-cdcl" solver-version = 1 diff --git a/buffer/buffer.go b/buffer/buffer.go deleted file mode 100644 index fda2683..0000000 --- a/buffer/buffer.go +++ /dev/null @@ -1,665 +0,0 @@ -// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bufio" - "errors" - "fmt" - "io" - "sync" - "sync/atomic" -) - -var ( - bufCNT int64 -) - -var ( - // ErrInsufficientData buffer has insufficient data - ErrInsufficientData = errors.New("buffer has insufficient data") - - // ErrInsufficientSpace buffer has insufficient space - ErrInsufficientSpace = errors.New("buffer has insufficient space") - - // ErrNotReady buffer is not ready yet - ErrNotReady = errors.New("buffer is not ready") -) - -const ( - // DefaultBufferSize buffer size created by default - DefaultBufferSize = 1024 * 256 - // DefaultReadBlockSize default read block size - DefaultReadBlockSize = 1500 - // DefaultWriteBlockSize default write block size - DefaultWriteBlockSize = 1500 -) - -type sequence struct { - // The current position of the producer or consumer - cursor int64 - - // The previous known position of the consumer (if producer) or producer (if consumer) - gate int64 - - // These are fillers to pad the cache line, which is generally 64 bytes - //p2 int64 - //p3 int64 - //p4 int64 - //p5 int64 - //p6 int64 - //p7 int64 -} - -func newSequence() *sequence { - return &sequence{} -} - -func (b *sequence) get() int64 { - return atomic.LoadInt64(&b.cursor) -} - -func (b *sequence) set(seq int64) { - atomic.StoreInt64(&b.cursor, seq) -} - -// Type of buffer -// align atomic values to prevent panics on 32 bits macnines -// see https://github.com/golang/go/issues/5278 -type Type struct { - id int64 - size int64 - mask int64 - done int64 - - buf []byte - tmp []byte - ExternalBuf []byte - - pSeq *sequence - cSeq *sequence - pCond *sync.Cond - cCond *sync.Cond -} - -// New buffer -func New(size int64) (*Type, error) { - if size < 0 { - return nil, bufio.ErrNegativeCount - } - - if size == 0 { - size = DefaultBufferSize - } - - if !powerOfTwo64(size) { - return nil, fmt.Errorf("Size must be power of two. Try %d", roundUpPowerOfTwo64(size)) - } - - if size < 2*DefaultReadBlockSize { - return nil, fmt.Errorf("Size must at least be %d. Try %d", 2*DefaultReadBlockSize, 2*DefaultReadBlockSize) - } - - return &Type{ - id: atomic.AddInt64(&bufCNT, 1), - ExternalBuf: make([]byte, size), - buf: make([]byte, size), - size: size, - mask: size - 1, - pSeq: newSequence(), - cSeq: newSequence(), - pCond: sync.NewCond(new(sync.Mutex)), - cCond: sync.NewCond(new(sync.Mutex)), - }, nil -} - -// ID of buffer -func (b *Type) ID() int64 { - return b.id -} - -// Close buffer -func (b *Type) Close() error { - atomic.StoreInt64(&b.done, 1) - - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - - b.pCond.L.Lock() - b.cCond.Broadcast() - b.pCond.L.Unlock() - - return nil -} - -// Len of data -func (b *Type) Len() int { - cpos := b.cSeq.get() - ppos := b.pSeq.get() - return int(ppos - cpos) -} - -// Size of buffer -func (b *Type) Size() int64 { - return b.size -} - -// ReadFrom from reader -func (b *Type) ReadFrom(r io.Reader) (int64, error) { - total := int64(0) - - for { - if b.isDone() { - return total, io.EOF - } - - start, cnt, err := b.waitForWriteSpace(DefaultReadBlockSize) - if err != nil { - return 0, err - } - - pStart := start & b.mask - pEnd := pStart + int64(cnt) - if pEnd > b.size { - pEnd = b.size - } - - n, err := r.Read(b.buf[pStart:pEnd]) - if n > 0 { - total += int64(n) - if _, err = b.WriteCommit(n); err != nil { - return total, err - } - } - - if err != nil { - return total, err - } - } -} - -// WriteTo to writer -func (b *Type) WriteTo(w io.Writer) (int64, error) { - total := int64(0) - - for { - if b.isDone() { - return total, io.EOF - } - - p, err := b.ReadPeek(DefaultWriteBlockSize) - // There's some data, let's process it first - if len(p) > 0 { - var n int - n, err = w.Write(p) - total += int64(n) - - if err != nil { - return total, err - } - - _, err = b.ReadCommit(n) - if err != nil { - return total, err - } - } - - if err != nil { - if err != ErrInsufficientData { - return total, err - } - } - } -} - -// Read data -func (b *Type) Read(p []byte) (int, error) { - if b.isDone() && b.Len() == 0 { - return 0, io.EOF - } - - pl := int64(len(p)) - - for { - cPos := b.cSeq.get() - pPos := b.pSeq.get() - cIndex := cPos & b.mask - - // If consumer position is at least len(p) less than producer position, that means - // we have enough data to fill p. There are two scenarios that could happen: - // 1. cIndex + len(p) < buffer size, in this case, we can just copy() data from - // buffer to p, and copy will just copy enough to fill p and stop. - // The number of bytes copied will be len(p). - // 2. cIndex + len(p) > buffer size, this means the data will wrap around to the - // the beginning of the buffer. In thise case, we can also just copy data from - // buffer to p, and copy will just copy until the end of the buffer and stop. - // The number of bytes will NOT be len(p) but less than that. - if cPos+pl < pPos { - n := copy(p, b.buf[cIndex:]) - - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - - return n, nil - } - - // If we got here, that means there's not len(p) data available, but there might - // still be data. - - // If cPos < pPos, that means there's at least pPos-cPos bytes to read. Let's just - // send that back for now. - if cPos < pPos { - // n bytes available - avail := pPos - cPos - - // bytes copied - var n int - - // if cIndex+n < size, that means we can copy all n bytes into p. - // No wrapping in this case. - if cIndex+avail < b.size { - n = copy(p, b.buf[cIndex:cIndex+avail]) - } else { - // If cIndex+n >= size, that means we can copy to the end of buffer - n = copy(p, b.buf[cIndex:]) - } - - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - return n, nil - } - - // If we got here, that means cPos >= pPos, which means there's no data available. - // If so, let's wait... - - b.cCond.L.Lock() - for pPos = b.pSeq.get(); cPos >= pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return 0, io.EOF - } - - //b.cWait++ - b.cCond.Wait() - } - b.cCond.L.Unlock() - } -} - -// Write message -func (b *Type) Write(p []byte) (int, error) { - if b.isDone() { - return 0, io.EOF - } - - start, _, err := b.waitForWriteSpace(len(p)) - if err != nil { - return 0, err - } - - // If we are here that means we now have enough space to write the full p. - // Let's copy from p into this.buf, starting at position ppos&this.mask. - total := ringCopy(b.buf, p, start&b.mask) - - b.pSeq.set(start + int64(len(p))) - b.cCond.L.Lock() - b.cCond.Broadcast() - b.cCond.L.Unlock() - - return total, nil -} - -// ReadPeek Description below is copied completely from bufio.Peek() -// http://golang.org/pkg/bufio/#Reader.Peek -// Peek returns the next n bytes without advancing the reader. The bytes stop being valid -// at the next read call. If Peek returns fewer than n bytes, it also returns an error -// explaining why the read is short. The error is bufio.ErrBufferFull if n is larger than -// b's buffer size. -// If there's not enough data to peek, error is ErrBufferInsufficientData. -// If n < 0, error is bufio.ErrNegativeCount -func (b *Type) ReadPeek(n int) ([]byte, error) { - if int64(n) > b.size { - return nil, bufio.ErrBufferFull - } - - if n < 0 { - return nil, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // If there's no data, then let's wait until there is some data - b.cCond.L.Lock() - for ; cPos >= pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return nil, io.EOF - } - - //b.cWait++ - b.cCond.Wait() - } - b.cCond.L.Unlock() - - // m = the number of bytes available. If m is more than what's requested (n), - // then we make m = n, basically peek max n bytes - m := pPos - cPos - err := error(nil) - - if m >= int64(n) { - m = int64(n) - } else { - err = ErrInsufficientData - } - - // There's data to peek. The size of the data could be <= n. - if cPos+m <= pPos { - cindex := cPos & b.mask - - // If cindex (index relative to buffer) + n is more than buffer size, that means - // the data wrapped - if cindex+m > b.size { - // reset the tmp buffer - b.tmp = b.tmp[0:0] - - l := len(b.buf[cindex:]) - b.tmp = append(b.tmp, b.buf[cindex:]...) - b.tmp = append(b.tmp, b.buf[0:m-int64(l)]...) - return b.tmp, err - } - - return b.buf[cindex : cindex+m], err - } - - return nil, ErrInsufficientData -} - -// ReadWait waits for for n bytes to be ready. If there's not enough data, then it will -// wait until there's enough. This differs from ReadPeek or Readin that Peek will -// return whatever is available and won't wait for full count. -func (b *Type) ReadWait(n int) ([]byte, error) { - if int64(n) > b.size { - return nil, bufio.ErrBufferFull - } - - if n < 0 { - return nil, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // This is the magic read-to position. The producer position must be equal or - // greater than the next position we read to. - next := cPos + int64(n) - - // If there's no data, then let's wait until there is some data - b.cCond.L.Lock() - for ; next > pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return nil, io.EOF - } - - b.cCond.Wait() - } - b.cCond.L.Unlock() - - //if b.isDone() { - // return nil, io.EOF - //} - - // If we are here that means we have at least n bytes of data available. - cIndex := cPos & b.mask - - // If cIndex (index relative to buffer) + n is more than buffer size, that means - // the data wrapped - if cIndex+int64(n) > b.size { - // reset the tmp buffer - b.tmp = b.tmp[0:0] - - l := len(b.buf[cIndex:]) - b.tmp = append(b.tmp, b.buf[cIndex:]...) - b.tmp = append(b.tmp, b.buf[0:n-l]...) - return b.tmp[:n], nil - } - - return b.buf[cIndex : cIndex+int64(n)], nil -} - -// ReadCommit Commit moves the cursor forward by n bytes. It behaves like Read() except it doesn't -// return any data. If there's enough data, then the cursor will be moved forward and -// n will be returned. If there's not enough data, then the cursor will move forward -// as much as possible, then return the number of positions (bytes) moved. -func (b *Type) ReadCommit(n int) (int, error) { - if int64(n) > b.size { - return 0, bufio.ErrBufferFull - } - - if n < 0 { - return 0, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // If consumer position is at least n less than producer position, that means - // we have enough data to fill p. There are two scenarios that could happen: - // 1. cindex + n < buffer size, in this case, we can just copy() data from - // buffer to p, and copy will just copy enough to fill p and stop. - // The number of bytes copied will be len(p). - // 2. cindex + n > buffer size, this means the data will wrap around to the - // the beginning of the buffer. In thise case, we can also just copy data from - // buffer to p, and copy will just copy until the end of the buffer and stop. - // The number of bytes will NOT be len(p) but less than that. - if cPos+int64(n) <= pPos { - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - return n, nil - } - - return 0, ErrInsufficientData -} - -// WriteWait waits for n bytes to be available in the buffer and then returns -// 1. the slice pointing to the location in the buffer to be filled -// 2. a boolean indicating whether the bytes available wraps around the ring -// 3. any errors encountered. If there's error then other return values are invalid -func (b *Type) WriteWait(n int) ([]byte, bool, error) { - start, cnt, err := b.waitForWriteSpace(n) - if err != nil { - return nil, false, err - } - - pStart := start & b.mask - if pStart+int64(cnt) > b.size { - return b.buf[pStart:], true, nil - } - - return b.buf[pStart : pStart+int64(cnt)], false, nil -} - -// WriteCommit write with commit -func (b *Type) WriteCommit(n int) (int, error) { - start, cnt, err := b.waitForWriteSpace(n) - if err != nil { - return 0, err - } - - // If we are here then there's enough bytes to commit - b.pSeq.set(start + int64(cnt)) - - b.cCond.L.Lock() - b.cCond.Broadcast() - b.cCond.L.Unlock() - - return cnt, nil -} - -// Send to -func (b *Type) Send(from [][]byte) (int, error) { - defer func() { - if int64(len(b.ExternalBuf)) > b.size { - b.ExternalBuf = make([]byte, b.size) - } - }() - - var total int - - for _, s := range from { - remaining := len(s) - offset := 0 - for remaining > 0 { - toWrite := remaining - if toWrite > int(b.Size()) { - toWrite = int(b.Size()) - } - - var wrote int - var err error - - if wrote, err = b.Write(s[offset : offset+toWrite]); err != nil { - return 0, err - } - - remaining -= wrote - offset += wrote - } - total += len(s) - } - - return total, nil -} - -func (b *Type) waitForWriteSpace(n int) (int64, int, error) { - if b.isDone() { - return 0, 0, io.EOF - } - - // The current producer position, remember it's a forever inreasing int64, - // NOT the position relative to the buffer - pPos := b.pSeq.get() - - // The next producer position we will get to if we write len(p) - next := pPos + int64(n) - - // For the producer, gate is the previous consumer sequence. - gate := b.pSeq.gate - - wrap := next - b.size - - // If wrap point is greater than gate, that means the consumer hasn't read - // some of the data in the buffer, and if we read in additional data and put - // into the buffer, we would overwrite some of the unread data. It means we - // cannot do anything until the customers have passed it. So we wait... - // - // Let's say size = 16, block = 4, pPos = 0, gate = 0 - // then next = 4 (0+4), and wrap = -12 (4-16) - // _______________________________________________________________________ - // | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | - // ----------------------------------------------------------------------- - // ^ ^ - // pPos, next - // gate - // - // So wrap (-12) > gate (0) = false, and gate (0) > pPos (0) = false also, - // so we move on (no waiting) - // - // Now if we get to pPos = 14, gate = 12, - // then next = 18 (4+14) and wrap = 2 (18-16) - // - // So wrap (2) > gate (12) = false, and gate (12) > pPos (14) = false aos, - // so we move on again - // - // Now let's say we have pPos = 14, gate = 0 still (nothing read), - // then next = 18 (4+14) and wrap = 2 (18-16) - // - // So wrap (2) > gate (0) = true, which means we have to wait because if we - // put data into the slice to the wrap point, it would overwrite the 2 bytes - // that are currently unread. - // - // Another scenario, let's say pPos = 100, gate = 80, - // then next = 104 (100+4) and wrap = 88 (104-16) - // - // So wrap (88) > gate (80) = true, which means we have to wait because if we - // put data into the slice to the wrap point, it would overwrite the 8 bytes - // that are currently unread. - // - if wrap > gate || gate > pPos { - var cPos int64 - b.pCond.L.Lock() - for cPos = b.cSeq.get(); wrap > cPos; cPos = b.cSeq.get() { - if b.isDone() { - return 0, 0, io.EOF - } - - //b.pWait++ - b.pCond.Wait() - } - - b.pSeq.gate = cPos - b.pCond.L.Unlock() - } - - return pPos, n, nil -} - -func (b *Type) isDone() bool { - return atomic.LoadInt64(&b.done) == 1 -} - -func ringCopy(dst, src []byte, start int64) int { - n := len(src) - - var i int - var l int - - for n > 0 { - l = copy(dst[start:], src[i:]) - i += l - n -= l - - if n > 0 { - start = 0 - } - } - - return i -} - -func powerOfTwo64(n int64) bool { - return n != 0 && (n&(n-1)) == 0 -} - -func roundUpPowerOfTwo64(n int64) int64 { - n-- - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n |= n >> 32 - n++ - - return n -} diff --git a/buffer/buffer_test.go b/buffer/buffer_test.go deleted file mode 100644 index 75aa7e5..0000000 --- a/buffer/buffer_test.go +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "io" - "testing" - "time" - - "bufio" - - "github.com/stretchr/testify/require" -) - -func TestBufferSequence(t *testing.T) { - seq := newSequence() - - seq.set(100) - require.Equal(t, int64(100), seq.get()) - - seq.set(20000) - require.Equal(t, int64(20000), seq.get()) -} - -func TestBufferReadFrom(t *testing.T) { - testFillBuffer(t, 144, 16384) - testFillBuffer(t, 2048, 16384) - testFillBuffer(t, 3072, 16384) -} - -func TestBufferReadBytes(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - testReadBytes(t, buf) -} - -func TestBufferCommitBytes(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - testCommit(t, buf) -} - -func TestBufferConsumerProducerRead(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testRead(t, buf) -} - -func TestBufferConsumerProducerWriteTo(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testWriteTo(t, buf) -} - -func TestBufferConsumerProducerPeekCommit(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testPeekCommit(t, buf) -} - -func TestBufferPeek(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - peekBuffer(t, buf, 100) - peekBuffer(t, buf, 1000) -} - -func TestBufferNew(t *testing.T) { - _, err := New(-1) - require.EqualError(t, bufio.ErrNegativeCount, err.Error()) - - _, err = New(863) - require.Error(t, err) - - _, err = New(1024) - require.Error(t, err) - - _, err = New(39666) - require.Error(t, err) -} - -func TestBufferID(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - require.NotEqual(t, 0, buf.ID()) -} - -func TestBufferSize(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - require.NotEqual(t, 0, buf.Size()) -} - -func TestBufferClosed(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - p := make([]byte, 1024) - for i := range p { - p[i] = 'a' - } - - _, err = buf.ReadFrom(bytes.NewBuffer(p)) - require.EqualError(t, io.EOF, err.Error()) - - go func() { - buf.Close() // nolint: errcheck - }() -} - -func BenchmarkBufferConsumerProducerRead(b *testing.B) { - buf, _ := New(0) - benchmarkRead(b, buf) -} - -func testFillBuffer(t *testing.T, bufsize, ringsize int64) *Type { - buf, err := New(ringsize) - - require.NoError(t, err) - - fillBuffer(t, buf, bufsize) - - require.Equal(t, int(bufsize), buf.Len()) - - return buf -} - -func fillBuffer(t *testing.T, buf *Type, bufsize int64) { - p := make([]byte, bufsize) - for i := range p { - p[i] = 'a' - } - - n, err := buf.ReadFrom(bytes.NewBuffer(p)) - - require.Equal(t, bufsize, n) - require.Equal(t, err, io.EOF) -} - -func peekBuffer(t *testing.T, buf *Type, n int) { - pkbuf, err := buf.ReadPeek(n) - - require.NoError(t, err) - require.Equal(t, n, len(pkbuf)) - - for _, b := range pkbuf { - require.Equal(t, byte('a'), b) - } -} - -func testPeekCommit(t *testing.T, buf *Type) { - n := 20000 - - go func(n int64) { - fillBuffer(t, buf, n) - }(int64(n)) - - i := 0 - - for n > 0 { - pkbuf, _ := buf.ReadPeek(1024) - l, err := buf.ReadCommit(len(pkbuf)) - - require.NoError(t, err) - - n -= l - i += l - } -} - -func testWriteTo(t *testing.T, buf *Type) { - n := int64(20000) - - go func(n int64) { - fillBuffer(t, buf, n) - time.Sleep(time.Millisecond * 100) - buf.Close() // nolint: errcheck - }(n) - - m, err := buf.WriteTo(bytes.NewBuffer(make([]byte, n))) - - require.Equal(t, io.EOF, err) - require.Equal(t, int64(20000), m) -} - -func testRead(t *testing.T, buf *Type) { - n := int64(20000) - - go func(n int64) { - fillBuffer(t, buf, n) - }(n) - - p := make([]byte, n) - i := 0 - - for n > 0 { - l, err := buf.Read(p[i:]) - - require.NoError(t, err) - - n -= int64(l) - i += l - } -} - -func testCommit(t *testing.T, buf *Type) { - n, err := buf.ReadCommit(256) - - require.NoError(t, err) - require.Equal(t, 256, n) - - _, err = buf.ReadCommit(2048) - - require.Equal(t, ErrInsufficientData, err) -} - -func testReadBytes(t *testing.T, buf *Type) { - p := make([]byte, 256) - n, err := buf.Read(p) - - require.NoError(t, err) - require.Equal(t, 256, n) - - p2 := make([]byte, 4096) - n, err = buf.Read(p2) - - require.NoError(t, err) - require.Equal(t, 2048-256, n) -} - -func benchmarkRead(b *testing.B, buf *Type) { - n := int64(b.N) - - go func(n int64) { - p := make([]byte, n) - buf.ReadFrom(bytes.NewBuffer(p)) // nolint: errcheck - }(n) - - p := make([]byte, n) - i := 0 - - for n > 0 { - l, _ := buf.Read(p[i:]) - - n -= int64(l) - i += l - } -} diff --git a/connection/net.go b/connection/net.go index 5f6c21b..b772098 100644 --- a/connection/net.go +++ b/connection/net.go @@ -9,7 +9,7 @@ import ( "errors" "sync" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" "github.com/troian/surgemq/configuration" "github.com/troian/surgemq/packet" "github.com/troian/surgemq/systree" @@ -41,7 +41,7 @@ type netConfig struct { on onProcess // Conn is network connection - conn io.Closer + conn net.Conn // PacketsMetric interface to metric packets packetsMetric systree.PacketsMetric @@ -56,15 +56,34 @@ type netConfig struct { protoVersion packet.ProtocolVersion } +type keepAlive struct { + period time.Duration + conn net.Conn + timer *time.Timer +} + +func (k *keepAlive) Read(b []byte) (int, error) { + if k.period > 0 { + if !k.timer.Stop() { + <-k.timer.C + } + k.timer.Reset(k.period) + } + return k.conn.Read(b) +} + // netConn implementation of the connection type netConn struct { // Incoming data buffer. Bytes are read from the connection and put in here - in *buffer.Type - - // Outgoing data buffer. Bytes written here are in turn written out to the connection - out *buffer.Type + in *goring.Buffer config *netConfig + sendTicker *time.Timer + currLock sync.Mutex + currOutBuffer net.Buffers + outBuffers chan net.Buffers + keepAlive keepAlive + // Wait for the various goroutines to finish starting and stopping wg struct { routines struct { @@ -85,21 +104,10 @@ type netConn struct { // Quit signal for determining when this service should end. If channel is closed, then exit expireIn *time.Duration done chan struct{} - wmu sync.Mutex onStop types.Once will bool } -type netReader interface { - io.Reader - SetReadDeadline(t time.Time) error -} - -type timeoutReader struct { - d time.Duration - conn netReader -} - // newNet connection func newNet(config *netConfig) (f *netConn, err error) { defer func() { @@ -109,11 +117,15 @@ func newNet(config *netConfig) (f *netConn, err error) { }() f = &netConn{ - config: config, - done: make(chan struct{}), - will: true, + config: config, + done: make(chan struct{}), + will: true, + outBuffers: make(chan net.Buffers, 5), + sendTicker: time.NewTimer(5 * time.Millisecond), } + f.sendTicker.Stop() + f.log.prod = configuration.GetProdLogger().Named("session.conn." + config.id) f.log.dev = configuration.GetDevLogger().Named("session.conn." + config.id) @@ -121,15 +133,16 @@ func newNet(config *netConfig) (f *netConn, err error) { f.wg.conn.stopped.Add(1) // Create the incoming ring buffer - f.in, err = buffer.New(buffer.DefaultBufferSize) + f.in, err = goring.New(goring.DefaultBufferSize) if err != nil { return nil, err } - // Create the outgoing ring buffer - f.out, err = buffer.New(buffer.DefaultBufferSize) - if err != nil { - return nil, err + f.keepAlive.conn = f.config.conn + + if f.config.keepAlive > 0 { + f.keepAlive.period = time.Second * time.Duration(f.config.keepAlive) + f.keepAlive.period = f.keepAlive.period + (f.keepAlive.period / 2) } return f, nil @@ -139,7 +152,11 @@ func newNet(config *netConfig) (f *netConn, err error) { func (s *netConn) start() { defer s.wg.conn.started.Done() - s.wg.routines.stopped.Add(3) + if s.keepAlive.period > 0 { + s.wg.routines.stopped.Add(4) + } else { + s.wg.routines.stopped.Add(3) + } // these routines must start in specified order // and next proceed next one only when previous finished @@ -151,6 +168,12 @@ func (s *netConn) start() { go s.processIncoming() s.wg.routines.started.Wait() + if s.keepAlive.period > 0 { + s.wg.routines.started.Add(1) + go s.readTimeOutWorker() + s.wg.routines.started.Wait() + } + s.wg.routines.started.Add(1) go s.receiver() s.wg.routines.started.Wait() @@ -183,9 +206,7 @@ func (s *netConn) stop(reason *packet.ReasonCode) { s.log.prod.Error("close input buffer error", zap.String("ClientID", s.config.id), zap.Error(err)) } - if err := s.out.Close(); err != nil { - s.log.prod.Error("close output buffer error", zap.String("ClientID", s.config.id), zap.Error(err)) - } + s.sendTicker.Stop() // Wait for all the connection goroutines are finished s.wg.routines.stopped.Wait() @@ -196,14 +217,6 @@ func (s *netConn) stop(reason *packet.ReasonCode) { }) } -// Read -func (r timeoutReader) Read(b []byte) (int, error) { - if err := r.conn.SetReadDeadline(time.Now().Add(r.d)); err != nil { - return 0, err - } - return r.conn.Read(b) -} - // isDone func (s *netConn) isDone() bool { select { @@ -314,45 +327,61 @@ func (s *netConn) processIncoming() { } } +func (s *netConn) readTimeOutWorker() { + defer s.onRoutineReturn() + + s.keepAlive.timer = time.NewTimer(s.keepAlive.period) + s.wg.routines.started.Done() + + select { + case <-s.keepAlive.timer.C: + s.log.prod.Error("Keep alive timed out") + return + case <-s.done: + s.keepAlive.timer.Stop() + return + } +} + // receiver reads data from the network, and writes the data into the incoming buffer func (s *netConn) receiver() { defer s.onRoutineReturn() s.wg.routines.started.Done() - switch conn := s.config.conn.(type) { - case net.Conn: - keepAlive := time.Second * time.Duration(s.config.keepAlive) - r := timeoutReader{ - d: keepAlive + (keepAlive / 2), - conn: conn, - } - - for { - if _, err := s.in.ReadFrom(r); err != nil { - return - } - } - default: - s.log.prod.Error("Invalid connection type", zap.String("ClientID", s.config.id)) - } + s.in.ReadFrom(&s.keepAlive) // nolint: errcheck } // sender writes data from the outgoing buffer to the network func (s *netConn) sender() { defer s.onRoutineReturn() - s.wg.routines.started.Done() - switch conn := s.config.conn.(type) { - case net.Conn: - for { - if _, err := s.out.WriteTo(conn); err != nil { + for { + bufs := net.Buffers{} + select { + case <-s.sendTicker.C: + s.currLock.Lock() + s.outBuffers <- s.currOutBuffer + s.currOutBuffer = net.Buffers{} + s.currLock.Unlock() + case buf, ok := <-s.outBuffers: + s.sendTicker.Stop() + if !ok { + return + } + bufs = buf + case <-s.done: + s.sendTicker.Stop() + close(s.outBuffers) + return + } + + if len(bufs) > 0 { + if _, err := bufs.WriteTo(s.config.conn); err != nil { return } } - default: - s.log.prod.Error("Invalid connection type", zap.String("ClientID", s.config.id)) } } @@ -364,7 +393,7 @@ func (s *netConn) peekMessageSize() (packet.Type, int, error) { cnt := 2 if s.in == nil { - err = buffer.ErrNotReady + err = goring.ErrNotReady return 0, 0, err } @@ -420,7 +449,7 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { var msg packet.Provider if s.in == nil { - err = buffer.ErrNotReady + err = goring.ErrNotReady return nil, 0, err } @@ -445,7 +474,7 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { s.log.prod.Error("Incoming and outgoing length does not match", zap.Int("in", total), zap.Int("out", dTotal)) - return nil, 0, buffer.ErrNotReady + return nil, 0, goring.ErrNotReady } return msg, n, err @@ -454,26 +483,37 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { // WriteMessage writes a message to the outgoing buffer func (s *netConn) WriteMessage(msg packet.Provider, lastMessage bool) (int, error) { if s.isDone() { - return 0, buffer.ErrNotReady + return 0, goring.ErrNotReady } if lastMessage { close(s.done) } - defer s.wmu.Unlock() - s.wmu.Lock() + var total int + + expectedSize, err := msg.Size() + if err != nil { + return 0, err + } - if s.out == nil { - return 0, buffer.ErrNotReady + buf := make([]byte, expectedSize) + total, err = msg.Encode(buf) + if err != nil { + return 0, err } - var total int - var err error + s.currLock.Lock() + s.currOutBuffer = append(s.currOutBuffer, buf) + if len(s.currOutBuffer) == 1 { + s.sendTicker.Reset(1 * time.Millisecond) + } - if total, err = packet.WriteToBuffer(msg, s.out); err == nil { - s.config.packetsMetric.Sent(msg.Type()) + if len(s.currOutBuffer) == 10 { + s.outBuffers <- s.currOutBuffer + s.currOutBuffer = net.Buffers{} } + s.currLock.Unlock() return total, err } diff --git a/connection/netCallbacks.go b/connection/netCallbacks.go index 199c93b..5d4774d 100644 --- a/connection/netCallbacks.go +++ b/connection/netCallbacks.go @@ -29,9 +29,6 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { outMessages := [][]byte{} unAckMessages := [][]byte{} - //messages := s.publisher.messages.GetAll() - - //for _, v := range messages { var next *list.Element for elem := s.publisher.messages.Front(); elem != nil; elem = next { next = elem.Next() @@ -258,7 +255,7 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) error { t := kv.Key.(string) ops := kv.Value.(packet.SubscriptionOptions) - reason := packet.CodeSuccess + reason := packet.CodeSuccess // nolint: ineffassign //authorized := true // TODO: check permissions here diff --git a/examples/surgemq/surgemq.go b/examples/surgemq/surgemq.go index 12a4bcb..6632f2e 100644 --- a/examples/surgemq/surgemq.go +++ b/examples/surgemq/surgemq.go @@ -29,6 +29,7 @@ import ( "go.uber.org/zap" _ "net/http/pprof" + "runtime" _ "runtime/debug" ) @@ -44,7 +45,7 @@ func main() { var err error logger.Info("Starting application") - + logger.Info("Allocated cores", zap.Int("GOMAXPROCS", runtime.GOMAXPROCS(0))) viper.SetConfigName("config") viper.AddConfigPath("conf") viper.SetConfigType("json") @@ -119,17 +120,19 @@ func main() { logger.Error("Couldn't start listener", zap.Error(err)) } - configWs := transport.NewConfigWS( - &transport.Config{ - Port: 8080, - AuthManager: authMng, - }) - - if err = srv.ListenAndServe(configWs); err != nil { - logger.Error("Couldn't start listener", zap.Error(err)) - } + //configWs := transport.NewConfigWS( + // &transport.Config{ + // Port: 8080, + // AuthManager: authMng, + // }) + // + //if err = srv.ListenAndServe(configWs); err != nil { + // logger.Error("Couldn't start listener", zap.Error(err)) + //} go func() { + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) logger.Info(http.ListenAndServe("localhost:6061", nil).Error()) }() diff --git a/packet/connack_test.go b/packet/connack_test.go index dc22e06..50d8a60 100644 --- a/packet/connack_test.go +++ b/packet/connack_test.go @@ -18,7 +18,7 @@ import ( "testing" "github.com/stretchr/testify/require" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" ) func TestConnAckMessageFields(t *testing.T) { @@ -184,7 +184,7 @@ func TestConnAckEncodeEnsureSize(t *testing.T) { } func TestConnAckCodeWrite(t *testing.T) { - buf, err := buffer.New(16384) + buf, err := goring.New(16384) require.NoError(t, err) buf.ExternalBuf = make([]byte, 1) diff --git a/packet/routines.go b/packet/routines.go index 912f87a..673c76b 100644 --- a/packet/routines.go +++ b/packet/routines.go @@ -3,11 +3,11 @@ package packet import ( "encoding/binary" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" ) // WriteToBuffer encode and send message into ring buffer -func WriteToBuffer(msg Provider, to *buffer.Type) (int, error) { +func WriteToBuffer(msg Provider, to *goring.Buffer) (int, error) { expectedSize, err := msg.Size() if err != nil { return 0, err @@ -38,14 +38,10 @@ func ReadLPBytes(buf []byte) ([]byte, int, error) { n = int(binary.BigEndian.Uint16(buf)) total += 2 - if len(buf[total:]) < n { - return nil, total, ErrInsufficientDataSize - } - // Check for malformed length-prefixed field // if remaining space is less than length-prefixed size the packet seems to be broken if len(buf[total:]) < n { - return nil, total, ErrInvalidLength + return nil, total, ErrInsufficientDataSize } total += n diff --git a/routines/misc.go b/routines/misc.go index 31c8cd1..a443054 100644 --- a/routines/misc.go +++ b/routines/misc.go @@ -116,9 +116,3 @@ func WriteMessageBuffer(c io.Closer, b []byte) error { _, err := conn.Write(b) return err } - -// Copied from http://golang.org/src/pkg/net/timeout_test.go -//func isTimeout(err error) bool { -// e, ok := err.(net.Error) -// return ok && e.Timeout() -//} diff --git a/transport/base.go b/transport/base.go index 04133d8..605f773 100644 --- a/transport/base.go +++ b/transport/base.go @@ -115,6 +115,8 @@ func (c *baseConfig) handleConnection(conn conn) { } } } else { + // Disable read deadline. Will set it later if keep-alive interval is bigger than 0 + conn.SetReadDeadline(time.Time{}) // nolint: errcheck switch r := req.(type) { case *packet.Connect: m, _ := packet.NewMessage(req.Version(), packet.CONNACK)