diff --git a/conn.go b/conn.go index fe113c9e..e02ed0a4 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ import ( "github.com/lxzan/gws/internal" "net" "sync" + "sync/atomic" "time" ) @@ -83,6 +84,45 @@ func (c *Conn) Listen() { } } +func (c *Conn) emitError(err error) { + if err == nil { + return + } + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + c.handlerError(err, nil) + c.handler.OnError(c, err) + } +} + +func (c *Conn) handlerError(err error, buf *internal.Buffer) { + code := CloseNormalClosure + v, ok := err.(CloseCode) + if ok { + closeCode := v.Uint16() + if closeCode < 1000 || (closeCode >= 1016 && closeCode < 3000) { + code = CloseProtocolError + } else { + switch closeCode { + case 1004, 1005, 1006, 1014: + code = CloseProtocolError + default: + code = v + } + } + } + var content = code.Bytes() + if buf != nil { + content = append(content, buf.Bytes()...) + } else { + content = append(content, err.Error()...) + } + if len(content) > internal.Lv1 { + content = content[:internal.Lv1] + } + _ = c.writeMessage(OpcodeCloseConnection, content, true) + _ = c.conn.SetDeadline(time.Now()) +} + func (c *Conn) isCanceled() bool { select { case <-c.ctx.Done(): diff --git a/examples/bench/main.go b/examples/bench/main.go index 9234c039..6da7e3c5 100644 --- a/examples/bench/main.go +++ b/examples/bench/main.go @@ -50,7 +50,7 @@ func NewWebSocket() *WebSocket { } func (c *WebSocket) OnClose(socket *gws.Conn, message *gws.Message) { - fmt.Printf("onclose: code=%d, payload=%s", message.Code(), string(message.Bytes())) + fmt.Printf("onclose: code=%d, payload=%s\n", message.Code(), string(message.Bytes())) } type WebSocket struct{} diff --git a/frame.go b/frame.go index 4413a386..c00131df 100644 --- a/frame.go +++ b/frame.go @@ -39,13 +39,10 @@ func (c *Message) Bytes() []byte { } func payloadValid(opcode Opcode, buf *internal.Buffer) bool { - if buf.Len() == 0 { + if buf.Len() == 0 && !(opcode == OpcodeCloseConnection || opcode == OpcodeText) { return true } - if opcode == OpcodeCloseConnection || opcode == OpcodeText { - return utf8.Valid(buf.Bytes()) - } - return true + return utf8.Valid(buf.Bytes()) } func maskXOR(b []byte, key []byte) { diff --git a/reader.go b/reader.go index f7da7f5e..7caad9dd 100644 --- a/reader.go +++ b/reader.go @@ -189,7 +189,7 @@ func (c *Conn) readMessage() error { func (c *Conn) emitMessage(msg *Message, compressed bool) error { if atomic.LoadUint32(&c.closed) == 1 { - return nil + return CloseNormalClosure } if c.isCanceled() { return CloseServiceRestart @@ -222,8 +222,10 @@ func (c *Conn) emitMessage(msg *Message, compressed bool) error { c.handler.OnMessage(c, msg) case OpcodeCloseConnection: if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + code := msg.Code() c.handlerError(msg.Code(), msg.buf) c.handler.OnClose(c, msg) + return code } } return nil diff --git a/writer.go b/writer.go index f6317838..3a954dd1 100644 --- a/writer.go +++ b/writer.go @@ -3,7 +3,6 @@ package gws import ( "github.com/lxzan/gws/internal" "io" - "sync/atomic" "time" ) @@ -21,45 +20,6 @@ func writeN(writer io.Writer, content []byte, n int) error { return nil } -func (c *Conn) emitError(err error) { - if err == nil { - return - } - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - c.handlerError(err, nil) - c.handler.OnError(c, err) - } -} - -func (c *Conn) handlerError(err error, buf *internal.Buffer) { - code := CloseNormalClosure - v, ok := err.(CloseCode) - if ok { - closeCode := v.Uint16() - if closeCode < 1000 || (closeCode >= 1016 && closeCode < 3000) { - code = CloseProtocolError - } else { - switch closeCode { - case 1004, 1005, 1006, 1014: - code = CloseProtocolError - default: - code = v - } - } - } - var content = code.Bytes() - if buf != nil { - content = append(content, buf.Bytes()...) - } else { - content = append(content, err.Error()...) - } - if len(content) > internal.Lv1 { - content = content[:internal.Lv1] - } - _ = c.writeMessage(OpcodeCloseConnection, content, true) - _ = c.conn.SetDeadline(time.Now()) -} - // WriteClose write close frame // 发送关闭帧 func (c *Conn) WriteClose(code CloseCode, reason []byte) {