diff --git a/plugins/transport/websocket/README.md b/plugins/transport/websocket/README.md new file mode 100644 index 0000000..df30598 --- /dev/null +++ b/plugins/transport/websocket/README.md @@ -0,0 +1,18 @@ +# Websocket + +## 什么是WebSocket? + +WebSocket 协议主要为了解决基于 HTTP/1.x 的 Web 应用无法实现服务端向客户端主动推送的问题, 为了兼容现有的设施, WebSocket 协议使用与 HTTP 协议相同的端口, 并使用 HTTP Upgrade 机制来进行 WebSocket 握手, 当握手完成之后, 通信双方便可以按照 WebSocket 协议的方式进行交互 + +WebSocket 使用 TCP 作为传输层协议, 与 HTTP 类似, WebSocket 也支持在 TCP 上层引入 TLS 层, 以建立加密数据传输通道, 即 WebSocket over TLS, WebSocket 的 URI 与 HTTP URI 的结构类似, 对于使用 80 端口的 WebSocket over TCP, 其 URI 的一般形式为 `ws://host:port/path/query` 对于使用 443 端口的 WebSocket over TLS, 其 URI 的一般形式为 `wss://host:port/path/query` + +在 WebSocket 协议中, 帧 (frame) 是通信双方数据传输的基本单元, 与其它网络协议相同, frame 由 Header 和 Payload 两部分构成, frame 有多种类型, frame 的类型由其头部的 Opcode 字段 (将在下面讨论) 来指示, WebSocket 的 frame 可以分为两类, 一类是用于传输控制信息的 frame (如通知对方关闭 WebSocket 连接), 一类是用于传输应用数据的 frame, 使用 WebSocket 协议通信的双方都需要首先进行握手, 只有当握手成功之后才开始使用 frame 传输数据 + +## 参考资料 + +* [RFC 6455 - The WebSocket Protocol](https://tools.ietf.org/html/rfc6455) +* [wikipedia - WebSocket](https://en.wikipedia.org/wiki/WebSocket) +* [HTML5 WebSocket](https://www.runoob.com/html/html5-websocket.html) +* [MDN - WebSocket](https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket) +* [WebSocket 协议解析 [RFC 6455]](https://sunyunqiang.com/blog/websocket_protocol_rfc6455/) +* [WebSocket 教程](https://www.ruanyifeng.com/blog/2017/05/websocket.html) diff --git a/plugins/transport/websocket/client.go b/plugins/transport/websocket/client.go new file mode 100644 index 0000000..3e4f2a3 --- /dev/null +++ b/plugins/transport/websocket/client.go @@ -0,0 +1,313 @@ +package websocket + +import ( + "encoding/json" + "errors" + "net/url" + "time" + + "github.com/go-kratos/kratos/v2/encoding" + + ws "github.com/gorilla/websocket" + + "github.com/tx7do/kratos-transport/broker" +) + +type ClientMessageHandler func(MessagePayload) error + +type ClientHandlerData struct { + Handler ClientMessageHandler + Binder Binder +} +type ClientMessageHandlerMap map[MessageType]*ClientHandlerData + +type Client struct { + conn *ws.Conn + + url string + endpoint *url.URL + + codec encoding.Codec + messageHandlers ClientMessageHandlerMap + + timeout time.Duration + + payloadType PayloadType +} + +func NewClient(opts ...ClientOption) *Client { + cli := &Client{ + url: "", + timeout: 1 * time.Second, + codec: encoding.GetCodec("json"), + messageHandlers: make(ClientMessageHandlerMap), + payloadType: PayloadTypeBinary, + } + + cli.init(opts...) + + return cli +} + +func (c *Client) init(opts ...ClientOption) { + for _, o := range opts { + o(c) + } + + c.endpoint, _ = url.Parse(c.url) +} + +func (c *Client) Connect() error { + if c.endpoint == nil { + return errors.New("endpoint is nil") + } + + LogInfof("connecting to %s", c.endpoint.String()) + + conn, resp, err := ws.DefaultDialer.Dial(c.endpoint.String(), nil) + if err != nil { + LogErrorf("%s [%v]", err.Error(), resp) + return err + } + c.conn = conn + + go c.run() + + return nil +} + +func (c *Client) Disconnect() { + if c.conn != nil { + if err := c.conn.Close(); err != nil { + LogErrorf("disconnect error: %s", err.Error()) + } + c.conn = nil + } +} + +func (c *Client) RegisterMessageHandler(messageType MessageType, handler ClientMessageHandler, binder Binder) { + if _, ok := c.messageHandlers[messageType]; ok { + return + } + + c.messageHandlers[messageType] = &ClientHandlerData{handler, binder} +} + +func RegisterClientMessageHandler[T any](cli *Client, messageType MessageType, handler func(*T) error) { + cli.RegisterMessageHandler(messageType, + func(payload MessagePayload) error { + switch t := payload.(type) { + case *T: + return handler(t) + default: + LogError("invalid payload struct type:", t) + return errors.New("invalid payload struct type") + } + }, + func() Any { + var t T + return &t + }, + ) +} + +func (c *Client) DeregisterMessageHandler(messageType MessageType) { + delete(c.messageHandlers, messageType) +} + +func (c *Client) marshalMessage(messageType MessageType, message MessagePayload) ([]byte, error) { + var err error + var buff []byte + + switch c.payloadType { + case PayloadTypeBinary: + var msg BinaryMessage + msg.Type = messageType + msg.Body, err = broker.Marshal(c.codec, message) + if err != nil { + return nil, err + } + buff, err = msg.Marshal() + if err != nil { + return nil, err + } + break + + case PayloadTypeText: + var buf []byte + var msg TextMessage + msg.Type = messageType + buf, err = broker.Marshal(c.codec, message) + msg.Body = string(buf) + if err != nil { + return nil, err + } + buff, err = json.Marshal(msg) + if err != nil { + return nil, err + } + break + } + + //LogInfo("marshalMessage:", string(buff)) + + return buff, nil +} + +func (c *Client) SendMessage(messageType MessageType, message interface{}) error { + buff, err := c.marshalMessage(messageType, message) + if err != nil { + LogError("marshal message exception:", err) + return err + } + + switch c.payloadType { + case PayloadTypeBinary: + if err = c.sendBinaryMessage(buff); err != nil { + return err + } + break + + case PayloadTypeText: + if err = c.sendTextMessage(string(buff)); err != nil { + return err + } + break + } + + return nil +} + +func (c *Client) sendPingMessage(message string) error { + return c.conn.WriteMessage(ws.PingMessage, []byte(message)) +} + +func (c *Client) sendPongMessage(message string) error { + return c.conn.WriteMessage(ws.PongMessage, []byte(message)) +} + +func (c *Client) sendTextMessage(message string) error { + return c.conn.WriteMessage(ws.TextMessage, []byte(message)) +} + +func (c *Client) sendBinaryMessage(message []byte) error { + return c.conn.WriteMessage(ws.BinaryMessage, message) +} + +func (c *Client) run() { + defer c.Disconnect() + + for { + messageType, data, err := c.conn.ReadMessage() + if err != nil { + if ws.IsUnexpectedCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseAbnormalClosure) { + LogErrorf("read message error: %v", err) + } + return + } + + switch messageType { + case ws.CloseMessage: + return + + case ws.BinaryMessage: + _ = c.messageHandler(data) + break + + case ws.TextMessage: + _ = c.messageHandler(data) + break + + case ws.PingMessage: + if err := c.sendPongMessage(""); err != nil { + LogError("write pong message error: ", err) + return + } + break + + case ws.PongMessage: + break + } + + } +} + +func (c *Client) unmarshalMessage(buf []byte) (*ClientHandlerData, MessagePayload, error) { + var handler *ClientHandlerData + var payload MessagePayload + + switch c.payloadType { + case PayloadTypeBinary: + var msg BinaryMessage + if err := msg.Unmarshal(buf); err != nil { + LogErrorf("decode message exception: %s", err) + return nil, nil, err + } + + var ok bool + handler, ok = c.messageHandlers[msg.Type] + if !ok { + LogError("message handler not found:", msg.Type) + return nil, nil, errors.New("message handler not found") + } + + if handler.Binder != nil { + payload = handler.Binder() + } else { + payload = msg.Body + } + + if err := broker.Unmarshal(c.codec, msg.Body, &payload); err != nil { + LogErrorf("unmarshal message exception: %s", err) + return nil, nil, err + } + //LogDebug(string(msg.Body)) + + case PayloadTypeText: + var msg TextMessage + if err := msg.Unmarshal(buf); err != nil { + LogErrorf("decode message exception: %s", err) + return nil, nil, err + } + + var ok bool + handler, ok = c.messageHandlers[msg.Type] + if !ok { + LogError("message handler not found:", msg.Type) + return nil, nil, errors.New("message handler not found") + } + + if handler.Binder != nil { + payload = handler.Binder() + } else { + payload = msg.Body + } + + if err := broker.Unmarshal(c.codec, []byte(msg.Body), &payload); err != nil { + LogErrorf("unmarshal message exception: %s", err) + return nil, nil, err + } + //LogDebug(string(msg.Body)) + } + + return handler, payload, nil +} + +func (c *Client) messageHandler(buf []byte) error { + var err error + var handler *ClientHandlerData + var payload MessagePayload + + if handler, payload, err = c.unmarshalMessage(buf); err != nil { + LogErrorf("unmarshal message failed: %s", err) + return err + } + //LogDebug(payload) + + if err = handler.Handler(payload); err != nil { + LogErrorf("message handler exception: %s", err) + return err + } + + return nil +} diff --git a/plugins/transport/websocket/client_test.go b/plugins/transport/websocket/client_test.go new file mode 100644 index 0000000..55942d4 --- /dev/null +++ b/plugins/transport/websocket/client_test.go @@ -0,0 +1,81 @@ +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "fmt" + "io" + "os" + "os/signal" + "syscall" + "testing" +) + +var testClient *Client + +func handleClientChatMessage(message *ChatMessage) error { + fmt.Printf("Payload: %v\n", message) + _ = sendChatMessage(message.Sender, message.Message) + return nil +} + +func sendChatMessage(sender, msg string) error { + chatMsg := &ChatMessage{ + Sender: sender, + Message: msg, + } + return testClient.SendMessage(MessageTypeChat, chatMsg) +} + +func TestClient(t *testing.T) { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + cli := NewClient( + WithEndpoint("ws://localhost:10000/"), + WithClientCodec("json"), + //WithClientPayloadType(PayloadTypeText), + ) + defer cli.Disconnect() + + testClient = cli + + RegisterClientMessageHandler(cli, MessageTypeChat, handleClientChatMessage) + + err := cli.Connect() + if err != nil { + t.Error(err) + } + + _ = sendChatMessage("ws", "Hello, World!") + + <-interrupt +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func Test1(t *testing.T) { + challengeKey, _ := generateChallengeKey() + fmt.Println(computeAcceptKey(challengeKey)) + + fmt.Println(computeAcceptKey("foIGUMVOg/QOba9qZkaCmg==")) + fmt.Println(computeAcceptKey("UHF9V2jktxC//1zmwLnxMg==")) + fmt.Println(computeAcceptKey("KWtssYGuj2uQiv7bG7tc7A==")) + fmt.Println(computeAcceptKey("3G4O+cC9DDGJS9pJAhzpUA==")) +} diff --git a/plugins/transport/websocket/encoding.go b/plugins/transport/websocket/encoding.go new file mode 100644 index 0000000..708bc8c --- /dev/null +++ b/plugins/transport/websocket/encoding.go @@ -0,0 +1 @@ +package websocket diff --git a/plugins/transport/websocket/go.mod b/plugins/transport/websocket/go.mod new file mode 100644 index 0000000..9e7f73b --- /dev/null +++ b/plugins/transport/websocket/go.mod @@ -0,0 +1,42 @@ +module github.com/tx7do/kratos-transport/transport/websocket + +go 1.19 + +require ( + github.com/go-kratos/kratos/v2 v2.7.0 + github.com/google/uuid v1.3.1 + github.com/gorilla/websocket v1.5.0 + github.com/stretchr/testify v1.8.4 + github.com/tx7do/kratos-transport v1.0.12 +) + +require ( + github.com/cenkalti/backoff/v4 v4.2.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logr/logr v1.2.4 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/form/v4 v4.2.1 // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.17.1 // indirect + github.com/openzipkin/zipkin-go v0.4.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.opentelemetry.io/otel v1.17.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.17.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.17.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.17.0 // indirect + go.opentelemetry.io/otel/exporters/zipkin v1.17.0 // indirect + go.opentelemetry.io/otel/metric v1.17.0 // indirect + go.opentelemetry.io/otel/sdk v1.17.0 // indirect + go.opentelemetry.io/otel/trace v1.17.0 // indirect + go.opentelemetry.io/proto/otlp v1.0.0 // indirect + golang.org/x/net v0.15.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect + google.golang.org/grpc v1.58.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/tx7do/kratos-transport => ../../ diff --git a/plugins/transport/websocket/logger.go b/plugins/transport/websocket/logger.go new file mode 100644 index 0000000..79b5291 --- /dev/null +++ b/plugins/transport/websocket/logger.go @@ -0,0 +1,59 @@ +package websocket + +import ( + "fmt" + + "github.com/go-kratos/kratos/v2/log" +) + +const ( + logKey = "websocket" +) + +/// +/// logger +/// + +func LogDebug(args ...interface{}) { + _ = log.GetLogger().Log(log.LevelDebug, logKey, fmt.Sprint(args...)) +} + +func LogInfo(args ...interface{}) { + _ = log.GetLogger().Log(log.LevelInfo, logKey, fmt.Sprint(args...)) +} + +func LogWarn(args ...interface{}) { + _ = log.GetLogger().Log(log.LevelWarn, logKey, fmt.Sprint(args...)) +} + +func LogError(args ...interface{}) { + _ = log.GetLogger().Log(log.LevelError, logKey, fmt.Sprint(args...)) +} + +func LogFatal(args ...interface{}) { + _ = log.GetLogger().Log(log.LevelFatal, logKey, fmt.Sprint(args...)) +} + +/// +/// logger +/// + +func LogDebugf(format string, args ...interface{}) { + _ = log.GetLogger().Log(log.LevelDebug, logKey, fmt.Sprintf(format, args...)) +} + +func LogInfof(format string, args ...interface{}) { + _ = log.GetLogger().Log(log.LevelInfo, logKey, fmt.Sprintf(format, args...)) +} + +func LogWarnf(format string, args ...interface{}) { + _ = log.GetLogger().Log(log.LevelWarn, logKey, fmt.Sprintf(format, args...)) +} + +func LogErrorf(format string, args ...interface{}) { + _ = log.GetLogger().Log(log.LevelError, logKey, fmt.Sprintf(format, args...)) +} + +func LogFatalf(format string, args ...interface{}) { + _ = log.GetLogger().Log(log.LevelFatal, logKey, fmt.Sprintf(format, args...)) +} diff --git a/plugins/transport/websocket/message.go b/plugins/transport/websocket/message.go new file mode 100644 index 0000000..e7d16da --- /dev/null +++ b/plugins/transport/websocket/message.go @@ -0,0 +1,51 @@ +package websocket + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Any interface{} +type MessageType uint32 +type MessagePayload Any + +type BinaryMessage struct { + Type MessageType + Body []byte +} + +func (m *BinaryMessage) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := binary.Write(buf, binary.LittleEndian, uint32(m.Type)); err != nil { + return nil, err + } + buf.Write(m.Body) + return buf.Bytes(), nil +} + +func (m *BinaryMessage) Unmarshal(buf []byte) error { + network := new(bytes.Buffer) + network.Write(buf) + + if err := binary.Read(network, binary.LittleEndian, &m.Type); err != nil { + return err + } + + m.Body = network.Bytes() + + return nil +} + +type TextMessage struct { + Type MessageType `json:"type" xml:"type"` + Body string `json:"body" xml:"body"` +} + +func (m *TextMessage) Marshal() ([]byte, error) { + return json.Marshal(m) +} + +func (m *TextMessage) Unmarshal(buf []byte) error { + return json.Unmarshal(buf, m) +} diff --git a/plugins/transport/websocket/options.go b/plugins/transport/websocket/options.go new file mode 100644 index 0000000..9b025cd --- /dev/null +++ b/plugins/transport/websocket/options.go @@ -0,0 +1,100 @@ +package websocket + +import ( + "crypto/tls" + "net" + "time" + + "github.com/go-kratos/kratos/v2/encoding" +) + +type PayloadType uint8 + +const ( + PayloadTypeBinary = 0 + PayloadTypeText = 1 +) + +type ServerOption func(o *Server) + +func WithNetwork(network string) ServerOption { + return func(s *Server) { + s.network = network + } +} + +func WithAddress(addr string) ServerOption { + return func(s *Server) { + s.address = addr + } +} + +func WithTimeout(timeout time.Duration) ServerOption { + return func(s *Server) { + s.timeout = timeout + } +} + +func WithPath(path string) ServerOption { + return func(s *Server) { + s.path = path + } +} + +func WithConnectHandle(h ConnectHandler) ServerOption { + return func(s *Server) { + s.sessionMgr.RegisterConnectHandler(h) + } +} + +func WithTLSConfig(c *tls.Config) ServerOption { + return func(o *Server) { + o.tlsConf = c + } +} + +func WithListener(lis net.Listener) ServerOption { + return func(s *Server) { + s.lis = lis + } +} + +func WithCodec(c string) ServerOption { + return func(s *Server) { + s.codec = encoding.GetCodec(c) + } +} + +func WithChannelBufferSize(size int) ServerOption { + return func(_ *Server) { + channelBufSize = size + } +} + +func WithPayloadType(payloadType PayloadType) ServerOption { + return func(s *Server) { + s.payloadType = payloadType + } +} + +//////////////////////////////////////////////////////////////////////////////// + +type ClientOption func(o *Client) + +func WithClientCodec(c string) ClientOption { + return func(o *Client) { + o.codec = encoding.GetCodec(c) + } +} + +func WithEndpoint(uri string) ClientOption { + return func(o *Client) { + o.url = uri + } +} + +func WithClientPayloadType(payloadType PayloadType) ClientOption { + return func(c *Client) { + c.payloadType = payloadType + } +} diff --git a/plugins/transport/websocket/server.go b/plugins/transport/websocket/server.go new file mode 100644 index 0000000..805f421 --- /dev/null +++ b/plugins/transport/websocket/server.go @@ -0,0 +1,393 @@ +package websocket + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-kratos/kratos/v2/encoding" + "github.com/go-kratos/kratos/v2/transport" + + ws "github.com/gorilla/websocket" + + "github.com/tx7do/kratos-transport/broker" +) + +type Binder func() Any + +type ConnectHandler func(SessionID, bool) + +type MessageHandler func(SessionID, MessagePayload) error + +type HandlerData struct { + Handler MessageHandler + Binder Binder +} +type MessageHandlerMap map[MessageType]*HandlerData + +var ( + _ transport.Server = (*Server)(nil) + _ transport.Endpointer = (*Server)(nil) +) + +type Server struct { + *http.Server + + lis net.Listener + tlsConf *tls.Config + upgrader *ws.Upgrader + + network string + address string + path string + strictSlash bool + + timeout time.Duration + + err error + codec encoding.Codec + + messageHandlers MessageHandlerMap + + sessionMgr *SessionManager + + register chan *Session + unregister chan *Session + + payloadType PayloadType +} + +func NewServer(opts ...ServerOption) *Server { + srv := &Server{ + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + strictSlash: true, + path: "/", + + messageHandlers: make(MessageHandlerMap), + + sessionMgr: NewSessionManager(), + upgrader: &ws.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, + }, + + register: make(chan *Session), + unregister: make(chan *Session), + + payloadType: PayloadTypeBinary, + } + + srv.init(opts...) + + srv.err = srv.listen() + + return srv +} + +func (s *Server) Name() string { + return string(KindWebsocket) +} + +func (s *Server) init(opts ...ServerOption) { + for _, o := range opts { + o(s) + } + + s.Server = &http.Server{ + TLSConfig: s.tlsConf, + } + + http.HandleFunc(s.path, s.wsHandler) +} + +func (s *Server) SessionCount() int { + return s.sessionMgr.Count() +} + +func (s *Server) RegisterMessageHandler(messageType MessageType, handler MessageHandler, binder Binder) { + if _, ok := s.messageHandlers[messageType]; ok { + return + } + + s.messageHandlers[messageType] = &HandlerData{ + handler, binder, + } +} + +func RegisterServerMessageHandler[T any](srv *Server, messageType MessageType, handler func(SessionID, *T) error) { + srv.RegisterMessageHandler(messageType, + func(sessionId SessionID, payload MessagePayload) error { + switch t := payload.(type) { + case *T: + return handler(sessionId, t) + default: + LogError("invalid payload struct type:", t) + return errors.New("invalid payload struct type") + } + }, + func() Any { + var t T + return &t + }, + ) +} + +func (s *Server) DeregisterMessageHandler(messageType MessageType) { + delete(s.messageHandlers, messageType) +} + +func (s *Server) marshalMessage(messageType MessageType, message MessagePayload) ([]byte, error) { + var err error + var buff []byte + + switch s.payloadType { + case PayloadTypeBinary: + var msg BinaryMessage + msg.Type = messageType + msg.Body, err = broker.Marshal(s.codec, message) + if err != nil { + return nil, err + } + buff, err = msg.Marshal() + if err != nil { + return nil, err + } + break + + case PayloadTypeText: + var buf []byte + var msg TextMessage + msg.Type = messageType + buf, err = broker.Marshal(s.codec, message) + msg.Body = string(buf) + if err != nil { + return nil, err + } + buff, err = json.Marshal(msg) + if err != nil { + return nil, err + } + break + } + + //LogInfo("marshalMessage:", string(buff)) + + return buff, nil +} + +func (s *Server) SendMessage(sessionId SessionID, messageType MessageType, message MessagePayload) { + c, ok := s.sessionMgr.Get(sessionId) + if !ok { + LogError("session not found:", sessionId) + return + } + + switch s.payloadType { + case PayloadTypeBinary: + buf, err := s.marshalMessage(messageType, message) + if err != nil { + LogError("marshal message exception:", err) + return + } + + c.SendMessage(buf) + break + + case PayloadTypeText: + buf, err := s.codec.Marshal(message) + if err != nil { + LogError("marshal message exception:", err) + return + } + + c.SendMessage(buf) + break + } + +} + +func (s *Server) Broadcast(messageType MessageType, message MessagePayload) { + buf, err := s.marshalMessage(messageType, message) + if err != nil { + LogError(" marshal message exception:", err) + return + } + + s.sessionMgr.Range(func(session *Session) { + session.SendMessage(buf) + }) +} + +func (s *Server) unmarshalMessage(buf []byte) (*HandlerData, MessagePayload, error) { + var handler *HandlerData + var payload MessagePayload + + switch s.payloadType { + case PayloadTypeBinary: + var msg BinaryMessage + if err := msg.Unmarshal(buf); err != nil { + LogErrorf("decode message exception: %s", err) + return nil, nil, err + } + + var ok bool + handler, ok = s.messageHandlers[msg.Type] + if !ok { + LogError("message handler not found:", msg.Type) + return nil, nil, errors.New("message handler not found") + } + + if handler.Binder != nil { + payload = handler.Binder() + } else { + payload = msg.Body + } + + if err := broker.Unmarshal(s.codec, msg.Body, &payload); err != nil { + LogErrorf("unmarshal message exception: %s", err) + return nil, nil, err + } + //LogDebug(string(msg.Body)) + + case PayloadTypeText: + var msg TextMessage + if err := msg.Unmarshal(buf); err != nil { + LogErrorf("decode message exception: %s", err) + return nil, nil, err + } + + var ok bool + handler, ok = s.messageHandlers[msg.Type] + if !ok { + LogError("message handler not found:", msg.Type) + return nil, nil, errors.New("message handler not found") + } + + if handler.Binder != nil { + payload = handler.Binder() + } else { + payload = msg.Body + } + + if err := broker.Unmarshal(s.codec, []byte(msg.Body), &payload); err != nil { + LogErrorf("unmarshal message exception: %s", err) + return nil, nil, err + } + //LogDebug(string(msg.Body)) + } + + return handler, payload, nil +} + +func (s *Server) messageHandler(sessionId SessionID, buf []byte) error { + var err error + var handler *HandlerData + var payload MessagePayload + + if handler, payload, err = s.unmarshalMessage(buf); err != nil { + LogErrorf("unmarshal message failed: %s", err) + return err + } + //LogDebug(payload) + + if err = handler.Handler(sessionId, payload); err != nil { + LogErrorf("message handler failed: %s", err) + return err + } + + return nil +} + +func (s *Server) wsHandler(res http.ResponseWriter, req *http.Request) { + conn, err := s.upgrader.Upgrade(res, req, nil) + if err != nil { + LogError("upgrade exception:", err) + return + } + + session := NewSession(conn, s) + session.server.register <- session + + session.Listen() +} + +func (s *Server) listen() error { + if s.lis == nil { + lis, err := net.Listen(s.network, s.address) + if err != nil { + s.err = err + return err + } + s.lis = lis + } + + return nil +} + +func (s *Server) Endpoint() (*url.URL, error) { + addr := s.address + + prefix := "ws://" + if s.tlsConf == nil { + if !strings.HasPrefix(addr, "ws://") { + prefix = "ws://" + } + } else { + if !strings.HasPrefix(addr, "wss://") { + prefix = "wss://" + } + } + addr = prefix + addr + + var endpoint *url.URL + endpoint, s.err = url.Parse(addr) + return endpoint, nil +} + +func (s *Server) run() { + for { + select { + case client := <-s.register: + s.sessionMgr.Add(client) + case client := <-s.unregister: + s.sessionMgr.Remove(client) + } + } +} + +func (s *Server) Start(ctx context.Context) error { + if s.err != nil { + return s.err + } + s.BaseContext = func(net.Listener) context.Context { + return ctx + } + LogInfof("server listening on: %s", s.lis.Addr().String()) + + go s.run() + + var err error + if s.tlsConf != nil { + err = s.ServeTLS(s.lis, "", "") + } else { + err = s.Serve(s.lis) + } + if !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +func (s *Server) Stop(ctx context.Context) error { + LogInfo("server stopping") + return s.Shutdown(ctx) +} diff --git a/plugins/transport/websocket/server_test.go b/plugins/transport/websocket/server_test.go new file mode 100644 index 0000000..075a6ac --- /dev/null +++ b/plugins/transport/websocket/server_test.go @@ -0,0 +1,101 @@ +package websocket + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "os" + "os/signal" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testServer *Server + +const ( + MessageTypeChat = iota + 1 +) + +type ChatMessage struct { + Type int `json:"type"` + Sender string `json:"sender"` + Message string `json:"message"` +} + +func handleConnect(sessionId SessionID, register bool) { + if register { + fmt.Printf("%s registered\n", sessionId) + } else { + fmt.Printf("%s unregistered\n", sessionId) + } +} + +func handleChatMessage(sessionId SessionID, message *ChatMessage) error { + fmt.Printf("[%s] Payload: %v\n", sessionId, message) + + testServer.Broadcast(MessageTypeChat, *message) + + return nil +} + +func TestServer(t *testing.T) { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + ctx := context.Background() + + srv := NewServer( + WithAddress(":10000"), + WithPath("/"), + WithConnectHandle(handleConnect), + WithCodec("json"), + //WithPayloadType(PayloadTypeText), + ) + + RegisterServerMessageHandler(srv, MessageTypeChat, handleChatMessage) + + testServer = srv + + if err := srv.Start(ctx); err != nil { + panic(err) + } + + defer func() { + if err := srv.Stop(ctx); err != nil { + t.Errorf("expected nil got %v", err) + } + }() + + <-interrupt +} + +func TestGob(t *testing.T) { + var msg BinaryMessage + msg.Type = MessageTypeChat + msg.Body = []byte("") + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + _ = enc.Encode(msg) + + fmt.Printf("%s\n", string(buf.Bytes())) +} + +func TestMessageMarshal(t *testing.T) { + var msg BinaryMessage + msg.Type = 10000 + msg.Body = []byte("Hello World") + + buf, err := msg.Marshal() + assert.Nil(t, err) + + fmt.Printf("%s\n", string(buf)) + + var msg1 BinaryMessage + _ = msg1.Unmarshal(buf) + + fmt.Printf("[%d] [%s]\n", msg1.Type, string(msg1.Body)) +} diff --git a/plugins/transport/websocket/session.go b/plugins/transport/websocket/session.go new file mode 100644 index 0000000..3d873cc --- /dev/null +++ b/plugins/transport/websocket/session.go @@ -0,0 +1,149 @@ +package websocket + +import ( + "github.com/google/uuid" + ws "github.com/gorilla/websocket" +) + +var channelBufSize = 256 + +type SessionID string + +type Session struct { + id SessionID + conn *ws.Conn + send chan []byte + server *Server +} + +func NewSession(conn *ws.Conn, server *Server) *Session { + if conn == nil { + panic("conn cannot be nil") + } + + u1, _ := uuid.NewUUID() + + c := &Session{ + id: SessionID(u1.String()), + conn: conn, + send: make(chan []byte, channelBufSize), + server: server, + } + + return c +} + +func (c *Session) Conn() *ws.Conn { + return c.conn +} + +func (c *Session) SessionID() SessionID { + return c.id +} + +func (c *Session) SendMessage(message []byte) { + select { + case c.send <- message: + } +} + +func (c *Session) Close() { + c.server.unregister <- c + c.closeConnect() +} + +func (c *Session) Listen() { + go c.writePump() + go c.readPump() +} + +func (c *Session) closeConnect() { + //LogInfo(c.SessionID(), " connection closed") + if c.conn != nil { + if err := c.conn.Close(); err != nil { + LogErrorf("disconnect error: %s", err.Error()) + } + c.conn = nil + } +} + +func (c *Session) sendPingMessage(message string) error { + return c.conn.WriteMessage(ws.PingMessage, []byte(message)) +} + +func (c *Session) sendPongMessage(message string) error { + return c.conn.WriteMessage(ws.PongMessage, []byte(message)) +} + +func (c *Session) sendTextMessage(message string) error { + return c.conn.WriteMessage(ws.TextMessage, []byte(message)) +} + +func (c *Session) sendBinaryMessage(message []byte) error { + return c.conn.WriteMessage(ws.BinaryMessage, message) +} + +func (c *Session) writePump() { + defer c.Close() + + for { + select { + case msg := <-c.send: + var err error + switch c.server.payloadType { + case PayloadTypeBinary: + if err = c.sendBinaryMessage(msg); err != nil { + LogError("write binary message error: ", err) + return + } + break + + case PayloadTypeText: + if err = c.sendTextMessage(string(msg)); err != nil { + LogError("write text message error: ", err) + return + } + break + } + + } + } +} + +func (c *Session) readPump() { + defer c.Close() + + for { + messageType, data, err := c.conn.ReadMessage() + if err != nil { + if ws.IsUnexpectedCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseAbnormalClosure) { + LogErrorf("read message error: %v", err) + } + return + } + + switch messageType { + case ws.CloseMessage: + return + + case ws.BinaryMessage: + _ = c.server.messageHandler(c.SessionID(), data) + break + + case ws.TextMessage: + _ = c.server.messageHandler(c.SessionID(), data) + break + + case ws.PingMessage: + if err := c.sendPongMessage(""); err != nil { + LogError("write pong message error: ", err) + return + } + break + + case ws.PongMessage: + break + } + + } +} diff --git a/plugins/transport/websocket/session_manager.go b/plugins/transport/websocket/session_manager.go new file mode 100644 index 0000000..4531da2 --- /dev/null +++ b/plugins/transport/websocket/session_manager.go @@ -0,0 +1,77 @@ +package websocket + +import "sync" + +type SessionMap map[SessionID]*Session + +type SessionManager struct { + sessions SessionMap + mtx sync.RWMutex + connectHandler ConnectHandler +} + +func NewSessionManager() *SessionManager { + return &SessionManager{ + sessions: make(SessionMap), + } +} + +func (s *SessionManager) RegisterConnectHandler(handler ConnectHandler) { + s.connectHandler = handler +} + +func (s *SessionManager) Clean() { + s.mtx.Lock() + defer s.mtx.Unlock() + s.sessions = SessionMap{} +} + +func (s *SessionManager) Count() int { + s.mtx.Lock() + defer s.mtx.Unlock() + return len(s.sessions) +} + +func (s *SessionManager) Get(sessionId SessionID) (*Session, bool) { + s.mtx.Lock() + defer s.mtx.Unlock() + c, ok := s.sessions[sessionId] + return c, ok +} + +func (s *SessionManager) Range(fn func(*Session)) { + s.mtx.Lock() + defer s.mtx.Unlock() + + for _, v := range s.sessions { + fn(v) + } +} + +func (s *SessionManager) Add(c *Session) { + s.mtx.Lock() + defer s.mtx.Unlock() + + //log.Info("[websocket] add session: ", c.SessionID()) + s.sessions[c.SessionID()] = c + + if s.connectHandler != nil { + s.connectHandler(c.SessionID(), true) + } +} + +func (s *SessionManager) Remove(c *Session) { + s.mtx.Lock() + defer s.mtx.Unlock() + + for k, v := range s.sessions { + if c == v { + //log.Info("[websocket] remove session: ", c.SessionID()) + if s.connectHandler != nil { + s.connectHandler(c.SessionID(), false) + } + delete(s.sessions, k) + return + } + } +} diff --git a/plugins/transport/websocket/transport.go b/plugins/transport/websocket/transport.go new file mode 100644 index 0000000..6cada46 --- /dev/null +++ b/plugins/transport/websocket/transport.go @@ -0,0 +1,105 @@ +package websocket + +import ( + "context" + "net/http" + + "github.com/go-kratos/kratos/v2/transport" +) + +const ( + KindWebsocket transport.Kind = "websocket" +) + +var _ Transporter = &Transport{} + +type Transporter interface { + transport.Transporter + Request() *http.Request + PathTemplate() string +} + +// Transport is a websocket transport. +type Transport struct { + endpoint string + operation string + reqHeader headerCarrier + replyHeader headerCarrier + request *http.Request + pathTemplate string +} + +// Kind returns the transport kind. +func (tr *Transport) Kind() transport.Kind { + return KindWebsocket +} + +// Endpoint returns the transport endpoint. +func (tr *Transport) Endpoint() string { + return tr.endpoint +} + +// Operation returns the transport operation. +func (tr *Transport) Operation() string { + return tr.operation +} + +// Request returns the HTTP request. +func (tr *Transport) Request() *http.Request { + return tr.request +} + +// RequestHeader returns the request header. +func (tr *Transport) RequestHeader() transport.Header { + return tr.reqHeader +} + +// ReplyHeader returns the reply header. +func (tr *Transport) ReplyHeader() transport.Header { + return tr.replyHeader +} + +// PathTemplate returns the http path template. +func (tr *Transport) PathTemplate() string { + return tr.pathTemplate +} + +// SetOperation sets the transport operation. +func SetOperation(ctx context.Context, op string) { + if tr, ok := transport.FromServerContext(ctx); ok { + if tr, ok := tr.(*Transport); ok { + tr.operation = op + } + } +} + +type headerCarrier http.Header + +// Get returns the value associated with the passed key. +func (hc headerCarrier) Get(key string) string { + return http.Header(hc).Get(key) +} + +// Set stores the key-value pair. +func (hc headerCarrier) Set(key, value string) { + http.Header(hc).Set(key, value) +} + +// Keys lists the keys stored in this carrier. +func (hc headerCarrier) Keys() []string { + keys := make([]string, 0, len(hc)) + for k := range http.Header(hc) { + keys = append(keys, k) + } + return keys +} + +// Add append value to key-values pair. +func (hc headerCarrier) Add(key string, value string) { + http.Header(hc).Add(key, value) +} + +// Values returns a slice of values associated with the passed key. +func (hc headerCarrier) Values(key string) []string { + return http.Header(hc).Values(key) +} diff --git a/plugins/transport/websocket/transport_test.go b/plugins/transport/websocket/transport_test.go new file mode 100644 index 0000000..fc09403 --- /dev/null +++ b/plugins/transport/websocket/transport_test.go @@ -0,0 +1,93 @@ +package websocket + +import ( + "context" + "github.com/go-kratos/kratos/v2/transport" + "net/http" + "reflect" + "sort" + "testing" +) + +func TestTransport_Kind(t *testing.T) { + o := &Transport{} + if !reflect.DeepEqual(KindWebsocket, o.Kind()) { + t.Errorf("expect %v, got %v", KindWebsocket, o.Kind()) + } +} + +func TestTransport_Endpoint(t *testing.T) { + v := "hello" + o := &Transport{endpoint: v} + if !reflect.DeepEqual(v, o.Endpoint()) { + t.Errorf("expect %v, got %v", v, o.Endpoint()) + } +} + +func TestTransport_Operation(t *testing.T) { + v := "hello" + o := &Transport{operation: v} + if !reflect.DeepEqual(v, o.Operation()) { + t.Errorf("expect %v, got %v", v, o.Operation()) + } +} + +func TestTransport_Request(t *testing.T) { + v := &http.Request{} + o := &Transport{request: v} + if !reflect.DeepEqual(v, o.Request()) { + t.Errorf("expect %v, got %v", v, o.Request()) + } +} + +func TestTransport_RequestHeader(t *testing.T) { + v := headerCarrier{} + v.Set("a", "1") + o := &Transport{reqHeader: v} + if !reflect.DeepEqual("1", o.RequestHeader().Get("a")) { + t.Errorf("expect %v, got %v", "1", o.RequestHeader().Get("a")) + } +} + +func TestTransport_ReplyHeader(t *testing.T) { + v := headerCarrier{} + v.Set("a", "1") + o := &Transport{replyHeader: v} + if !reflect.DeepEqual("1", o.ReplyHeader().Get("a")) { + t.Errorf("expect %v, got %v", "1", o.ReplyHeader().Get("a")) + } +} + +func TestTransport_PathTemplate(t *testing.T) { + v := "template" + o := &Transport{pathTemplate: v} + if !reflect.DeepEqual(v, o.PathTemplate()) { + t.Errorf("expect %v, got %v", v, o.PathTemplate()) + } +} + +func TestHeaderCarrier_Keys(t *testing.T) { + v := headerCarrier{} + v.Set("abb", "1") + v.Set("bcc", "2") + want := []string{"Abb", "Bcc"} + keys := v.Keys() + sort.Slice(want, func(i, j int) bool { + return want[i] < want[j] + }) + sort.Slice(keys, func(i, j int) bool { + return keys[i] < keys[j] + }) + if !reflect.DeepEqual(want, keys) { + t.Errorf("expect %v, got %v", want, keys) + } +} + +func TestSetOperation(t *testing.T) { + tr := &Transport{} + ctx := transport.NewServerContext(context.Background(), tr) + SetOperation(ctx, "kratos") + if !reflect.DeepEqual(tr.operation, "kratos") { + t.Errorf("expect %v, got %v", "kratos", tr.operation) + } +}