From aff9d139fa8a1f7a3698d322ea0fb0463122aef5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 01:14:02 -0500 Subject: [PATCH] Switch to stateless compression with klauspost/compress --- compress_notjs.go | 7 +++++++ conn_notjs.go | 1 - conn_test.go | 2 ++ go.mod | 1 + go.sum | 2 ++ read.go | 6 +----- write.go | 49 +++++++++++++++-------------------------------- 7 files changed, 28 insertions(+), 40 deletions(-) diff --git a/compress_notjs.go b/compress_notjs.go index 6ab6e284..270a064a 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -125,6 +125,13 @@ func newSlidingWindow(n int) *slidingWindow { } } +func (w *slidingWindow) getBuf() []byte { + if w == nil { + return nil + } + return w.buf +} + func (w *slidingWindow) write(p []byte) { if len(p) >= cap(w.buf) { w.buf = w.buf[:cap(w.buf)] diff --git a/conn_notjs.go b/conn_notjs.go index 4d8762bf..05ef862b 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -141,7 +141,6 @@ func (c *Conn) close(err error) { c.writeFrameMu.Lock(context.Background()) putBufioWriter(c.bw) } - c.msgWriter.close() c.msgReader.close() if c.client { diff --git a/conn_test.go b/conn_test.go index e1e6c35c..9d924ca9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -351,6 +351,8 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { + defer c.Close(websocket.StatusInternalError, "") + for { _, _, err := c.Read(ctx) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { diff --git a/go.mod b/go.mod index cb372391..a10c7b1e 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/golang/protobuf v1.3.3 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 + github.com/klauspost/compress v1.10.0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 8cbc66ce..e4bbd62d 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= +github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/read.go b/read.go index a9c291d1..b6027450 100644 --- a/read.go +++ b/read.go @@ -91,11 +91,7 @@ func (mr *msgReader) resetFlate() { mr.dict = newSlidingWindow(32768) } - if mr.flateContextTakeover() { - mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) - } else { - mr.flateReader = getFlateReader(readerFunc(mr.read), nil) - } + mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.getBuf()) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } diff --git a/write.go b/write.go index 9d4b670f..96bd7d24 100644 --- a/write.go +++ b/write.go @@ -4,14 +4,13 @@ package websocket import ( "bufio" - "compress/flate" "context" "crypto/rand" "encoding/binary" "io" - "sync" "time" + "github.com/klauspost/compress/flate" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" @@ -51,16 +50,15 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { type msgWriter struct { c *Conn - mu *mu - writeMu sync.Mutex + mu *mu ctx context.Context opcode opcode closed bool flate bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer + trimWriter *trimLastFourBytesWriter + dict *slidingWindow } func newMsgWriter(c *Conn) *msgWriter { @@ -72,16 +70,16 @@ func newMsgWriter(c *Conn) *msgWriter { } func (mw *msgWriter) ensureFlate() { + if mw.flateContextTakeover() && mw.dict == nil { + mw.dict = newSlidingWindow(8192) + } + if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } } - if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) - } - mw.flate = true } @@ -138,20 +136,10 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { return nil } -func (mw *msgWriter) returnFlateWriter() { - if mw.flateWriter != nil { - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil - } -} - // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") - mw.writeMu.Lock() - defer mw.writeMu.Unlock() - if mw.closed { return 0, xerrors.New("cannot use closed writer") } @@ -165,7 +153,11 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } if mw.flate { - return mw.flateWriter.Write(p) + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.getBuf()) + if mw.flateContextTakeover() { + mw.dict.write(p) + } + return len(p), err } return mw.write(p) @@ -184,17 +176,14 @@ func (mw *msgWriter) write(p []byte) (int, error) { func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") - mw.writeMu.Lock() - defer mw.writeMu.Unlock() - if mw.closed { return xerrors.New("cannot use closed writer") } if mw.flate { - err = mw.flateWriter.Flush() + err = flate.StatelessDeflate(mw.trimWriter, nil, true, mw.dict.getBuf()) if err != nil { - return xerrors.Errorf("failed to flush flate writer: %w", err) + return xerrors.Errorf("failed to flush flate: %w", err) } } @@ -207,18 +196,10 @@ func (mw *msgWriter) Close() (err error) { return xerrors.Errorf("failed to write fin frame: %w", err) } - if mw.flate && !mw.flateContextTakeover() { - mw.returnFlateWriter() - } mw.mu.Unlock() return nil } -func (mw *msgWriter) close() { - mw.writeMu.Lock() - mw.returnFlateWriter() -} - func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel()