diff --git a/clientsession.go b/clientsession.go index 68d61869..ee302a14 100644 --- a/clientsession.go +++ b/clientsession.go @@ -1173,7 +1173,16 @@ func (s *ClientSession) filterMessage(message *ServerMessage) *ServerMessage { switch message.Event.Type { case "join": if s.HasPermission(PERMISSION_HIDE_DISPLAYNAMES) { - message.Event.Join = filterDisplayNames(message.Event.Join) + // Create unique copy of message for only this client. + message = &ServerMessage{ + Id: message.Id, + Type: message.Type, + Event: &EventServerMessage{ + Type: message.Event.Type, + Target: message.Event.Target, + Join: filterDisplayNames(message.Event.Join), + }, + } } case "message": if message.Event.Message == nil || message.Event.Message.Data == nil || len(*message.Event.Message.Data) == 0 || !s.HasPermission(PERMISSION_HIDE_DISPLAYNAMES) { @@ -1189,7 +1198,19 @@ func (s *ClientSession) filterMessage(message *ServerMessage) *ServerMessage { if displayName, found := (*data.Chat.Comment)["actorDisplayName"]; found && displayName != "" { (*data.Chat.Comment)["actorDisplayName"] = "" if encoded, err := json.Marshal(data); err == nil { - message.Event.Message.Data = (*json.RawMessage)(&encoded) + // Create unique copy of message for only this client. + message = &ServerMessage{ + Id: message.Id, + Type: message.Type, + Event: &EventServerMessage{ + Type: message.Event.Type, + Target: message.Event.Target, + Message: &RoomEventMessage{ + RoomId: message.Event.Message.RoomId, + Data: (*json.RawMessage)(&encoded), + }, + }, + } } } } diff --git a/grpc_common.go b/grpc_common.go index 62bd437d..48461799 100644 --- a/grpc_common.go +++ b/grpc_common.go @@ -36,13 +36,19 @@ import ( type reloadableCredentials struct { config *tls.Config - pool *CertPoolReloader + loader *CertificateReloader + pool *CertPoolReloader } func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := c.config.Clone() - cfg.RootCAs = c.pool.GetCertPool() + if c.loader != nil { + cfg.GetClientCertificate = c.loader.GetClientCertificate + } + if c.pool != nil { + cfg.RootCAs = c.pool.GetCertPool() + } if cfg.ServerName == "" { serverName, _, err := net.SplitHostPort(authority) if err != nil { @@ -78,7 +84,12 @@ func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority s func (c *reloadableCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { cfg := c.config.Clone() - cfg.ClientCAs = c.pool.GetCertPool() + if c.loader != nil { + cfg.GetCertificate = c.loader.GetCertificate + } + if c.pool != nil { + cfg.ClientCAs = c.pool.GetCertPool() + } conn := tls.Server(rawConn, cfg) if err := conn.Handshake(); err != nil { @@ -130,21 +141,18 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia cfg := &tls.Config{ NextProtos: []string{"h2"}, } + var loader *CertificateReloader + var err error if certificateFile != "" && keyFile != "" { - loader, err := NewCertificateReloader(certificateFile, keyFile) + loader, err = NewCertificateReloader(certificateFile, keyFile) if err != nil { return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err) } - - if server { - cfg.GetCertificate = loader.GetCertificate - } else { - cfg.GetClientCertificate = loader.GetClientCertificate - } } + var pool *CertPoolReloader if caFile != "" { - pool, err := NewCertPoolReloader(caFile) + pool, err = NewCertPoolReloader(caFile) if err != nil { return nil, err } @@ -152,14 +160,9 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia if server { cfg.ClientAuth = tls.RequireAndVerifyClientCert } - creds := &reloadableCredentials{ - config: cfg, - pool: pool, - } - return creds, nil } - if cfg.GetCertificate == nil { + if loader == nil && pool == nil { if server { log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted") } else { @@ -168,5 +171,10 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia return insecure.NewCredentials(), nil } - return credentials.NewTLS(cfg), nil + creds := &reloadableCredentials{ + config: cfg, + loader: loader, + pool: pool, + } + return creds, nil } diff --git a/grpc_common_test.go b/grpc_common_test.go index 038859b3..a525b496 100644 --- a/grpc_common_test.go +++ b/grpc_common_test.go @@ -27,6 +27,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "io/fs" "math/big" "net" "os" @@ -86,3 +87,32 @@ func WritePublicKey(key *rsa.PublicKey, filename string) error { return os.WriteFile(filename, data, 0755) } + +func replaceFile(t *testing.T, filename string, data []byte, perm fs.FileMode) { + t.Helper() + oldStat, err := os.Stat(filename) + if err != nil { + t.Fatalf("can't stat old file %s: %s", filename, err) + return + } + + for { + if err := os.WriteFile(filename, data, perm); err != nil { + t.Fatalf("can't write file %s: %s", filename, err) + return + } + + newStat, err := os.Stat(filename) + if err != nil { + t.Fatalf("can't stat new file %s: %s", filename, err) + return + } + + // We need different modification times. + if !newStat.ModTime().Equal(oldStat.ModTime()) { + break + } + + time.Sleep(time.Millisecond) + } +} diff --git a/grpc_server_test.go b/grpc_server_test.go index 4ce17a4c..4c4abed9 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -126,7 +126,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { org2 := "Updated certificate" cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key) - os.WriteFile(certFile, cert2, 0755) // nolint + replaceFile(t, certFile, cert2, 0755) cp2 := x509.NewCertPool() if !cp2.AppendCertsFromPEM(cert2) { @@ -215,7 +215,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { org2 := "Updated client" clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) - os.WriteFile(caFile, clientCert2, 0755) // nolint + replaceFile(t, caFile, clientCert2, 0755) pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY",