diff --git a/client_test.go b/client_test.go index f6d2dc9..da8dde7 100644 --- a/client_test.go +++ b/client_test.go @@ -9,14 +9,18 @@ import ( gorilla "github.com/kataras/neffos/gorilla" ) -func runTestClient(addr string, connHandler neffos.ConnHandler, testFn func(string, *neffos.Client)) error { +func runTestClient(addr string, connHandler neffos.ConnHandler, testFn func(string, *neffos.Client)) func() error { gobwasClient, err := neffos.Dial(nil, gobwas.DefaultDialer, fmt.Sprintf("ws://%s/gobwas", addr), connHandler) if err != nil { - return err + return func() error { + return err + } } gorillaClient, err := neffos.Dial(nil, gorilla.DefaultDialer, fmt.Sprintf("ws://%s/gorilla", addr), connHandler) if err != nil { - return err + return func() error { + return err + } } // teardown. @@ -25,9 +29,8 @@ func runTestClient(addr string, connHandler neffos.ConnHandler, testFn func(stri gorillaClient.Close() return nil } - defer teardown() testFn("gobwas", gobwasClient) testFn("gorilla", gorillaClient) - return nil + return teardown } diff --git a/conn_namespace_test.go b/conn_namespace_test.go index 11189e7..c8737bb 100644 --- a/conn_namespace_test.go +++ b/conn_namespace_test.go @@ -85,7 +85,7 @@ func TestJoinAndLeaveRoom(t *testing.T) { t.Fatalf("expected true") } wg.Wait() - }) + })() if err != nil { t.Fatal(err) } diff --git a/conn_test.go b/conn_test.go index 2dfa041..022d914 100644 --- a/conn_test.go +++ b/conn_test.go @@ -86,7 +86,7 @@ func TestConnect(t *testing.T) { t.Fatalf("%s namespace connect should give a local event's error by the client of the neffos.ErrBadNamespace but got: %v", namespaceThatShouldErrOnServer, err) } - }) + })() if err != nil { t.Fatal(err) } @@ -141,7 +141,7 @@ func TestAsk(t *testing.T) { t.Fatal(err) } testMessage(dialer, -1, msg) - }) + })() if err != nil { t.Fatal(err) } @@ -200,7 +200,7 @@ func TestOnAnyEvent(t *testing.T) { t.Fatal(err) } testMessage(msg) - }) + })() if err != nil { t.Fatal(err) } @@ -278,7 +278,7 @@ func TestOnNativeMessageAndMessageError(t *testing.T) { c.Emit(eventThatWillGiveErrorByServer, []byte("doesn't matter")) wg.Wait() - }) + })() if err != nil { t.Fatal(err) } @@ -365,7 +365,7 @@ func TestSimultaneouslyEventsRoutines(t *testing.T) { } wg.Wait() - }) + })() if err != nil { t.Fatal(err) } diff --git a/message.go b/message.go index 8da17ae..7d35fcc 100644 --- a/message.go +++ b/message.go @@ -52,6 +52,13 @@ type Message struct { // the CONN ID, filled automatically if `Server#Broadcast` first parameter of sender connection's ID is not empty, not exposed to the subscribers (rest of the clients). from string + // To is the connection ID of the receiver, used only when `Server#Broadcast` is called, indeed when we only need to send a message to a single connection. + // The Namespace, Room are still respected at all. + // + // However, sending messages to a group of connections is done by the `Room` field for groups inside a namespace or just `Namespace` field as usual. + // This field is not filled on sending/receiving. + To string + // True when event came from local (i.e client if running client) on force disconnection, // i.e OnNamespaceDisconnect and OnRoomLeave when closing a conn. // This field is not filled on sending/receiving. @@ -243,6 +250,7 @@ func deserializeMessage(decrypt MessageDecrypt, b []byte, allowNativeMessages bo isNoOp, isInvalid, "", + "", false, false, allowNativeMessages && event == OnNativeMessage, diff --git a/server.go b/server.go index 0a48240..7721156 100644 --- a/server.go +++ b/server.go @@ -273,13 +273,17 @@ func (s *Server) waitMessage(c *Conn) bool { return false } - if msg.from != c.ID() { - if !c.Write(msg) && c.IsClosed() { - return false - } + // don't send to its own if set-ed. + if msg.from == c.ID() { + return true + } + + // if "To" field is given then send to a specific connection. + if msg.To != "" && msg.To != c.ID() { + return true } - return true + return c.Write(msg) && !c.IsClosed() } // GetTotalConnections returns the total amount of the connected connections to the server, it's fast diff --git a/server_test.go b/server_test.go index a1e7c5f..90e1ea5 100644 --- a/server_test.go +++ b/server_test.go @@ -1,7 +1,11 @@ package neffos_test import ( + "bytes" "net/http" + "sync" + "sync/atomic" + "testing" "time" "github.com/kataras/neffos" @@ -10,10 +14,15 @@ import ( gorilla "github.com/kataras/neffos/gorilla" ) -func runTestServer(addr string, connHandler neffos.ConnHandler) func() error { +func runTestServer(addr string, connHandler neffos.ConnHandler, configureServer ...func(*neffos.Server)) func() error { gobwasServer := neffos.New(gobwas.DefaultUpgrader, connHandler) gorillaServer := neffos.New(gorilla.DefaultUpgrader, connHandler) + for _, cfg := range configureServer { + cfg(gobwasServer) + cfg(gorillaServer) + } + mux := http.NewServeMux() mux.Handle("/gobwas", gobwasServer) mux.Handle("/gorilla", gorillaServer) @@ -32,3 +41,74 @@ func runTestServer(addr string, connHandler neffos.ConnHandler) func() error { return httpServer.Close() } } + +func TestServerBroadcastTo(t *testing.T) { + // we fire up two connections, one with the "conn_ID" and other with the default uuid id generator, + // the message which the second client emits should only be sent to the connection with the ID of "conn_ID". + + var ( + wg sync.WaitGroup + namespace = "default" + body = []byte("data") + to = "conn_ID" + events = neffos.Namespaces{ + namespace: neffos.Events{ + "event": func(c *neffos.NSConn, msg neffos.Message) error { + if c.Conn.IsClient() { + if !bytes.Equal(msg.Body, body) { + t.Fatalf("expected event's incoming data to be: %s but got: %s", string(body), string(msg.Body)) + } + + if c.String() != to { + t.Fatalf("expected the message to be sent only to the connection with an ID of 'conn_ID'") + } + + wg.Done() + } else { + msg.To = to + c.Conn.Server().Broadcast(c, msg) + } + + return nil + }, + }, + } + ) + + teardownServer := runTestServer("localhost:8080", events, func(wsServer *neffos.Server) { + once := new(uint32) + wsServer.IDGenerator = func(w http.ResponseWriter, r *http.Request) string { + if atomic.CompareAndSwapUint32(once, 0, 1) { + return to // set the "to" only to the first conn for test. + } + + return neffos.DefaultIDGenerator(w, r) + } + }) + defer teardownServer() + + wg.Add(2) + + teardownClient1 := runTestClient("localhost:8080", events, + func(dialer string, client *neffos.Client) { + _, err := client.Connect(nil, namespace) + if err != nil { + t.Fatal(err) + } + + }) + + defer teardownClient1() + + teardownClient2 := runTestClient("localhost:8080", events, + func(dialer string, client *neffos.Client) { + c, err := client.Connect(nil, namespace) + if err != nil { + t.Fatal(err) + } + c.Emit("event", body) + }) + defer teardownClient2() + + wg.Wait() +}