diff --git a/README.md b/README.md index 7dc4f772c..7975c386f 100644 --- a/README.md +++ b/README.md @@ -201,9 +201,11 @@ Once up and running, don't forget to [register one or more users](#creating-jack - [XEP-0030: Service Discovery](https://xmpp.org/extensions/xep-0030.html) *2.5rc3* - [XEP-0049: Private XML Storage](https://xmpp.org/extensions/xep-0049.html) *1.2* - [XEP-0054: vcard-temp](https://xmpp.org/extensions/xep-0054.html) *1.2* +- [XEP-0059: Result Set Management](https://xmpp.org/extensions/xep-0059.html) *1.0* - [XEP-0092: Software Version](https://xmpp.org/extensions/xep-0092.html) *1.1* - [XEP-0114: Jabber Component Protocol](https://xmpp.org/extensions/xep-0114.html) *1.6* -- [XEP-0115: Entity Capabilities](https://xmpp.org/extensions/xep-0115.html) *1.5.2* +- [XEP-0115: Entity Capabilities](https://xmpp.org/extensions/xep-0115.html) *1.5.2* +- [XEP-0122: Data Forms Validation](https://xmpp.org/extensions/xep-0122.html) *1.0.2* - [XEP-0138: Stream Compression](https://xmpp.org/extensions/xep-0138.html) *2.0* - [XEP-0160: Best Practices for Handling Offline Messages](https://xmpp.org/extensions/xep-0160.html) *1.0.1* - [XEP-0190: Best Practice for Closing Idle Streams](https://xmpp.org/extensions/xep-0190.html) *1.1* @@ -214,6 +216,8 @@ Once up and running, don't forget to [register one or more users](#creating-jack - [XEP-0220: Server Dialback](https://xmpp.org/extensions/xep-0220.html) *1.1.1* - [XEP-0237: Roster Versioning](https://xmpp.org/extensions/xep-0237.html) *1.3* - [XEP-0280: Message Carbons](https://xmpp.org/extensions/xep-0280.html) *0.13.3* +- [XEP-0297: Stanza Forwarding](https://xmpp.org/extensions/xep-0297.html) *1.0* +- [XEP-0313: Message Archive Management](https://xmpp.org/extensions/xep-0313.html) *1.0.1* - [XEP-0368: SRV records for XMPP over TLS](https://xmpp.org/extensions/xep-0368.html) *1.1.0* ## Join and Contribute diff --git a/config/example.config.yaml b/config/example.config.yaml index 34e9f18b9..7dbf2c3f9 100644 --- a/config/example.config.yaml +++ b/config/example.config.yaml @@ -24,20 +24,20 @@ # cert_file: "" # privkey_file: "" -#storage: -# type: pgsql -# pgsql: -# host: 127.0.0.1:5432 -# user: jackal -# password: a-secret-key -# database: jackal -# max_open_conns: 16 -# -# cache: -# type: redis -# redis: -# addresses: -# - localhost:6379 +storage: + type: pgsql + pgsql: + host: 127.0.0.1:5432 + user: jackal + password: a-secret-key + database: jackal + max_open_conns: 16 + + cache: + type: redis + redis: + addresses: + - localhost:6379 #cluster: # type: kv @@ -128,6 +128,7 @@ modules: # - ping # XEP-0199: XMPP Ping # - time # XEP-0202: Entity Time # - carbons # XEP-0280: Message Carbons +# - mam # XEP-0313: Message Archive Management # # version: # show_os: true @@ -140,6 +141,10 @@ modules: # interval: 3m # send_pings: true # timeout_action: kill +# +# mam: +# queue_size: 1500 +# components: secret: a-super-secret-key diff --git a/go.mod b/go.mod index 46d649ada..1363d1b55 100644 --- a/go.mod +++ b/go.mod @@ -17,11 +17,12 @@ require ( github.com/google/uuid v1.1.2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/jackal-xmpp/runqueue/v2 v2.0.0 - github.com/jackal-xmpp/stravaganza v1.2.3 + github.com/jackal-xmpp/stravaganza v1.2.4 github.com/kkyr/fig v0.2.0 github.com/lib/pq v1.8.0 github.com/mattn/go-sqlite3 v1.14.5 // indirect github.com/prometheus/client_golang v1.11.0 + github.com/samber/lo v1.25.0 github.com/spf13/cobra v1.1.3 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.1 @@ -61,6 +62,7 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.17.0 // indirect + golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/net v0.0.0-20220526153639-5463443f8c37 // indirect golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/text v0.3.7 // indirect diff --git a/go.sum b/go.sum index 5c5106ade..1a40ef4a4 100644 --- a/go.sum +++ b/go.sum @@ -225,8 +225,8 @@ github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62 github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= github.com/jackal-xmpp/runqueue/v2 v2.0.0 h1:QfvOfL6zF5yK1LN5TKabpj+VBuELMwtR8Xpkz0CrjoI= github.com/jackal-xmpp/runqueue/v2 v2.0.0/go.mod h1:tXZARVqBMGeV8BTc/qDPg0qXILTUWmER7wlYbN9Xcac= -github.com/jackal-xmpp/stravaganza v1.2.3 h1:fxxyvtkj94CHYfooy7YsFRue7jFtJaMg3BozfWlzSOY= -github.com/jackal-xmpp/stravaganza v1.2.3/go.mod h1:oesgQpMM0I5gnJM80NsEfSspzDDCArQex+oA0/swCWU= +github.com/jackal-xmpp/stravaganza v1.2.4 h1:xz3L2lNEPezOn43az4W4omK1at9tSuR4BDaWOSKo6aE= +github.com/jackal-xmpp/stravaganza v1.2.4/go.mod h1:oesgQpMM0I5gnJM80NsEfSspzDDCArQex+oA0/swCWU= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -306,6 +306,7 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ3M8LwxM= github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -363,6 +364,8 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samber/lo v1.25.0 h1:H8F6cB0RotRdgcRCivTByAQePaYhGMdOTJIj2QFS2I0= +github.com/samber/lo v1.25.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -397,6 +400,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= @@ -450,6 +454,8 @@ golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -642,8 +648,8 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= diff --git a/helm/sql/postgres.up.psql b/helm/sql/postgres.up.psql index 3be5cf8da..39d0ad1c1 100644 --- a/helm/sql/postgres.up.psql +++ b/helm/sql/postgres.up.psql @@ -170,3 +170,25 @@ CREATE TABLE IF NOT EXISTS vcards ( ); SELECT enable_updated_at('vcards'); + +-- archives + +CREATE TABLE IF NOT EXISTS archives ( + serial SERIAL PRIMARY KEY, + archive_id VARCHAR(1023), + id VARCHAR(255) NOT NULL, + "from" TEXT NOT NULL, + from_bare TEXT NOT NULL, + "to" TEXT NOT NULL, + to_bare TEXT NOT NULL, + message BYTEA NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_archives_archive_id ON archives(archive_id); +CREATE INDEX IF NOT EXISTS i_archives_id ON archives(id); +CREATE INDEX IF NOT EXISTS i_archives_to ON archives("to"); +CREATE INDEX IF NOT EXISTS i_archives_to_bare ON archives(to_bare); +CREATE INDEX IF NOT EXISTS i_archives_from ON archives("from"); +CREATE INDEX IF NOT EXISTS i_archives_from_bare ON archives(from_bare); +CREATE INDEX IF NOT EXISTS i_archives_created_at ON archives(created_at); diff --git a/helm/values.yaml b/helm/values.yaml index 8f9e21369..d28dc9c05 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -125,6 +125,7 @@ jackal: - ping # XEP-0199: XMPP Ping - time # XEP-0202: Entity Time - carbons # XEP-0280: Message Carbons + - mam # XEP-0313: Message Archive Management version: show_os: true @@ -138,6 +139,9 @@ jackal: send_pings: true timeout_action: kill + mam: + queue_size: 1500 + components: # listeners: # - port: 5275 diff --git a/pkg/admin/pb/users.pb.go b/pkg/admin/pb/users.pb.go index 2dbd15c7f..9ade76daa 100644 --- a/pkg/admin/pb/users.pb.go +++ b/pkg/admin/pb/users.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/admin/v1/users.proto package pb diff --git a/pkg/c2s/in.go b/pkg/c2s/in.go index 11f5d30a6..042ac8bf4 100644 --- a/pkg/c2s/in.go +++ b/pkg/c2s/in.go @@ -404,13 +404,13 @@ func (s *inC2S) connTimeout() { func (s *inC2S) handleElement(ctx context.Context, elem stravaganza.Element) error { // run received element hook - hInf := &hook.C2SStreamInfo{ + hi := &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: elem, } - halted, err := s.runHook(ctx, hook.C2SStreamElementReceived, hInf) + halted, err := s.runHook(ctx, hook.C2SStreamElementReceived, hi) if halted { return nil } @@ -421,15 +421,15 @@ func (s *inC2S) handleElement(ctx context.Context, elem stravaganza.Element) err t0 := time.Now() switch s.getState() { case inConnecting: - err = s.handleConnecting(ctx, hInf.Element) + err = s.handleConnecting(ctx, hi.Element) case inConnected: - err = s.handleConnected(ctx, hInf.Element) + err = s.handleConnected(ctx, hi.Element) case inAuthenticating: - err = s.handleAuthenticating(ctx, hInf.Element) + err = s.handleAuthenticating(ctx, hi.Element) case inAuthenticated: - err = s.handleAuthenticated(ctx, hInf.Element) + err = s.handleAuthenticated(ctx, hi.Element) case inBinded: - err = s.handleBinded(ctx, hInf.Element) + err = s.handleBinded(ctx, hi.Element) } reportIncomingRequest( elem.Name(), @@ -553,15 +553,18 @@ func (s *inC2S) processStanza(ctx context.Context, stanza stravaganza.Stanza) er func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { // run iq received hook - _, err := s.runHook(ctx, hook.C2SStreamIQReceived, &hook.C2SStreamInfo{ + hi := &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: iq, - }) + } + _, err := s.runHook(ctx, hook.C2SStreamIQReceived, hi) if err != nil { return err } + iq = hi.Element.(*stravaganza.IQ) + if iq.IsSet() && iq.ChildNamespace("session", sessionNamespace) != nil { if !s.flags.isSessionStarted() { s.flags.setSessionStarted() @@ -576,24 +579,24 @@ func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { return s.mods.ProcessIQ(ctx, iq) } // run will route iq hook - hInf := &hook.C2SStreamInfo{ + hi = &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: iq, } - halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hInf) + halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hi) if halted { return nil } if err != nil { return err } - outIQ, ok := hInf.Element.(*stravaganza.IQ) + iq, ok := hi.Element.(*stravaganza.IQ) if !ok { return nil } - targets, err := s.router.Route(ctx, outIQ) + targets, err := s.router.Route(ctx, iq) switch err { case router.ErrResourceNotFound: return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, iq).Element()) @@ -604,8 +607,8 @@ func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { case router.ErrRemoteServerTimeout: return s.sendElement(ctx, stanzaerror.E(stanzaerror.RemoteServerTimeout, iq).Element()) - case nil: - _, err := s.runHook(ctx, hook.C2SStreamIQRouted, &hook.C2SStreamInfo{ + case nil, router.ErrUserNotAvailable: + _, err = s.runHook(ctx, hook.C2SStreamIQRouted, &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), @@ -619,38 +622,40 @@ func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { func (s *inC2S) processPresence(ctx context.Context, presence *stravaganza.Presence) error { // run presence received hook - _, err := s.runHook(ctx, hook.C2SStreamPresenceReceived, &hook.C2SStreamInfo{ + hi := &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: presence, - }) + } + _, err := s.runHook(ctx, hook.C2SStreamPresenceReceived, hi) if err != nil { return err } + presence = hi.Element.(*stravaganza.Presence) if presence.ToJID().IsFullWithUser() { // run will route presence hook - hInf := &hook.C2SStreamInfo{ + hi = &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: presence, } - halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hInf) + halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hi) if halted { return nil } if err != nil { return err } - outPr, ok := hInf.Element.(*stravaganza.Presence) + presence, ok := hi.Element.(*stravaganza.Presence) if !ok { return nil } - targets, err := s.router.Route(ctx, outPr) + targets, err := s.router.Route(ctx, presence) switch err { - case nil: + case nil, router.ErrUserNotAvailable: _, err = s.runHook(ctx, hook.C2SStreamPresenceRouted, &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), @@ -672,37 +677,38 @@ func (s *inC2S) processPresence(ctx context.Context, presence *stravaganza.Prese func (s *inC2S) processMessage(ctx context.Context, message *stravaganza.Message) error { // run message received hook - _, err := s.runHook(ctx, hook.C2SStreamMessageReceived, &hook.C2SStreamInfo{ + hi := &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: message, - }) + } + _, err := s.runHook(ctx, hook.C2SStreamMessageReceived, hi) if err != nil { return err } - msg := message + msg := hi.Element.(*stravaganza.Message) sendMsg: - // run will route Message hook - hInf := &hook.C2SStreamInfo{ + // run will route message hook + hi = &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Element: msg, } - halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hInf) + halted, err := s.runHook(ctx, hook.C2SStreamWillRouteElement, hi) if halted { return nil } if err != nil { return err } - outMsg, ok := hInf.Element.(*stravaganza.Message) + msg, ok := hi.Element.(*stravaganza.Message) if !ok { return nil } - targets, err := s.router.Route(ctx, outMsg) + targets, err := s.router.Route(ctx, msg) switch err { case router.ErrResourceNotFound: // treat the stanza as if it were addressed to @@ -721,18 +727,21 @@ sendMsg: case router.ErrRemoteServerTimeout: return s.sendElement(ctx, stanzaerror.E(stanzaerror.RemoteServerTimeout, message).Element()) - case router.ErrUserNotAvailable: - return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, message).Element()) - - case nil: - _, err = s.runHook(ctx, hook.C2SStreamMessageRouted, &hook.C2SStreamInfo{ + case nil, router.ErrUserNotAvailable: + halted, hErr := s.runHook(ctx, hook.C2SStreamMessageRouted, &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), Presence: s.Presence(), Targets: targets, Element: msg, }) - return err + if halted { + return nil + } + if errors.Is(err, router.ErrUserNotAvailable) { + return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, message).Element()) + } + return hErr default: return err @@ -1108,6 +1117,7 @@ func (s *inC2S) close(ctx context.Context, disconnectErr error) error { halted, err := s.runHook(ctx, hook.C2SStreamDisconnected, &hook.C2SStreamInfo{ ID: s.ID().String(), JID: s.JID(), + Presence: s.Presence(), DisconnectError: disconnectErr, }) if halted { diff --git a/pkg/c2s/pb/resourceinfo.pb.go b/pkg/c2s/pb/resourceinfo.pb.go index 9b7c2685d..ff94ee504 100644 --- a/pkg/c2s/pb/resourceinfo.pb.go +++ b/pkg/c2s/pb/resourceinfo.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/c2s/v1/resourceinfo.proto package pb diff --git a/pkg/c2s/router.go b/pkg/c2s/router.go index 30285008c..a5205aeb9 100644 --- a/pkg/c2s/router.go +++ b/pkg/c2s/router.go @@ -16,6 +16,7 @@ package c2s import ( "context" + "fmt" "sort" kitlog "github.com/go-kit/log" @@ -115,8 +116,12 @@ func (r *c2sRouter) Unregister(stm stream.C2S) error { return nil } -func (r *c2sRouter) LocalStream(username, resource string) stream.C2S { - return r.local.Stream(username, resource) +func (r *c2sRouter) LocalStream(username, resource string) (stream.C2S, error) { + stm := r.local.Stream(username, resource) + if stm == nil { + return nil, fmt.Errorf("c2s: local stream not found: %s/%s", username, resource) + } + return stm, nil } func (r *c2sRouter) Start(ctx context.Context) error { diff --git a/pkg/cluster/pb/cluster.pb.go b/pkg/cluster/pb/cluster.pb.go index d5a18a03e..50e337584 100644 --- a/pkg/cluster/pb/cluster.pb.go +++ b/pkg/cluster/pb/cluster.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/cluster/v1/cluster.proto package pb diff --git a/pkg/hook/c2s.go b/pkg/hook/c2s.go index 744361b1f..16d763b40 100644 --- a/pkg/hook/c2s.go +++ b/pkg/hook/c2s.go @@ -47,16 +47,16 @@ const ( // C2SStreamWillRouteElement hook runs when an XMPP element is about to be routed over a C2S stream. C2SStreamWillRouteElement = "c2s.stream.will_route_element" - // C2SStreamIQRouted hook runs when an iq stanza is successfully routed to one ore more C2S streams. + // C2SStreamIQRouted hook runs when an iq stanza is successfully routed to zero or more C2S streams. C2SStreamIQRouted = "c2s.stream.iq_routed" - // C2SStreamPresenceRouted hook runs when a presence stanza is successfully routed to one ore more C2S streams. + // C2SStreamPresenceRouted hook runs when a presence stanza is successfully routed to zero or more C2S streams. C2SStreamPresenceRouted = "c2s.stream.presence_routed" - // C2SStreamMessageRouted hook runs when a message stanza is successfully routed to one ore more C2S streams. + // C2SStreamMessageRouted hook runs when a message stanza is successfully routed to zero or more C2S streams. C2SStreamMessageRouted = "c2s.stream.message_routed" - // C2SStreamElementSent hook runs when a XMPP element is sent over a C2S stream. + // C2SStreamElementSent hook runs when an XMPP element is sent over a C2S stream. C2SStreamElementSent = "c2s.stream.element_sent" ) diff --git a/pkg/hook/hooks.go b/pkg/hook/hooks.go index f61542dee..e7ea5600d 100644 --- a/pkg/hook/hooks.go +++ b/pkg/hook/hooks.go @@ -28,13 +28,19 @@ type Priority int32 const ( // LowestPriority defines lowest hook execution priority. - LowestPriority = Priority(math.MinInt32 + 100) + LowestPriority = Priority(math.MinInt32) + + // LowPriority defines low hook execution priority. + LowPriority = Priority(math.MinInt32 + 1000) // DefaultPriority defines default hook execution priority. DefaultPriority = Priority(0) + // HighPriority defines high hook execution priority. + HighPriority = Priority(math.MaxInt32 - 1000) + // HighestPriority defines highest hook execution priority. - HighestPriority = Priority(math.MaxInt32 - 100) + HighestPriority = Priority(math.MaxInt32) ) // Handler defines a generic hook handler function. diff --git a/pkg/hook/s2s.go b/pkg/hook/s2s.go index 0eab5ed2f..6b13c4d6a 100644 --- a/pkg/hook/s2s.go +++ b/pkg/hook/s2s.go @@ -16,6 +16,7 @@ package hook import ( "github.com/jackal-xmpp/stravaganza" + "github.com/jackal-xmpp/stravaganza/jid" ) const ( @@ -25,7 +26,7 @@ const ( // S2SOutStreamDisconnected hook runs when an outgoing S2S connection is unregistered. S2SOutStreamDisconnected = "s2s.out.stream.disconnected" - // S2SOutStreamElementSent hook runs whenever a XMPP element is sent over an outgoing S2S stream. + // S2SOutStreamElementSent hook runs whenever an XMPP element is sent over an outgoing S2S stream. S2SOutStreamElementSent = "s2s.out.stream.element_sent" // S2SInStreamRegistered hook runs when an incoming S2S connection is registered. @@ -34,7 +35,7 @@ const ( // S2SInStreamUnregistered hook runs when an incoming S2S connection is unregistered. S2SInStreamUnregistered = "s2s.in.stream.unregistered" - // S2SInStreamElementReceived hook runs when a XMPP element is received over an incoming S2S stream. + // S2SInStreamElementReceived hook runs when an XMPP element is received over an incoming S2S stream. S2SInStreamElementReceived = "s2s.in.stream.stanza_received" // S2SInStreamIQReceived hook runs when an iq stanza is received over an incoming S2S stream. @@ -49,13 +50,13 @@ const ( // S2SInStreamWillRouteElement hook runs when an XMPP element is about to be routed on an incoming S2S stream. S2SInStreamWillRouteElement = "s2s.in.stream.will_route_element" - // S2SInStreamIQRouted hook runs when an iq stanza is successfully routed to one ore more S2S streams. + // S2SInStreamIQRouted hook runs when an iq stanza is successfully routed to zero or more C2S streams. S2SInStreamIQRouted = "s2s.in.stream.iq_routed" - // S2SInStreamPresenceRouted hook runs when a presence stanza is successfully routed to one ore more S2S streams. + // S2SInStreamPresenceRouted hook runs when a presence stanza is successfully routed to zero or more C2S streams. S2SInStreamPresenceRouted = "s2s.in.stream.presence_routed" - // S2SInStreamMessageRouted hook runs when a message stanza is successfully routed to one ore more S2S streams. + // S2SInStreamMessageRouted hook runs when a message stanza is successfully routed to zero or more C2S streams. S2SInStreamMessageRouted = "s2s.in.stream.message_routed" ) @@ -70,6 +71,9 @@ type S2SStreamInfo struct { // Target is the S2S target domain. Target string + // Targets contains all JIDs to which event stanza was routed. + Targets []jid.JID + // Element is the event associated XMPP element. Element stravaganza.Element } diff --git a/pkg/jackal/config.go b/pkg/jackal/config.go index ae4c95173..95452d224 100644 --- a/pkg/jackal/config.go +++ b/pkg/jackal/config.go @@ -17,6 +17,8 @@ package jackal import ( "path/filepath" + "github.com/ortuman/jackal/pkg/module/xep0313" + "github.com/kkyr/fig" adminserver "github.com/ortuman/jackal/pkg/admin/server" "github.com/ortuman/jackal/pkg/auth/pepper" @@ -95,6 +97,9 @@ type ModulesConfig struct { // XEP-0199: XMPP Ping Ping xep0199.Config `fig:"ping"` + + // XEP-0313: Message Archive Management + Mam xep0313.Config `fig:"mam"` } // Config defines jackal application configuration. diff --git a/pkg/jackal/jackal.go b/pkg/jackal/jackal.go index 6ccdac89a..1eeb6dfff 100644 --- a/pkg/jackal/jackal.go +++ b/pkg/jackal/jackal.go @@ -26,8 +26,6 @@ import ( "syscall" "time" - streamqueue "github.com/ortuman/jackal/pkg/module/xep0198/queue" - kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -47,6 +45,7 @@ import ( "github.com/ortuman/jackal/pkg/host" "github.com/ortuman/jackal/pkg/log" "github.com/ortuman/jackal/pkg/module" + streamqueue "github.com/ortuman/jackal/pkg/module/xep0198/queue" "github.com/ortuman/jackal/pkg/router" "github.com/ortuman/jackal/pkg/s2s" "github.com/ortuman/jackal/pkg/shaper" diff --git a/pkg/jackal/modules.go b/pkg/jackal/modules.go index d03b87fff..d6b11bcda 100644 --- a/pkg/jackal/modules.go +++ b/pkg/jackal/modules.go @@ -30,6 +30,7 @@ import ( "github.com/ortuman/jackal/pkg/module/xep0199" "github.com/ortuman/jackal/pkg/module/xep0202" "github.com/ortuman/jackal/pkg/module/xep0280" + "github.com/ortuman/jackal/pkg/module/xep0313" ) var defaultModules = []string{ @@ -45,6 +46,7 @@ var defaultModules = []string{ xep0198.ModuleName, xep0199.ModuleName, xep0280.ModuleName, + xep0313.ModuleName, } var modFns = map[string]func(a *Jackal, cfg *ModulesConfig) module.Module{ @@ -56,7 +58,7 @@ var modFns = map[string]func(a *Jackal, cfg *ModulesConfig) module.Module{ // Offline // (https://xmpp.org/extensions/xep-0160.html) offline.ModuleName: func(j *Jackal, cfg *ModulesConfig) module.Module { - return offline.New(cfg.Offline, j.router, j.hosts, j.resMng, j.rep, j.hk, j.logger) + return offline.New(cfg.Offline, j.router, j.hosts, j.rep, j.hk, j.logger) }, // XEP-0012: Last Activity // (https://xmpp.org/extensions/xep-0012.html) @@ -114,4 +116,9 @@ var modFns = map[string]func(a *Jackal, cfg *ModulesConfig) module.Module{ xep0280.ModuleName: func(j *Jackal, _ *ModulesConfig) module.Module { return xep0280.New(j.router, j.hosts, j.resMng, j.hk, j.logger) }, + // XEP-0313: Message Archive Management + // (https://xmpp.org/extensions/xep-0313.html) + xep0313.ModuleName: func(j *Jackal, cfg *ModulesConfig) module.Module { + return xep0313.New(cfg.Mam, j.router, j.hosts, j.rep, j.hk, j.logger) + }, } diff --git a/pkg/model/archive/archive.pb.go b/pkg/model/archive/archive.pb.go new file mode 100644 index 000000000..a9249431c --- /dev/null +++ b/pkg/model/archive/archive.pb.go @@ -0,0 +1,514 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.21.5 +// source: proto/model/v1/archive.proto + +package archivemodel + +import ( + stravaganza "github.com/jackal-xmpp/stravaganza" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Message represents an archive message entity. +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // archived_id is the message archive identifier. + ArchiveId string `protobuf:"bytes,1,opt,name=archive_id,json=archiveId,proto3" json:"archive_id,omitempty"` + // id is the message archive unique identifier. + Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` + // from_jid is the message from jid value. + FromJid string `protobuf:"bytes,3,opt,name=from_jid,json=fromJid,proto3" json:"from_jid,omitempty"` + // to_jid is the message from jid value. + ToJid string `protobuf:"bytes,4,opt,name=to_jid,json=toJid,proto3" json:"to_jid,omitempty"` + // message is the archived message. + Message *stravaganza.PBElement `protobuf:"bytes,5,opt,name=message,proto3" json:"message,omitempty"` + // stamp is the timestamp in which the message was archived. + Stamp *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=stamp,proto3" json:"stamp,omitempty"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_model_v1_archive_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_proto_model_v1_archive_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_proto_model_v1_archive_proto_rawDescGZIP(), []int{0} +} + +func (x *Message) GetArchiveId() string { + if x != nil { + return x.ArchiveId + } + return "" +} + +func (x *Message) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Message) GetFromJid() string { + if x != nil { + return x.FromJid + } + return "" +} + +func (x *Message) GetToJid() string { + if x != nil { + return x.ToJid + } + return "" +} + +func (x *Message) GetMessage() *stravaganza.PBElement { + if x != nil { + return x.Message + } + return nil +} + +func (x *Message) GetStamp() *timestamppb.Timestamp { + if x != nil { + return x.Stamp + } + return nil +} + +// Messages represents a set of archive messages. +type Messages struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ArchiveMessages []*Message `protobuf:"bytes,1,rep,name=archive_messages,json=archiveMessages,proto3" json:"archive_messages,omitempty"` +} + +func (x *Messages) Reset() { + *x = Messages{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_model_v1_archive_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Messages) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Messages) ProtoMessage() {} + +func (x *Messages) ProtoReflect() protoreflect.Message { + mi := &file_proto_model_v1_archive_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Messages.ProtoReflect.Descriptor instead. +func (*Messages) Descriptor() ([]byte, []int) { + return file_proto_model_v1_archive_proto_rawDescGZIP(), []int{1} +} + +func (x *Messages) GetArchiveMessages() []*Message { + if x != nil { + return x.ArchiveMessages + } + return nil +} + +// Metadata represents an archive metadata information. +type Metadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // start_timestamp is the identifier of the first archive message. + StartId string `protobuf:"bytes,1,opt,name=start_id,json=startId,proto3" json:"start_id,omitempty"` + // start_timestamp is the timestamp value of the first archive message. + StartTimestamp string `protobuf:"bytes,2,opt,name=start_timestamp,json=startTimestamp,proto3" json:"start_timestamp,omitempty"` + // end_id is the identifier of the last archive message. + EndId string `protobuf:"bytes,3,opt,name=end_id,json=endId,proto3" json:"end_id,omitempty"` + // end_timestamp is the timestamp value of the last archive message. + EndTimestamp string `protobuf:"bytes,4,opt,name=end_timestamp,json=endTimestamp,proto3" json:"end_timestamp,omitempty"` +} + +func (x *Metadata) Reset() { + *x = Metadata{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_model_v1_archive_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Metadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Metadata) ProtoMessage() {} + +func (x *Metadata) ProtoReflect() protoreflect.Message { + mi := &file_proto_model_v1_archive_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Metadata.ProtoReflect.Descriptor instead. +func (*Metadata) Descriptor() ([]byte, []int) { + return file_proto_model_v1_archive_proto_rawDescGZIP(), []int{2} +} + +func (x *Metadata) GetStartId() string { + if x != nil { + return x.StartId + } + return "" +} + +func (x *Metadata) GetStartTimestamp() string { + if x != nil { + return x.StartTimestamp + } + return "" +} + +func (x *Metadata) GetEndId() string { + if x != nil { + return x.EndId + } + return "" +} + +func (x *Metadata) GetEndTimestamp() string { + if x != nil { + return x.EndTimestamp + } + return "" +} + +// Filters define a set of filters to be applied when fetching archive messages. +type Filters struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // start is used to filter out messages before a certain date/time. + Start *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=start,proto3" json:"start,omitempty"` + // end is used to filter out messages after a certain date/time. + End *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=end,proto3" json:"end,omitempty"` + // with contains a JID against which to match messages. + With string `protobuf:"bytes,3,opt,name=with,proto3" json:"with,omitempty"` + // before_id is the id of the newest message user wants to fetch. + BeforeId string `protobuf:"bytes,4,opt,name=before_id,json=beforeId,proto3" json:"before_id,omitempty"` + // after_id is the id of the oldest message user wants to fetch. + AfterId string `protobuf:"bytes,5,opt,name=after_id,json=afterId,proto3" json:"after_id,omitempty"` + // ids contains one or more ids the user wants to fetch. + Ids []string `protobuf:"bytes,6,rep,name=ids,proto3" json:"ids,omitempty"` +} + +func (x *Filters) Reset() { + *x = Filters{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_model_v1_archive_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Filters) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Filters) ProtoMessage() {} + +func (x *Filters) ProtoReflect() protoreflect.Message { + mi := &file_proto_model_v1_archive_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Filters.ProtoReflect.Descriptor instead. +func (*Filters) Descriptor() ([]byte, []int) { + return file_proto_model_v1_archive_proto_rawDescGZIP(), []int{3} +} + +func (x *Filters) GetStart() *timestamppb.Timestamp { + if x != nil { + return x.Start + } + return nil +} + +func (x *Filters) GetEnd() *timestamppb.Timestamp { + if x != nil { + return x.End + } + return nil +} + +func (x *Filters) GetWith() string { + if x != nil { + return x.With + } + return "" +} + +func (x *Filters) GetBeforeId() string { + if x != nil { + return x.BeforeId + } + return "" +} + +func (x *Filters) GetAfterId() string { + if x != nil { + return x.AfterId + } + return "" +} + +func (x *Filters) GetIds() []string { + if x != nil { + return x.Ids + } + return nil +} + +var File_proto_model_v1_archive_proto protoreflect.FileDescriptor + +var file_proto_model_v1_archive_proto_rawDesc = []byte{ + 0x0a, 0x1c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2f, 0x76, 0x31, + 0x2f, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x10, + 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x2e, 0x76, 0x31, + 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x1a, 0x34, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x61, + 0x63, 0x6b, 0x61, 0x6c, 0x2d, 0x78, 0x6d, 0x70, 0x70, 0x2f, 0x73, 0x74, 0x72, 0x61, 0x76, 0x61, + 0x67, 0x61, 0x6e, 0x7a, 0x61, 0x2f, 0x73, 0x74, 0x72, 0x61, 0x76, 0x61, 0x67, 0x61, 0x6e, 0x7a, + 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xce, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, + 0x49, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, + 0x69, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x66, 0x72, 0x6f, 0x6d, 0x5f, 0x6a, 0x69, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x66, 0x72, 0x6f, 0x6d, 0x4a, 0x69, 0x64, 0x12, 0x15, 0x0a, + 0x06, 0x74, 0x6f, 0x5f, 0x6a, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, + 0x6f, 0x4a, 0x69, 0x64, 0x12, 0x30, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x73, 0x74, 0x72, 0x61, 0x76, 0x61, 0x67, 0x61, + 0x6e, 0x7a, 0x61, 0x2e, 0x50, 0x42, 0x45, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, + 0x70, 0x52, 0x05, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x22, 0x50, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x73, 0x12, 0x44, 0x0a, 0x10, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x5f, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x19, + 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x2e, 0x76, + 0x31, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x0f, 0x61, 0x72, 0x63, 0x68, 0x69, + 0x76, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x22, 0x8a, 0x01, 0x0a, 0x08, 0x4d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x49, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x65, + 0x6e, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6e, 0x64, + 0x49, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x6e, 0x64, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x65, 0x6e, 0x64, 0x54, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x22, 0xc7, 0x01, 0x0a, 0x07, 0x46, 0x69, 0x6c, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x30, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x03, + 0x65, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x77, 0x69, 0x74, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x77, 0x69, 0x74, 0x68, 0x12, 0x1b, 0x0a, 0x09, 0x62, 0x65, 0x66, 0x6f, 0x72, + 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x62, 0x65, 0x66, 0x6f, + 0x72, 0x65, 0x49, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, 0x69, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x66, 0x74, 0x65, 0x72, 0x49, 0x64, 0x12, + 0x10, 0x0a, 0x03, 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x64, + 0x73, 0x42, 0x21, 0x5a, 0x1f, 0x70, 0x6b, 0x67, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2f, 0x61, + 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x2f, 0x3b, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x6d, + 0x6f, 0x64, 0x65, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_proto_model_v1_archive_proto_rawDescOnce sync.Once + file_proto_model_v1_archive_proto_rawDescData = file_proto_model_v1_archive_proto_rawDesc +) + +func file_proto_model_v1_archive_proto_rawDescGZIP() []byte { + file_proto_model_v1_archive_proto_rawDescOnce.Do(func() { + file_proto_model_v1_archive_proto_rawDescData = protoimpl.X.CompressGZIP(file_proto_model_v1_archive_proto_rawDescData) + }) + return file_proto_model_v1_archive_proto_rawDescData +} + +var file_proto_model_v1_archive_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_proto_model_v1_archive_proto_goTypes = []interface{}{ + (*Message)(nil), // 0: model.archive.v1.Message + (*Messages)(nil), // 1: model.archive.v1.Messages + (*Metadata)(nil), // 2: model.archive.v1.Metadata + (*Filters)(nil), // 3: model.archive.v1.Filters + (*stravaganza.PBElement)(nil), // 4: stravaganza.PBElement + (*timestamppb.Timestamp)(nil), // 5: google.protobuf.Timestamp +} +var file_proto_model_v1_archive_proto_depIdxs = []int32{ + 4, // 0: model.archive.v1.Message.message:type_name -> stravaganza.PBElement + 5, // 1: model.archive.v1.Message.stamp:type_name -> google.protobuf.Timestamp + 0, // 2: model.archive.v1.Messages.archive_messages:type_name -> model.archive.v1.Message + 5, // 3: model.archive.v1.Filters.start:type_name -> google.protobuf.Timestamp + 5, // 4: model.archive.v1.Filters.end:type_name -> google.protobuf.Timestamp + 5, // [5:5] is the sub-list for method output_type + 5, // [5:5] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name +} + +func init() { file_proto_model_v1_archive_proto_init() } +func file_proto_model_v1_archive_proto_init() { + if File_proto_model_v1_archive_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proto_model_v1_archive_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_model_v1_archive_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Messages); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_model_v1_archive_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Metadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_model_v1_archive_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Filters); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proto_model_v1_archive_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proto_model_v1_archive_proto_goTypes, + DependencyIndexes: file_proto_model_v1_archive_proto_depIdxs, + MessageInfos: file_proto_model_v1_archive_proto_msgTypes, + }.Build() + File_proto_model_v1_archive_proto = out.File + file_proto_model_v1_archive_proto_rawDesc = nil + file_proto_model_v1_archive_proto_goTypes = nil + file_proto_model_v1_archive_proto_depIdxs = nil +} diff --git a/pkg/model/archive/codec.go b/pkg/model/archive/codec.go new file mode 100644 index 000000000..e3d126426 --- /dev/null +++ b/pkg/model/archive/codec.go @@ -0,0 +1,37 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package archivemodel + +import "google.golang.org/protobuf/proto" + +// MarshalBinary satisfies encoding.BinaryMarshaler interface. +func (x *Message) MarshalBinary() (data []byte, err error) { + return proto.Marshal(x) +} + +// UnmarshalBinary satisfies encoding.BinaryUnmarshaler interface. +func (x *Message) UnmarshalBinary(data []byte) error { + return proto.Unmarshal(data, x) +} + +// MarshalBinary satisfies encoding.BinaryMarshaler interface. +func (x *Messages) MarshalBinary() (data []byte, err error) { + return proto.Marshal(x) +} + +// UnmarshalBinary satisfies encoding.BinaryUnmarshaler interface. +func (x *Messages) UnmarshalBinary(data []byte) error { + return proto.Unmarshal(data, x) +} diff --git a/pkg/model/blocklist/blocklist.pb.go b/pkg/model/blocklist/blocklist.pb.go index aa55dec13..06b254ddd 100644 --- a/pkg/model/blocklist/blocklist.pb.go +++ b/pkg/model/blocklist/blocklist.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/model/v1/blocklist.proto package blocklistmodel diff --git a/pkg/model/caps/caps.pb.go b/pkg/model/caps/caps.pb.go index 30da67176..6a7ea0bb1 100644 --- a/pkg/model/caps/caps.pb.go +++ b/pkg/model/caps/caps.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/model/v1/caps.proto package capsmodel diff --git a/pkg/model/last/last.pb.go b/pkg/model/last/last.pb.go index e21a61af8..ebe05bc21 100644 --- a/pkg/model/last/last.pb.go +++ b/pkg/model/last/last.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/model/v1/last.proto package lastmodel diff --git a/pkg/model/roster/roster.pb.go b/pkg/model/roster/roster.pb.go index dbd15aa9b..6aa120fed 100644 --- a/pkg/model/roster/roster.pb.go +++ b/pkg/model/roster/roster.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/model/v1/roster.proto package rostermodel diff --git a/pkg/model/user/user.pb.go b/pkg/model/user/user.pb.go index dfb888413..ba960ae07 100644 --- a/pkg/model/user/user.pb.go +++ b/pkg/model/user/user.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.19.4 +// protoc v3.21.5 // source: proto/model/v1/user.proto package usermodel diff --git a/pkg/module/offline/interface.go b/pkg/module/offline/interface.go index 1b8962dae..388731e3a 100644 --- a/pkg/module/offline/interface.go +++ b/pkg/module/offline/interface.go @@ -17,6 +17,7 @@ package offline import ( "github.com/ortuman/jackal/pkg/cluster/resourcemanager" "github.com/ortuman/jackal/pkg/router" + "github.com/ortuman/jackal/pkg/router/stream" "github.com/ortuman/jackal/pkg/storage/repository" ) @@ -39,3 +40,8 @@ type hosts interface { type resourceManager interface { resourcemanager.Manager } + +//go:generate moq -out stream.mock_test.go . c2sStream +type c2sStream interface { + stream.C2S +} diff --git a/pkg/module/offline/offline.go b/pkg/module/offline/offline.go index 49f018468..588896299 100644 --- a/pkg/module/offline/offline.go +++ b/pkg/module/offline/offline.go @@ -23,10 +23,12 @@ import ( "github.com/go-kit/log/level" "github.com/jackal-xmpp/stravaganza" stanzaerror "github.com/jackal-xmpp/stravaganza/errors/stanza" - "github.com/ortuman/jackal/pkg/cluster/resourcemanager" + "github.com/jackal-xmpp/stravaganza/jid" "github.com/ortuman/jackal/pkg/hook" "github.com/ortuman/jackal/pkg/host" + "github.com/ortuman/jackal/pkg/module/xep0313" "github.com/ortuman/jackal/pkg/router" + "github.com/ortuman/jackal/pkg/router/stream" "github.com/ortuman/jackal/pkg/storage/repository" xmpputil "github.com/ortuman/jackal/pkg/util/xmpp" ) @@ -51,7 +53,6 @@ type Offline struct { cfg Config hosts hosts router router.Router - resMng resourcemanager.Manager rep repository.Repository hk *hook.Hooks logger kitlog.Logger @@ -62,7 +63,6 @@ func New( cfg Config, router router.Router, hosts *host.Hosts, - resMng resourcemanager.Manager, rep repository.Repository, hk *hook.Hooks, logger kitlog.Logger, @@ -71,7 +71,6 @@ func New( cfg: cfg, router: router, hosts: hosts, - resMng: resMng, rep: rep, hk: hk, logger: kitlog.With(logger, "module", ModuleName), @@ -96,8 +95,8 @@ func (m *Offline) AccountFeatures(_ context.Context) ([]string, error) { return // Start starts offline module. func (m *Offline) Start(_ context.Context) error { - m.hk.AddHook(hook.C2SStreamWillRouteElement, m.onWillRouteElement, hook.LowestPriority) - m.hk.AddHook(hook.S2SInStreamWillRouteElement, m.onWillRouteElement, hook.LowestPriority) + m.hk.AddHook(hook.C2SStreamMessageRouted, m.onMessageRouted, hook.LowestPriority) + m.hk.AddHook(hook.S2SInStreamMessageRouted, m.onMessageRouted, hook.LowestPriority) m.hk.AddHook(hook.C2SStreamPresenceReceived, m.onC2SPresenceRecv, hook.DefaultPriority) m.hk.AddHook(hook.UserDeleted, m.onUserDeleted, hook.DefaultPriority) @@ -108,8 +107,8 @@ func (m *Offline) Start(_ context.Context) error { // Stop stops offline module. func (m *Offline) Stop(_ context.Context) error { - m.hk.RemoveHook(hook.C2SStreamWillRouteElement, m.onWillRouteElement) - m.hk.RemoveHook(hook.S2SInStreamWillRouteElement, m.onWillRouteElement) + m.hk.RemoveHook(hook.C2SStreamMessageRouted, m.onMessageRouted) + m.hk.RemoveHook(hook.S2SInStreamMessageRouted, m.onMessageRouted) m.hk.RemoveHook(hook.C2SStreamPresenceReceived, m.onC2SPresenceRecv) m.hk.RemoveHook(hook.UserDeleted, m.onUserDeleted) @@ -118,15 +117,23 @@ func (m *Offline) Stop(_ context.Context) error { return nil } -func (m *Offline) onWillRouteElement(execCtx *hook.ExecutionContext) error { +func (m *Offline) onMessageRouted(execCtx *hook.ExecutionContext) error { var elem stravaganza.Element + var targets []jid.JID switch inf := execCtx.Info.(type) { case *hook.C2SStreamInfo: + targets = inf.Targets elem = inf.Element case *hook.S2SStreamInfo: + targets = inf.Targets elem = inf.Element } + // message was successufully routed to one of the available resources + if len(targets) > 0 { + return nil + } + msg, ok := elem.(*stravaganza.Message) if !ok || !isMessageArchievable(msg) { return nil @@ -135,17 +142,15 @@ func (m *Offline) onWillRouteElement(execCtx *hook.ExecutionContext) error { if !m.hosts.IsLocalHost(toJID.Domain()) { return nil } - rss, err := m.resMng.GetResources(execCtx.Context, toJID.Node()) - if err != nil { - return err - } - if len(rss) > 0 { - return nil - } return m.archiveMessage(execCtx.Context, msg) } func (m *Offline) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error { + stm := execCtx.Sender.(stream.C2S) + if xep0313.IsArchiveRequested(stm.Info()) { + // user has already queried the MAM archive. + return nil + } inf := execCtx.Info.(*hook.C2SStreamInfo) pr := inf.Element.(*stravaganza.Presence) @@ -156,7 +161,7 @@ func (m *Offline) onC2SPresenceRecv(execCtx *hook.ExecutionContext) error { if !pr.IsAvailable() || pr.Priority() < 0 { return nil } - return m.deliverOfflineMessages(execCtx.Context, toJID.Node()) + return m.deliverOfflineMessages(execCtx.Context, stm) } func (m *Offline) onUserDeleted(execCtx *hook.ExecutionContext) error { @@ -168,18 +173,20 @@ func (m *Offline) onUserDeleted(execCtx *hook.ExecutionContext) error { if err := m.rep.Lock(ctx, lockID); err != nil { return err } - defer func() { _ = m.rep.Unlock(ctx, lockID) }() + defer m.releaseLock(ctx, lockID) return m.rep.DeleteOfflineMessages(ctx, inf.Username) } -func (m *Offline) deliverOfflineMessages(ctx context.Context, username string) error { +func (m *Offline) deliverOfflineMessages(ctx context.Context, stm stream.C2S) error { + username := stm.Username() + lockID := offlineQueueLockID(username) if err := m.rep.Lock(ctx, lockID); err != nil { return err } - defer func() { _ = m.rep.Unlock(ctx, lockID) }() + defer m.releaseLock(ctx, lockID) ms, err := m.rep.FetchOfflineMessages(ctx, username) if err != nil { @@ -194,7 +201,7 @@ func (m *Offline) deliverOfflineMessages(ctx context.Context, username string) e } // route offline messages for _, msg := range ms { - _, _ = m.router.Route(ctx, msg) + stm.SendElement(msg) } level.Info(m.logger).Log("msg", "delivered offline messages", "queue_size", len(ms), "username", username) @@ -210,7 +217,7 @@ func (m *Offline) archiveMessage(ctx context.Context, msg *stravaganza.Message) if err := m.rep.Lock(ctx, lockID); err != nil { return err } - defer func() { _ = m.rep.Unlock(ctx, lockID) }() + defer m.releaseLock(ctx, lockID) qSize, err := m.rep.CountOfflineMessages(ctx, username) if err != nil { @@ -243,6 +250,12 @@ func (m *Offline) archiveMessage(ctx context.Context, msg *stravaganza.Message) return hook.ErrStopped // already handled } +func (m *Offline) releaseLock(ctx context.Context, lockID string) { + if err := m.rep.Unlock(ctx, lockID); err != nil { + level.Warn(m.logger).Log("msg", "failed to release lock", "err", err) + } +} + func isMessageArchievable(msg *stravaganza.Message) bool { if msg.ChildNamespace("no-store", hintsNamespace) != nil { return false diff --git a/pkg/module/offline/offline_test.go b/pkg/module/offline/offline_test.go index b59985572..b71e80596 100644 --- a/pkg/module/offline/offline_test.go +++ b/pkg/module/offline/offline_test.go @@ -43,15 +43,10 @@ func TestOffline_ArchiveOfflineMessage(t *testing.T) { hostsMock := &hostsMock{} hostsMock.IsLocalHostFunc = func(h string) bool { return h == "jackal.im" } - resManagerMock := &resourceManagerMock{} - resManagerMock.GetResourcesFunc = func(ctx context.Context, username string) ([]c2smodel.ResourceDesc, error) { - return nil, nil - } hk := hook.NewHooks() m := &Offline{ cfg: Config{QueueSize: 100}, hosts: hostsMock, - resMng: resManagerMock, rep: repMock, hk: hk, logger: kitlog.NewNopLogger(), @@ -70,7 +65,7 @@ func TestOffline_ArchiveOfflineMessage(t *testing.T) { _ = m.Start(context.Background()) defer func() { _ = m.Stop(context.Background()) }() - _, _ = hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + _, _ = hk.Run(hook.C2SStreamMessageRouted, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, @@ -114,7 +109,6 @@ func TestOffline_ArchiveOfflineMessageQueueFull(t *testing.T) { cfg: Config{QueueSize: 100}, router: routerMock, hosts: hostsMock, - resMng: resManagerMock, rep: repMock, hk: hk, logger: kitlog.NewNopLogger(), @@ -133,7 +127,7 @@ func TestOffline_ArchiveOfflineMessageQueueFull(t *testing.T) { _ = m.Start(context.Background()) defer func() { _ = m.Stop(context.Background()) }() - halted, err := hk.Run(hook.C2SStreamWillRouteElement, &hook.ExecutionContext{ + halted, err := hk.Run(hook.C2SStreamMessageRouted, &hook.ExecutionContext{ Info: &hook.C2SStreamInfo{ Element: msg, }, @@ -154,11 +148,6 @@ func TestOffline_DeliverOfflineMessages(t *testing.T) { // given routerMock := &routerMock{} - output := bytes.NewBuffer(nil) - routerMock.RouteFunc = func(ctx context.Context, stanza stravaganza.Stanza) ([]jid.JID, error) { - _ = stanza.ToXML(output, true) - return nil, nil - } hostsMock := &hostsMock{} hostsMock.IsLocalHostFunc = func(h string) bool { return h == "jackal.im" } @@ -186,6 +175,22 @@ func TestOffline_DeliverOfflineMessages(t *testing.T) { return nil } + stmMock := &c2sStreamMock{} + stmMock.UsernameFunc = func() string { + return "ortuman" + } + stmMock.InfoFunc = func() c2smodel.Info { + return c2smodel.NewInfoMap() + } + + output := bytes.NewBuffer(nil) + stmMock.SendElementFunc = func(elem stravaganza.Element) <-chan error { + _ = elem.ToXML(output, true) + ch := make(chan error) + close(ch) + return ch + } + hk := hook.NewHooks() m := &Offline{ cfg: Config{QueueSize: 100}, @@ -208,6 +213,7 @@ func TestOffline_DeliverOfflineMessages(t *testing.T) { Info: &hook.C2SStreamInfo{ Element: pr, }, + Sender: stmMock, Context: context.Background(), }) diff --git a/pkg/module/roster/roster.go b/pkg/module/roster/roster.go index 71f09229a..92b699be3 100644 --- a/pkg/module/roster/roster.go +++ b/pkg/module/roster/roster.go @@ -31,7 +31,6 @@ import ( "github.com/ortuman/jackal/pkg/host" rostermodel "github.com/ortuman/jackal/pkg/model/roster" "github.com/ortuman/jackal/pkg/router" - "github.com/ortuman/jackal/pkg/router/stream" "github.com/ortuman/jackal/pkg/storage/repository" xmpputil "github.com/ortuman/jackal/pkg/util/xmpp" ) @@ -204,7 +203,7 @@ func (r *Roster) sendRoster(ctx context.Context, iq *stravaganza.IQ) error { if err != nil { return err } - stm, err := r.getStream(usrJID.Node(), usrJID.Resource()) + stm, err := r.router.C2S().LocalStream(usrJID.Node(), usrJID.Resource()) if err != nil { return err } @@ -234,7 +233,7 @@ func (r *Roster) sendRoster(ctx context.Context, iq *stravaganza.IQ) error { if err != nil { return err } - stm, err := r.getStream(usrJID.Node(), usrJID.Resource()) + stm, err := r.router.C2S().LocalStream(usrJID.Node(), usrJID.Resource()) if err != nil { return err } @@ -553,7 +552,7 @@ func (r *Roster) processAvailability(ctx context.Context, presence *stravaganza. } isAvailable := presence.IsAvailable() if isAvailable { - stm, err := r.getStream(fromJID.Node(), fromJID.Resource()) + stm, err := r.router.C2S().LocalStream(fromJID.Node(), fromJID.Resource()) if err != nil { return err } @@ -831,14 +830,6 @@ func (r *Roster) routePresencesFrom(ctx context.Context, username string, toJID return nil } -func (r *Roster) getStream(username, resource string) (stream.C2S, error) { - stm := r.router.C2S().LocalStream(username, resource) - if stm == nil { - return nil, errStreamNotFound(username, resource) - } - return stm, nil -} - func (r *Roster) runHook(ctx context.Context, hookName string, inf *hook.RosterInfo) error { _, err := r.hk.Run(hookName, &hook.ExecutionContext{ Info: inf, @@ -914,7 +905,3 @@ func parseVer(ver string) int { } return 0 } - -func errStreamNotFound(username, resource string) error { - return fmt.Errorf("roster: local stream not found: %s/%s", username, resource) -} diff --git a/pkg/module/roster/roster_test.go b/pkg/module/roster/roster_test.go index 948797bcb..b9cda35c0 100644 --- a/pkg/module/roster/roster_test.go +++ b/pkg/module/roster/roster_test.go @@ -55,8 +55,8 @@ func TestRoster_SendRoster(t *testing.T) { return nil } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return stmMock + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil } routerMock := &routerMock{} @@ -870,8 +870,8 @@ func TestRoster_Available(t *testing.T) { return c2smodel.NewInfoMap() } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return stmMock + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil } routerMock := &routerMock{} diff --git a/pkg/module/xep0004/field.go b/pkg/module/xep0004/field.go index 6396290dd..102e1dbc7 100644 --- a/pkg/module/xep0004/field.go +++ b/pkg/module/xep0004/field.go @@ -71,6 +71,7 @@ type Field struct { Description string Values []string Options []Option + Validate *Validate } // NewFieldFromElement returns a new form field entity reading it from it's XML representation. @@ -110,6 +111,28 @@ func NewFieldFromElement(elem stravaganza.Element) (*Field, error) { } f.Options = append(f.Options, Option{Label: label, Value: value}) } + + validateElem := elem.ChildNamespace("validate", validateNamespace) + if validateElem != nil { + v := &Validate{ + DataType: validateElem.Attribute("datatype"), + } + if validateElem.Child("open") != nil { + v.Validator = &OpenValidator{} + } else if validateElem.Child("basic") != nil { + v.Validator = &BasicValidator{} + } else if rng := validateElem.Child("range"); rng != nil { + v.Validator = &RangeValidator{ + Max: rng.Attribute("max"), + Min: rng.Attribute("min"), + } + } else if rgx := validateElem.Child("regex"); rgx != nil { + v.Validator = &RegExValidator{ + RegEx: rgx.Text(), + } + } + f.Validate = v + } return f, nil } @@ -154,6 +177,9 @@ func (f *Field) Element() stravaganza.Element { ) b.WithChild(sb.Build()) } + if f.Validate != nil { + b.WithChild(f.Validate.Element()) + } return b.Build() } diff --git a/pkg/module/xep0004/field_test.go b/pkg/module/xep0004/field_test.go index 565678372..12d0702fc 100644 --- a/pkg/module/xep0004/field_test.go +++ b/pkg/module/xep0004/field_test.go @@ -116,6 +116,12 @@ func TestField_Element(t *testing.T) { f.Description = "A description" f.Values = []string{"A value"} f.Options = []Option{{"opt_label", "An option value"}} + f.Validate = &Validate{ + DataType: BooleanDataType, + Validator: &RegExValidator{ + RegEx: "([0-9]{3})-([0-9]{2})-([0-9]{4})", + }, + } elem := f.Element() require.Equal(t, "field", elem.Name()) @@ -134,4 +140,11 @@ func TestField_Element(t *testing.T) { valElem = optElem.Child("value") require.Equal(t, "An option value", valElem.Text()) + + validateElem := elem.ChildNamespace("validate", validateNamespace) + require.NotNil(t, validateElem) + + regexElem := validateElem.Child("regex") + require.NotNil(t, regexElem) + require.Equal(t, "([0-9]{3})-([0-9]{2})-([0-9]{4})", regexElem.Text()) } diff --git a/pkg/module/xep0004/fields.go b/pkg/module/xep0004/fields.go index e1dec7415..3cf37a6a7 100644 --- a/pkg/module/xep0004/fields.go +++ b/pkg/module/xep0004/fields.go @@ -42,7 +42,7 @@ func (f Fields) ValuesForFieldOfType(fieldName, typ string) []string { var res []string for _, field := range f { if field.Var == fieldName && field.Type == typ && len(field.Values) > 0 { - res = append(res, field.Values[0]) + res = append(res, field.Values...) } } return res diff --git a/pkg/module/xep0004/validate.go b/pkg/module/xep0004/validate.go new file mode 100644 index 000000000..a1651382b --- /dev/null +++ b/pkg/module/xep0004/validate.go @@ -0,0 +1,116 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0004 + +import "github.com/jackal-xmpp/stravaganza" + +const ( + // StringDataType datatype represents character strings in XML. + StringDataType = "xs:string" + + // BooleanDataType represents the values of two-valued logic. + BooleanDataType = "xs:boolean" + + // DecimalDataType represents a subset of the real numbers, which can be represented by decimal numerals. + DecimalDataType = "xs:decimal" + + // FloatDataType is patterned after the IEEE single-precision 32-bit floating point datatype + FloatDataType = "xs:float" + + // DoubleDataType is patterned after the IEEE double-precision 64-bit floating point datatype. + DoubleDataType = "xs:double" + + // DurationDataType is a datatype that represents durations of time. + DurationDataType = "xs:duration" + + // DateTimeDataType represents instants of time, optionally marked with a particular time zone offset. + DateTimeDataType = "xs:dateTime" + + // HexBinaryDataType represents arbitrary hex-encoded binary data. + HexBinaryDataType = "xs:hexBinary" + + // Base64BinaryDataType represents arbitrary Base64-encoded binary data + Base64BinaryDataType = "xs:base64Binary" +) + +const validateNamespace = "http://jabber.org/protocol/xdata-validate" + +// Validator defines validation type interface. +type Validator interface { + Element() stravaganza.Element +} + +// Validate represents a field validation type. +type Validate struct { + DataType string + Validator Validator +} + +// Element returns validation type element representation. +func (v *Validate) Element() stravaganza.Element { + b := stravaganza.NewBuilder("validate"). + WithAttribute(stravaganza.Namespace, validateNamespace). + WithAttribute("datatype", v.DataType) + if v.Validator != nil { + b.WithChild(v.Validator.Element()) + } + return b.Build() +} + +// OpenValidator represents open validation type. +type OpenValidator struct{} + +// Element satisfies Validator interface. +func (v *OpenValidator) Element() stravaganza.Element { + return stravaganza.NewBuilder("open").Build() +} + +// BasicValidator represents basic validation type. +type BasicValidator struct{} + +// Element satisfies Validator interface. +func (v *BasicValidator) Element() stravaganza.Element { + return stravaganza.NewBuilder("basic").Build() +} + +// RangeValidator represents range validation type. +type RangeValidator struct { + Min string + Max string +} + +// Element satisfies Validator interface. +func (v *RangeValidator) Element() stravaganza.Element { + b := stravaganza.NewBuilder("range") + if len(v.Min) > 0 { + b.WithAttribute("min", v.Min) + } + if len(v.Max) > 0 { + b.WithAttribute("max", v.Max) + } + return b.Build() +} + +// RegExValidator represents regex validation type. +type RegExValidator struct { + RegEx string +} + +// Element satisfies Validator interface. +func (v *RegExValidator) Element() stravaganza.Element { + b := stravaganza.NewBuilder("regex") + b.WithText(v.RegEx) + return b.Build() +} diff --git a/pkg/module/xep0004/validate_test.go b/pkg/module/xep0004/validate_test.go new file mode 100644 index 000000000..82f0194de --- /dev/null +++ b/pkg/module/xep0004/validate_test.go @@ -0,0 +1,40 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0004 + +import ( + "testing" + + "github.com/jackal-xmpp/stravaganza" + "github.com/stretchr/testify/require" +) + +func TestValidator_Element(t *testing.T) { + v := Validate{ + DataType: StringDataType, + Validator: &OpenValidator{}, + } + + elem := v.Element() + + require.NotNil(t, elem) + require.Equal(t, "validate", elem.Name()) + require.Equal(t, validateNamespace, elem.Attribute(stravaganza.Namespace)) + require.Equal(t, StringDataType, elem.Attribute("datatype")) + + validatorElem := elem.Child("open") + require.NotNil(t, validatorElem) + require.Equal(t, "open", validatorElem.Name()) +} diff --git a/pkg/module/xep0049/private.go b/pkg/module/xep0049/private.go index a568a60ce..3b1db61b4 100644 --- a/pkg/module/xep0049/private.go +++ b/pkg/module/xep0049/private.go @@ -18,6 +18,8 @@ import ( "context" "strings" + "github.com/jackal-xmpp/stravaganza/jid" + kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/jackal-xmpp/stravaganza" @@ -87,8 +89,8 @@ func (m *Private) MatchesNamespace(namespace string, serverTarget bool) bool { func (m *Private) ProcessIQ(ctx context.Context, iq *stravaganza.IQ) error { fromJid := iq.FromJID() toJid := iq.ToJID() - validTo := toJid.Node() == fromJid.Node() - if !validTo { + + if !fromJid.MatchesWithOptions(toJid, jid.MatchesBare) { _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.Forbidden)) return nil } diff --git a/pkg/module/xep0059/rsm.go b/pkg/module/xep0059/rsm.go new file mode 100644 index 000000000..c54d45ec0 --- /dev/null +++ b/pkg/module/xep0059/rsm.go @@ -0,0 +1,241 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0059 + +import ( + "errors" + "fmt" + "strconv" + + "github.com/jackal-xmpp/stravaganza" +) + +const ( + // RSMNamespace specifies XEP-0059 namespace constant value. + RSMNamespace = "http://jabber.org/protocol/rsm" +) + +var ( + // ErrPageNotFound will be returned by GetResultSetPage when page request cannot be satisfied. + ErrPageNotFound = errors.New("page not found") +) + +// Request represents a rsm request value. +type Request struct { + After string + Before string + Index int + Max int + LastPage bool +} + +// Result represents a rsm result value. +type Result struct { + Index int + First string + Last string + Count int + Complete bool +} + +// NewRequestFromElement returns a Request derived from an XML element. +func NewRequestFromElement(elem stravaganza.Element) (*Request, error) { + var req Request + var err error + + if n := elem.Name(); n != "set" { + return nil, fmt.Errorf("xep0059: invalid set name: %s", n) + } + if ns := elem.Attribute(stravaganza.Namespace); ns != RSMNamespace { + return nil, fmt.Errorf("xep0059: invalid set namespace: %s", ns) + } + if maxEl := elem.Child("max"); maxEl != nil { + req.Max, err = strconv.Atoi(maxEl.Text()) + if err != nil { + return nil, err + } + } + if indexEl := elem.Child("index"); indexEl != nil { + req.Index, err = strconv.Atoi(indexEl.Text()) + if err != nil { + return nil, err + } + } + if afterEl := elem.Child("after"); afterEl != nil { + req.After = afterEl.Text() + } + if beforeEl := elem.Child("before"); beforeEl != nil { + if beforeID := beforeEl.Text(); len(beforeID) > 0 { + req.Before = beforeID + } else { + req.LastPage = true + } + } + return &req, nil +} + +// Element returns XML representation of a Result instance. +func (r *Result) Element() stravaganza.Element { + sb := stravaganza.NewBuilder("set"). + WithAttribute(stravaganza.Namespace, RSMNamespace) + + if len(r.First) > 0 { + sb.WithChild( + stravaganza.NewBuilder("first"). + WithAttribute("index", strconv.Itoa(r.Index)). + WithText(r.First). + Build(), + ) + } + if len(r.Last) > 0 { + sb.WithChild( + stravaganza.NewBuilder("last"). + WithText(r.Last). + Build(), + ) + } + sb.WithChild( + stravaganza.NewBuilder("count"). + WithText(strconv.Itoa(r.Count)). + Build(), + ) + return sb.Build() +} + +// GetResultSetPage returns result page based on the passed request. +func GetResultSetPage[T any](rs []T, req *Request, getID func(i T) string) ([]T, *Result, error) { + var page []T + var res *Result + var err error + + switch { + case len(rs) == 0 && req.Index == 0: + return nil, &Result{Complete: true}, nil + + case req.LastPage: + page, res, err = getPageByIndex(rs, lastIndex(len(rs), req.Max), req.Max) + + case req.Index > 0: + page, res, err = getPageByIndex(rs, req.Index, req.Max) + + case len(req.After) > 0: + page, res, err = getPageAfterID(rs, getID, req.After, req.Max) + + case len(req.Before) > 0: + page, res, err = getPageBeforeID(rs, getID, req.Before, req.Max) + + case req.Max == 0: + return nil, &Result{Count: len(rs)}, nil + + default: + page, res, err = getPageByIndex(rs, 0, req.Max) // request first page + } + if err != nil { + return nil, nil, err + } + res.First = getID(page[0]) + res.Last = getID(page[len(page)-1]) + + return page, res, nil +} + +func getPageByIndex[T any](rs []T, idx, max int) ([]T, *Result, error) { + var page []T + var res Result + + i := idx * max + if i > len(rs)-1 { + return nil, nil, ErrPageNotFound + } + + lastIdx := len(rs) - 1 + for ; i < len(rs) && res.Count < max; i++ { + if i >= lastIdx { + res.Complete = true + } + page = append(page, rs[i]) + res.Count++ + } + res.Index = idx + + return page, &res, nil +} + +func getPageAfterID[T any](rs []T, getID func(i T) string, id string, max int) ([]T, *Result, error) { + var page []T + var res Result + + idIdx := getIDIndex(rs, getID, id) + if idIdx == -1 { + return nil, nil, ErrPageNotFound + } + startIdx := idIdx + 1 + + lastIdx := len(rs) - 1 + for i := startIdx; i < len(rs) && res.Count < max; i++ { + if i >= lastIdx { + res.Complete = true + } + page = append(page, rs[i]) + res.Count++ + } + res.Index = startIdx / max + + return page, &res, nil +} + +func getPageBeforeID[T any](rs []T, getID func(i T) string, id string, max int) ([]T, *Result, error) { + var page []T + var res Result + + idIdx := getIDIndex(rs, getID, id) + if idIdx == -1 { + return nil, nil, ErrPageNotFound + } + startIdx := idIdx - max + if startIdx < 0 { + startIdx = 0 + } + + lastIdx := len(rs) - 1 + for i := startIdx; i < len(rs) && res.Count < max; i++ { + if i >= lastIdx { + res.Complete = true + } + page = append(page, rs[i]) + res.Count++ + } + res.Index = startIdx / max + + return page, &res, nil +} + +func getIDIndex[T any](rs []T, getID func(i T) string, id string) int { + for i := 0; i < len(rs); i++ { + if getID(rs[i]) != id { + continue + } + return i + } + return -1 +} + +func lastIndex(len, max int) int { + li := len/max - 1 + if len%max > 0 { + li++ + } + return li +} diff --git a/pkg/module/xep0059/rsm_test.go b/pkg/module/xep0059/rsm_test.go new file mode 100644 index 000000000..0d70e6a75 --- /dev/null +++ b/pkg/module/xep0059/rsm_test.go @@ -0,0 +1,158 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0059 + +import ( + "testing" + + "github.com/jackal-xmpp/stravaganza" + "github.com/stretchr/testify/require" +) + +func TestRequest_NewFromElement(t *testing.T) { + // given + el := stravaganza.NewBuilder("set"). + WithAttribute(stravaganza.Namespace, RSMNamespace). + WithChild( + stravaganza.NewBuilder("max"). + WithText("10"). + Build(), + ). + WithChild( + stravaganza.NewBuilder("index"). + WithText("1"). + Build(), + ). + WithChild( + stravaganza.NewBuilder("after"). + WithText("peter@pixyland.org"). + Build(), + ). + WithChild( + stravaganza.NewBuilder("before"). + WithText("peter@rabbit.lit"). + Build(), + ). + Build() + + // when + req, err := NewRequestFromElement(el) + + // then + require.NoError(t, err) + + require.Equal(t, 10, req.Max) + require.Equal(t, 1, req.Index) + require.Equal(t, "peter@pixyland.org", req.After) + require.Equal(t, "peter@rabbit.lit", req.Before) +} + +func TestResult_Element(t *testing.T) { + // given + r := Result{ + Index: 1, + First: "f0", + Last: "l1", + Count: 800, + } + + // when + el := r.Element() + + // then + require.Equal(t, `f0l1800`, el.String()) +} + +func Test_GetResultSetPage(t *testing.T) { + tcs := map[string]struct { + rs []string + req Request + expectedPage []string + expectedResult Result + expectsError bool + }{ + "empty set": { + req: Request{Max: 10}, + expectedResult: Result{Count: 0, Complete: true}, + }, + "get page by index": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Index: 2, Max: 3}, + expectedPage: []string{"7", "8", "9"}, + expectedResult: Result{Index: 2, Count: 3, First: "7", Last: "9"}, + }, + "get out of bound index": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Index: 4, Max: 3}, + expectsError: true, + }, + "get last page": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{LastPage: true, Max: 3}, + expectedPage: []string{"10"}, + expectedResult: Result{Index: 3, Count: 1, First: "10", Last: "10", Complete: true}, + }, + "get page after id": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{After: "3", Max: 4}, + expectedPage: []string{"4", "5", "6", "7"}, + expectedResult: Result{Index: 0, Count: 4, First: "4", Last: "7"}, + }, + "get page after id - last page": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{After: "8", Max: 4}, + expectedPage: []string{"9", "10"}, + expectedResult: Result{Index: 2, Count: 2, First: "9", Last: "10", Complete: true}, + }, + "get page after id - not found": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{After: "11", Max: 4}, + expectsError: true, + }, + "get before id": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Before: "9", Max: 2}, + expectedPage: []string{"7", "8"}, + expectedResult: Result{Index: 3, Count: 2, First: "7", Last: "8"}, + }, + "get before id - first page": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Before: "2", Max: 4}, + expectedPage: []string{"1", "2", "3", "4"}, + expectedResult: Result{Index: 0, Count: 4, First: "1", Last: "4"}, + }, + "get before id - not found": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Before: "11", Max: 4}, + expectsError: true, + }, + "get results count": { + rs: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, + req: Request{Max: 0}, + expectedResult: Result{Count: 10}, + }, + } + for tName, tc := range tcs { + t.Run(tName, func(t *testing.T) { + page, res, err := GetResultSetPage(tc.rs, &tc.req, func(s string) string { return s }) + if tc.expectsError { + require.Error(t, err) + } else { + require.Equal(t, &tc.expectedResult, res) + require.Equal(t, tc.expectedPage, page) + } + }) + } +} diff --git a/pkg/module/xep0191/blocklist.go b/pkg/module/xep0191/blocklist.go index 5ea173e6a..9e818aecc 100644 --- a/pkg/module/xep0191/blocklist.go +++ b/pkg/module/xep0191/blocklist.go @@ -16,7 +16,6 @@ package xep0191 import ( "context" - "fmt" kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" @@ -111,7 +110,8 @@ func (m *BlockList) MatchesNamespace(namespace string, serverTarget bool) bool { func (m *BlockList) ProcessIQ(ctx context.Context, iq *stravaganza.IQ) error { fromJID := iq.FromJID() toJID := iq.ToJID() - if fromJID.Node() != toJID.Node() { + + if !fromJID.MatchesWithOptions(toJID, jid.MatchesBare) { _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.Forbidden)) return nil } @@ -288,10 +288,10 @@ func (m *BlockList) getBlockList(ctx context.Context, iq *stravaganza.IQ) error username := fromJID.Node() res := fromJID.Resource() - stm := m.router.C2S().LocalStream(username, res) - if stm == nil { + stm, err := m.router.C2S().LocalStream(username, res) + if err != nil { _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.InternalServerError)) - return fmt.Errorf("xep0191: local stream not found: %s/%s", username, res) + return err } if err := stm.SetInfoValue(ctx, blockListRequestedCtxKey, true); err != nil { _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.InternalServerError)) diff --git a/pkg/module/xep0191/blocklist_test.go b/pkg/module/xep0191/blocklist_test.go index 1a5f4ba75..7f19c63f4 100644 --- a/pkg/module/xep0191/blocklist_test.go +++ b/pkg/module/xep0191/blocklist_test.go @@ -47,8 +47,8 @@ func TestBlockList_GetBlockList(t *testing.T) { return nil } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return stmMock + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil } var respStanzas []stravaganza.Stanza diff --git a/pkg/module/xep0198/stream.go b/pkg/module/xep0198/stream.go index 3dab7ffb4..c6f991df6 100644 --- a/pkg/module/xep0198/stream.go +++ b/pkg/module/xep0198/stream.go @@ -25,20 +25,17 @@ import ( "sync" "time" - "github.com/ortuman/jackal/pkg/cluster/instance" - - clusterconnmanager "github.com/ortuman/jackal/pkg/cluster/connmanager" - - streamqueue "github.com/ortuman/jackal/pkg/module/xep0198/queue" - kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/jackal-xmpp/stravaganza" streamerror "github.com/jackal-xmpp/stravaganza/errors/stream" "github.com/jackal-xmpp/stravaganza/jid" + clusterconnmanager "github.com/ortuman/jackal/pkg/cluster/connmanager" + "github.com/ortuman/jackal/pkg/cluster/instance" "github.com/ortuman/jackal/pkg/cluster/resourcemanager" "github.com/ortuman/jackal/pkg/hook" "github.com/ortuman/jackal/pkg/host" + streamqueue "github.com/ortuman/jackal/pkg/module/xep0198/queue" xmppparser "github.com/ortuman/jackal/pkg/parser" "github.com/ortuman/jackal/pkg/router" "github.com/ortuman/jackal/pkg/router/stream" @@ -236,8 +233,11 @@ func (m *Stream) onDisconnect(execCtx *hook.ExecutionContext) error { inf := execCtx.Info.(*hook.C2SStreamInfo) discErr := inf.DisconnectError - _, ok := discErr.(*streamerror.Error) - if ok || errors.Is(discErr, xmppparser.ErrStreamClosedByPeer) { + _, isStreamErr := discErr.(*streamerror.Error) + + shouldHibernate := inf.Presence.IsAvailable() && !isStreamErr && !errors.Is(discErr, xmppparser.ErrStreamClosedByPeer) + + if !shouldHibernate { return nil } // schedule stream termination diff --git a/pkg/module/xep0199/ping.go b/pkg/module/xep0199/ping.go index 1eb3dd2c3..db4f3f621 100644 --- a/pkg/module/xep0199/ping.go +++ b/pkg/module/xep0199/ping.go @@ -214,7 +214,8 @@ func (p *Ping) timeout(jd *jid.JID) { // perform timeout action switch p.cfg.TimeoutAction { case killAction: - if stm := p.router.C2S().LocalStream(jd.Node(), jd.Resource()); stm != nil { + stm, _ := p.router.C2S().LocalStream(jd.Node(), jd.Resource()) + if stm != nil { _ = stm.Disconnect(streamerror.E(streamerror.ConnectionTimeout)) } } diff --git a/pkg/module/xep0199/ping_test.go b/pkg/module/xep0199/ping_test.go index d3643b390..4cab64230 100644 --- a/pkg/module/xep0199/ping_test.go +++ b/pkg/module/xep0199/ping_test.go @@ -111,8 +111,8 @@ func TestPing_Timeout(t *testing.T) { return nil } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return c2sStream + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return c2sStream, nil } routerMock.C2SFunc = func() router.C2SRouter { return c2sRouterMock diff --git a/pkg/module/xep0280/carbons.go b/pkg/module/xep0280/carbons.go index 02916a62c..2f4b8d000 100644 --- a/pkg/module/xep0280/carbons.go +++ b/pkg/module/xep0280/carbons.go @@ -16,7 +16,8 @@ package xep0280 import ( "context" - "fmt" + + "github.com/ortuman/jackal/pkg/module/xep0313" kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" @@ -36,7 +37,6 @@ const ( carbonsNamespace = "urn:xmpp:carbons:2" deliveryReceiptsNamespace = "urn:xmpp:receipts" - forwardingNamespace = "urn:xmpp:forward:0" chatStatesNamespace = "http://jabber.org/protocol/chatstates" hintsNamespace = "urn:xmpp:hints" ) @@ -97,8 +97,8 @@ func (p *Carbons) AccountFeatures(_ context.Context) ([]string, error) { func (p *Carbons) Start(_ context.Context) error { p.hk.AddHook(hook.C2SStreamWillRouteElement, p.onC2SElementWillRoute, hook.DefaultPriority) p.hk.AddHook(hook.S2SInStreamWillRouteElement, p.onS2SElementWillRoute, hook.DefaultPriority) - p.hk.AddHook(hook.C2SStreamMessageRouted, p.onC2SMessageRouted, hook.DefaultPriority) - p.hk.AddHook(hook.S2SInStreamMessageRouted, p.onS2SMessageRouted, hook.DefaultPriority) + p.hk.AddHook(hook.C2SStreamMessageRouted, p.onC2SMessageRouted, hook.LowestPriority+1) + p.hk.AddHook(hook.S2SInStreamMessageRouted, p.onS2SMessageRouted, hook.LowestPriority+1) level.Info(p.logger).Log("msg", "started carbons module") return nil @@ -175,7 +175,7 @@ func (p *Carbons) onS2SMessageRouted(execCtx *hook.ExecutionContext) error { if !ok { return nil } - return p.processMessage(ctx, msg, nil) + return p.processMessage(ctx, msg, inf.Targets) } func (p *Carbons) processIQ(ctx context.Context, iq *stravaganza.IQ) error { @@ -207,9 +207,9 @@ func (p *Carbons) processIQ(ctx context.Context, iq *stravaganza.IQ) error { } func (p *Carbons) setCarbonsEnabled(ctx context.Context, username, resource string, enabled bool) error { - stm := p.router.C2S().LocalStream(username, resource) - if stm == nil { - return errStreamNotFound(username, resource) + stm, err := p.router.C2S().LocalStream(username, resource) + if err != nil { + return err } return stm.SetInfoValue(ctx, carbonsEnabledCtxKey, enabled) } @@ -243,7 +243,7 @@ func (p *Carbons) routeSentCC(ctx context.Context, msg *stravaganza.Message, use if !res.Info().Bool(carbonsEnabledCtxKey) { continue } - _, _ = p.router.Route(ctx, sentMsgCC(msg, res.JID())) + _, _ = p.router.Route(ctx, sentMsgCC(ctx, msg, res.JID())) } return nil } @@ -257,7 +257,7 @@ func (p *Carbons) routeReceivedCC(ctx context.Context, msg *stravaganza.Message, if !res.Info().Bool(carbonsEnabledCtxKey) { continue } - _, _ = p.router.Route(ctx, receivedMsgCC(msg, res.JID())) + _, _ = p.router.Route(ctx, receivedMsgCC(ctx, msg, res.JID())) } return nil } @@ -316,46 +316,38 @@ func isCCMessage(msg *stravaganza.Message) bool { return msg.ChildNamespace("sent", carbonsNamespace) != nil || msg.ChildNamespace("received", carbonsNamespace) != nil } -func sentMsgCC(msg *stravaganza.Message, dest *jid.JID) *stravaganza.Message { +func sentMsgCC(ctx context.Context, originalMsg *stravaganza.Message, dest *jid.JID) *stravaganza.Message { + msg := originalMsg + if sentArchiveID := xep0313.ExtractSentArchiveID(ctx); len(sentArchiveID) > 0 { + msg = xmpputil.MakeStanzaIDMessage(msg, sentArchiveID, dest.ToBareJID().String()) + } ccMsg, _ := stravaganza.NewMessageBuilder(). WithAttribute(stravaganza.From, dest.ToBareJID().String()). WithAttribute(stravaganza.To, dest.String()). - WithAttribute(stravaganza.Type, stravaganza.ChatType). + WithAttribute(stravaganza.Type, msg.Type()). WithChild( stravaganza.NewBuilder("sent"). WithAttribute(stravaganza.Namespace, carbonsNamespace). - WithChild( - stravaganza.NewBuilder("forwarded"). - WithAttribute(stravaganza.Namespace, forwardingNamespace). - WithChild(msg). - Build(), - ). + WithChild(xmpputil.MakeForwardedStanza(msg, nil)). Build(), - ). - BuildMessage() + ).BuildMessage() return ccMsg } -func receivedMsgCC(msg *stravaganza.Message, dest *jid.JID) *stravaganza.Message { +func receivedMsgCC(ctx context.Context, originalMsg *stravaganza.Message, dest *jid.JID) *stravaganza.Message { + msg := originalMsg + if receivedArchiveID := xep0313.ExtractReceivedArchiveID(ctx); len(receivedArchiveID) > 0 { + msg = xmpputil.MakeStanzaIDMessage(msg, receivedArchiveID, dest.ToBareJID().String()) + } ccMsg, _ := stravaganza.NewMessageBuilder(). WithAttribute(stravaganza.From, dest.ToBareJID().String()). WithAttribute(stravaganza.To, dest.String()). - WithAttribute(stravaganza.Type, stravaganza.ChatType). + WithAttribute(stravaganza.Type, msg.Type()). WithChild( stravaganza.NewBuilder("received"). WithAttribute(stravaganza.Namespace, carbonsNamespace). - WithChild( - stravaganza.NewBuilder("forwarded"). - WithAttribute(stravaganza.Namespace, forwardingNamespace). - WithChild(msg). - Build(), - ). + WithChild(xmpputil.MakeForwardedStanza(msg, nil)). Build(), - ). - BuildMessage() + ).BuildMessage() return ccMsg } - -func errStreamNotFound(username, resource string) error { - return fmt.Errorf("xep0280: local stream not found: %s/%s", username, resource) -} diff --git a/pkg/module/xep0280/carbons_test.go b/pkg/module/xep0280/carbons_test.go index 2f8994aee..fee217402 100644 --- a/pkg/module/xep0280/carbons_test.go +++ b/pkg/module/xep0280/carbons_test.go @@ -44,8 +44,8 @@ func TestCarbons_Enable(t *testing.T) { return c2smodel.NewInfoMap() } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return stmMock + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil } routerMock := &routerMock{} @@ -112,8 +112,8 @@ func TestCarbons_Disable(t *testing.T) { return c2smodel.NewInfoMap() } c2sRouterMock := &c2sRouterMock{} - c2sRouterMock.LocalStreamFunc = func(username string, resource string) stream.C2S { - return stmMock + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil } routerMock := &routerMock{} diff --git a/pkg/module/xep0313/interface.go b/pkg/module/xep0313/interface.go new file mode 100644 index 000000000..3153fbbc5 --- /dev/null +++ b/pkg/module/xep0313/interface.go @@ -0,0 +1,51 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0313 + +import ( + "github.com/ortuman/jackal/pkg/router" + "github.com/ortuman/jackal/pkg/router/stream" + "github.com/ortuman/jackal/pkg/storage/repository" +) + +//go:generate moq -out repository.mock_test.go . globalRepository:repositoryMock +type globalRepository interface { + repository.Repository +} + +//go:generate moq -out tx.mock_test.go . repTransaction:txMock +type repTransaction interface { + repository.Transaction +} + +//go:generate moq -out router.mock_test.go . globalRouter:routerMock +type globalRouter interface { + router.Router +} + +//go:generate moq -out c2srouter.mock_test.go . c2sRouter +type c2sRouter interface { + router.C2SRouter +} + +//go:generate moq -out stream.mock_test.go . c2sStream +type c2sStream interface { + stream.C2S +} + +//go:generate moq -out hosts.mock_test.go . hosts +type hosts interface { + IsLocalHost(h string) bool +} diff --git a/pkg/module/xep0313/mam.go b/pkg/module/xep0313/mam.go new file mode 100644 index 000000000..4fa8e23df --- /dev/null +++ b/pkg/module/xep0313/mam.go @@ -0,0 +1,522 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0313 + +import ( + "context" + "errors" + "time" + + kitlog "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/google/uuid" + "github.com/jackal-xmpp/stravaganza" + stanzaerror "github.com/jackal-xmpp/stravaganza/errors/stanza" + "github.com/jackal-xmpp/stravaganza/jid" + "github.com/ortuman/jackal/pkg/hook" + "github.com/ortuman/jackal/pkg/host" + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + c2smodel "github.com/ortuman/jackal/pkg/model/c2s" + "github.com/ortuman/jackal/pkg/module/xep0004" + "github.com/ortuman/jackal/pkg/module/xep0059" + "github.com/ortuman/jackal/pkg/router" + "github.com/ortuman/jackal/pkg/storage/repository" + xmpputil "github.com/ortuman/jackal/pkg/util/xmpp" + "github.com/samber/lo" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + // ModuleName represents mam module name. + ModuleName = "mam" + + // XEPNumber represents mam XEP number. + XEPNumber = "0313" + + mamNamespace = "urn:xmpp:mam:2" + extendedMamNamespace = "urn:xmpp:mam:2#extended" + + dateTimeFormat = "2006-01-02T15:04:05Z" + + archiveRequestedCtxKey = "mam:requested" + + defaultPageSize = 50 + maxPageSize = 250 +) + +type archiveIDCtxKey int + +const ( + sentArchiveIDKey archiveIDCtxKey = iota + receivedArchiveIDKey +) + +// Config contains mam module configuration options. +type Config struct { + // QueueSize defines maximum number of archive messages stanzas. + // When the limit is reached, the oldest message will be purged to make room for the new one. + QueueSize int `fig:"queue_size" default:"1000"` +} + +// Mam represents a mam (XEP-0313) module type. +type Mam struct { + cfg Config + hosts hosts + router router.Router + hk *hook.Hooks + rep repository.Repository + logger kitlog.Logger +} + +// New returns a new initialized mam instance. +func New( + cfg Config, + router router.Router, + hosts *host.Hosts, + rep repository.Repository, + hk *hook.Hooks, + logger kitlog.Logger, +) *Mam { + return &Mam{ + cfg: cfg, + router: router, + hosts: hosts, + rep: rep, + hk: hk, + logger: kitlog.With(logger, "module", ModuleName, "xep", XEPNumber), + } +} + +// Name returns mam module name. +func (m *Mam) Name() string { return ModuleName } + +// StreamFeature returns mam module stream feature. +func (m *Mam) StreamFeature(_ context.Context, _ string) (stravaganza.Element, error) { + return nil, nil +} + +// ServerFeatures returns mam server disco features. +func (m *Mam) ServerFeatures(_ context.Context) ([]string, error) { + return nil, nil +} + +// AccountFeatures returns mam account disco features. +func (m *Mam) AccountFeatures(_ context.Context) ([]string, error) { + return []string{mamNamespace, extendedMamNamespace}, nil +} + +// Start starts mam module. +func (m *Mam) Start(_ context.Context) error { + m.hk.AddHook(hook.C2SStreamMessageReceived, m.onMessageReceived, hook.HighestPriority) + m.hk.AddHook(hook.S2SInStreamMessageReceived, m.onMessageReceived, hook.HighestPriority) + + m.hk.AddHook(hook.C2SStreamMessageRouted, m.onMessageRouted, hook.LowestPriority+2) + m.hk.AddHook(hook.S2SInStreamMessageRouted, m.onMessageRouted, hook.LowestPriority+2) + m.hk.AddHook(hook.UserDeleted, m.onUserDeleted, hook.DefaultPriority) + + level.Info(m.logger).Log("msg", "started mam module") + return nil +} + +// Stop stops mam module. +func (m *Mam) Stop(_ context.Context) error { + m.hk.RemoveHook(hook.C2SStreamMessageReceived, m.onMessageReceived) + m.hk.RemoveHook(hook.S2SInStreamMessageReceived, m.onMessageReceived) + m.hk.RemoveHook(hook.C2SStreamMessageRouted, m.onMessageRouted) + m.hk.RemoveHook(hook.S2SInStreamMessageRouted, m.onMessageRouted) + m.hk.RemoveHook(hook.UserDeleted, m.onUserDeleted) + + level.Info(m.logger).Log("msg", "stopped mam module") + return nil +} + +// MatchesNamespace tells whether namespace matches mam module. +func (m *Mam) MatchesNamespace(namespace string, serverTarget bool) bool { + if serverTarget { + return false + } + return namespace == mamNamespace +} + +// ProcessIQ process a mam iq. +func (m *Mam) ProcessIQ(ctx context.Context, iq *stravaganza.IQ) error { + fromJID := iq.FromJID() + toJID := iq.ToJID() + + if !fromJID.MatchesWithOptions(toJID, jid.MatchesBare) { + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.Forbidden)) + return nil + } + switch { + case iq.IsGet() && iq.ChildNamespace("metadata", mamNamespace) != nil: + return m.sendArchiveMetadata(ctx, iq) + + case iq.IsGet() && iq.ChildNamespace("query", mamNamespace) != nil: + return m.sendFormFields(ctx, iq) + + case iq.IsSet() && iq.ChildNamespace("query", mamNamespace) != nil: + return m.sendArchiveMessages(ctx, iq) + } + return nil +} + +func (m *Mam) sendArchiveMetadata(ctx context.Context, iq *stravaganza.IQ) error { + archiveID := iq.FromJID().Node() + + metadata, err := m.rep.FetchArchiveMetadata(ctx, archiveID) + if err != nil { + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.InternalServerError)) + return err + } + // send reply + metadataBuilder := stravaganza.NewBuilder("metadata").WithAttribute(stravaganza.Namespace, mamNamespace) + + startBuilder := stravaganza.NewBuilder("start") + if metadata != nil { + startBuilder.WithAttribute("id", metadata.StartId) + startBuilder.WithAttribute("timestamp", metadata.StartTimestamp) + } + endBuilder := stravaganza.NewBuilder("end") + if metadata != nil { + endBuilder.WithAttribute("id", metadata.EndId) + endBuilder.WithAttribute("timestamp", metadata.EndTimestamp) + } + + metadataBuilder.WithChildren(startBuilder.Build(), endBuilder.Build()) + + resIQ := xmpputil.MakeResultIQ(iq, metadataBuilder.Build()) + _, _ = m.router.Route(ctx, resIQ) + + level.Info(m.logger).Log("msg", "requested archive metadata", "archive_id", archiveID) + + return nil +} + +func (m *Mam) sendFormFields(ctx context.Context, iq *stravaganza.IQ) error { + form := xep0004.DataForm{ + Type: xep0004.Form, + } + + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.Hidden, + Var: xep0004.FormType, + Values: []string{mamNamespace}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.JidSingle, + Var: "with", + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.TextSingle, + Var: "start", + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.TextSingle, + Var: "end", + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.TextSingle, + Var: "before-id", + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.TextSingle, + Var: "after-id", + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.ListMulti, + Var: "ids", + Validate: &xep0004.Validate{ + DataType: xep0004.StringDataType, + Validator: &xep0004.OpenValidator{}, + }, + }) + + qChild := stravaganza.NewBuilder("query"). + WithAttribute(stravaganza.Namespace, mamNamespace). + WithChild(form.Element()). + Build() + + _, _ = m.router.Route(ctx, xmpputil.MakeResultIQ(iq, qChild)) + + level.Info(m.logger).Log("msg", "requested form fields") + + return nil +} + +func (m *Mam) sendArchiveMessages(ctx context.Context, iq *stravaganza.IQ) error { + fromJID := iq.FromJID() + + stm, err := m.router.C2S().LocalStream(fromJID.Node(), fromJID.Resource()) + if err != nil { + return err + } + + qChild := iq.ChildNamespace("query", mamNamespace) + + // filter archive result + filters := &archivemodel.Filters{} + if x := qChild.ChildNamespace("x", xep0004.FormNamespace); x != nil { + form, err := xep0004.NewFormFromElement(x) + if err != nil { + return err + } + filters, err = formToFilters(form) + if err != nil { + return err + } + } + archiveID := fromJID.Node() + + messages, err := m.rep.FetchArchiveMessages(ctx, filters, archiveID) + if err != nil { + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.InternalServerError)) + return err + } + + // return not found error if any requested id cannot be found + switch { + case len(filters.Ids) > 0 && (len(messages) != len(filters.Ids)): + fallthrough + + case (len(filters.AfterId) > 0 || len(filters.BeforeId) > 0) && len(messages) == 0: + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.ItemNotFound)) + return nil + } + + // apply RSM paging + var req *xep0059.Request + var res *xep0059.Result + + if set := qChild.ChildNamespace("set", xep0059.RSMNamespace); set != nil { + req, err = xep0059.NewRequestFromElement(set) + if err != nil { + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.BadRequest)) + return err + } + if req.Max > maxPageSize { + req.Max = maxPageSize + } + } else { + req = &xep0059.Request{Max: defaultPageSize} + } + messages, res, err = xep0059.GetResultSetPage(messages, req, func(m *archivemodel.Message) string { + return m.Id + }) + if err != nil { + if errors.Is(err, xep0059.ErrPageNotFound) { + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.ItemNotFound)) + return nil + } + _, _ = m.router.Route(ctx, xmpputil.MakeErrorStanza(iq, stanzaerror.InternalServerError)) + return err + } + + // flip result page + if qChild.Child("flip-page") != nil { + messages = lo.Reverse(messages) + + lastID := res.Last + res.Last = res.First + res.First = lastID + } + + // route archive messages + for _, msg := range messages { + msgStanza, _ := stravaganza.NewBuilderFromProto(msg.Message). + BuildStanza() + stamp := msg.Stamp.AsTime() + + resultElem := stravaganza.NewBuilder("result"). + WithAttribute(stravaganza.Namespace, mamNamespace). + WithAttribute("queryid", qChild.Attribute("queryid")). + WithAttribute(stravaganza.ID, uuid.New().String()). + WithChild(xmpputil.MakeForwardedStanza(msgStanza, &stamp)). + Build() + + archiveMsg, _ := stravaganza.NewMessageBuilder(). + WithAttribute(stravaganza.From, iq.ToJID().String()). + WithAttribute(stravaganza.To, iq.FromJID().String()). + WithAttribute(stravaganza.ID, uuid.New().String()). + WithChild(resultElem). + BuildMessage() + + _, _ = m.router.Route(ctx, archiveMsg) + } + + finB := stravaganza.NewBuilder("fin"). + WithChild(res.Element()). + WithAttribute(stravaganza.Namespace, mamNamespace) + if res.Complete { + finB.WithAttribute("complete", "true") + } + _, _ = m.router.Route(ctx, xmpputil.MakeResultIQ(iq, finB.Build())) + + level.Info(m.logger).Log("msg", "archive messages requested", "archive_id", fromJID.Node(), "count", len(messages), "complete", res.Complete) + + return stm.SetInfoValue(ctx, archiveRequestedCtxKey, true) +} + +func (m *Mam) onMessageReceived(execCtx *hook.ExecutionContext) error { + var msg *stravaganza.Message + + switch inf := execCtx.Info.(type) { + case *hook.C2SStreamInfo: + msg = inf.Element.(*stravaganza.Message) + inf.Element = m.addRecipientStanzaID(msg) + execCtx.Info = inf + + case *hook.S2SStreamInfo: + msg = inf.Element.(*stravaganza.Message) + inf.Element = m.addRecipientStanzaID(msg) + execCtx.Info = inf + } + return nil +} + +func (m *Mam) onMessageRouted(execCtx *hook.ExecutionContext) error { + var elem stravaganza.Element + + switch inf := execCtx.Info.(type) { + case *hook.C2SStreamInfo: + elem = inf.Element + case *hook.S2SStreamInfo: + elem = inf.Element + } + return m.handleRoutedMessage(execCtx, elem) +} + +func (m *Mam) onUserDeleted(execCtx *hook.ExecutionContext) error { + inf := execCtx.Info.(*hook.UserInfo) + return m.rep.DeleteArchive(execCtx.Context, inf.Username) +} + +func (m *Mam) handleRoutedMessage(execCtx *hook.ExecutionContext, elem stravaganza.Element) error { + msg, ok := elem.(*stravaganza.Message) + if !ok { + return nil + } + if !isMessageArchievable(msg) { + return nil + } + + fromJID := msg.FromJID() + if m.hosts.IsLocalHost(fromJID.Domain()) { + sentArchiveID := uuid.New().String() + archiveMsg := xmpputil.MakeStanzaIDMessage(msg, sentArchiveID, fromJID.ToBareJID().String()) + if err := m.archiveMessage(execCtx.Context, archiveMsg, fromJID.Node(), sentArchiveID); err != nil { + return err + } + execCtx.Context = context.WithValue(execCtx.Context, sentArchiveIDKey, sentArchiveID) + } + toJID := msg.ToJID() + if !m.hosts.IsLocalHost(toJID.Domain()) { + return nil + } + recievedArchiveID := xmpputil.MessageStanzaID(msg) + if err := m.archiveMessage(execCtx.Context, msg, toJID.Node(), recievedArchiveID); err != nil { + return err + } + execCtx.Context = context.WithValue(execCtx.Context, receivedArchiveIDKey, recievedArchiveID) + return nil +} + +func (m *Mam) archiveMessage(ctx context.Context, message *stravaganza.Message, archiveID, id string) error { + return m.rep.InTransaction(ctx, func(ctx context.Context, tx repository.Transaction) error { + err := tx.InsertArchiveMessage(ctx, &archivemodel.Message{ + ArchiveId: archiveID, + Id: id, + FromJid: message.FromJID().String(), + ToJid: message.ToJID().String(), + Message: message.Proto(), + Stamp: timestamppb.Now(), + }) + if err != nil { + return err + } + return tx.DeleteArchiveOldestMessages(ctx, archiveID, m.cfg.QueueSize) + }) +} + +func (m *Mam) addRecipientStanzaID(originalMsg *stravaganza.Message) *stravaganza.Message { + toJID := originalMsg.ToJID() + if !m.hosts.IsLocalHost(toJID.Domain()) { + return originalMsg + } + archiveID := uuid.New().String() + return xmpputil.MakeStanzaIDMessage(originalMsg, archiveID, toJID.ToBareJID().String()) +} + +// IsArchiveRequested determines whether archive has been requested over a C2S stream by inspecting inf parameter. +func IsArchiveRequested(inf c2smodel.Info) bool { + return inf.Bool(archiveRequestedCtxKey) +} + +// ExtractSentArchiveID returns message sent archive ID by inspecting the passed context. +func ExtractSentArchiveID(ctx context.Context) string { + ret, ok := ctx.Value(sentArchiveIDKey).(string) + if ok { + return ret + } + return "" +} + +// ExtractReceivedArchiveID returns message received archive ID by inspecting the passed context. +func ExtractReceivedArchiveID(ctx context.Context) string { + ret, ok := ctx.Value(receivedArchiveIDKey).(string) + if ok { + return ret + } + return "" +} + +func formToFilters(fm *xep0004.DataForm) (*archivemodel.Filters, error) { + var retVal archivemodel.Filters + + fmType := fm.Fields.ValueForFieldOfType(xep0004.FormType, xep0004.Hidden) + if fm.Type != xep0004.Submit || fmType != mamNamespace { + return nil, errors.New("unexpected form type value") + } + if start := fm.Fields.ValueForField("start"); len(start) > 0 { + startTm, err := time.Parse(dateTimeFormat, start) + if err != nil { + return nil, err + } + retVal.Start = timestamppb.New(startTm) + } + if end := fm.Fields.ValueForField("end"); len(end) > 0 { + endTm, err := time.Parse(dateTimeFormat, end) + if err != nil { + return nil, err + } + retVal.End = timestamppb.New(endTm) + } + if with := fm.Fields.ValueForField("with"); len(with) > 0 { + retVal.With = with + } + if beforeID := fm.Fields.ValueForField("before-id"); len(beforeID) > 0 { + retVal.BeforeId = beforeID + } + if afterID := fm.Fields.ValueForField("after-id"); len(afterID) > 0 { + retVal.AfterId = afterID + } + if ids := fm.Fields.ValuesForField("ids"); len(ids) > 0 { + retVal.Ids = ids + } + return &retVal, nil +} + +func isMessageArchievable(msg *stravaganza.Message) bool { + return (msg.IsNormal() || msg.IsChat()) && msg.IsMessageWithBody() +} diff --git a/pkg/module/xep0313/mam_test.go b/pkg/module/xep0313/mam_test.go new file mode 100644 index 000000000..a0f410536 --- /dev/null +++ b/pkg/module/xep0313/mam_test.go @@ -0,0 +1,477 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xep0313 + +import ( + "context" + "errors" + "testing" + "time" + + kitlog "github.com/go-kit/log" + "github.com/jackal-xmpp/stravaganza" + "github.com/jackal-xmpp/stravaganza/jid" + "github.com/ortuman/jackal/pkg/hook" + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + c2smodel "github.com/ortuman/jackal/pkg/model/c2s" + "github.com/ortuman/jackal/pkg/module/xep0004" + "github.com/ortuman/jackal/pkg/module/xep0059" + "github.com/ortuman/jackal/pkg/router" + "github.com/ortuman/jackal/pkg/router/stream" + "github.com/ortuman/jackal/pkg/storage/repository" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestMam_FormFields(t *testing.T) { + // given + routerMock := &routerMock{} + + var respStanzas []stravaganza.Stanza + routerMock.RouteFunc = func(ctx context.Context, stanza stravaganza.Stanza) ([]jid.JID, error) { + respStanzas = append(respStanzas, stanza) + return nil, nil + } + mam := &Mam{ + router: routerMock, + logger: kitlog.NewNopLogger(), + } + + iq, _ := stravaganza.NewIQBuilder(). + WithAttribute(stravaganza.ID, "form1"). + WithAttribute(stravaganza.Type, stravaganza.GetType). + WithAttribute(stravaganza.From, "ortuman@jackal.im/chamber"). + WithAttribute(stravaganza.To, "ortuman@jackal.im"). + WithChild( + stravaganza.NewBuilder("query"). + WithAttribute(stravaganza.Namespace, mamNamespace). + Build(), + ). + BuildIQ() + + // when + _ = mam.ProcessIQ(context.Background(), iq) + + // then + require.Len(t, respStanzas, 1) + require.Equal(t, "iq", respStanzas[0].Name()) + require.Equal(t, stravaganza.ResultType, respStanzas[0].Type()) + + qChild := respStanzas[0].ChildNamespace("query", mamNamespace) + require.NotNil(t, qChild) + + x := qChild.ChildNamespace("x", xep0004.FormNamespace) + require.NotNil(t, x) + + form, _ := xep0004.NewFormFromElement(x) + require.NotNil(t, form) + + require.Len(t, form.Fields, 7) +} + +func TestMam_Metadata(t *testing.T) { + // given + routerMock := &routerMock{} + + var respStanzas []stravaganza.Stanza + routerMock.RouteFunc = func(ctx context.Context, stanza stravaganza.Stanza) ([]jid.JID, error) { + respStanzas = append(respStanzas, stanza) + return nil, nil + } + repMock := &repositoryMock{} + repMock.FetchArchiveMetadataFunc = func(ctx context.Context, archiveID string) (*archivemodel.Metadata, error) { + return &archivemodel.Metadata{ + StartId: "s0", + StartTimestamp: "2008-08-22T21:09:04Z", + EndId: "e0", + EndTimestamp: "2020-04-20T14:34:21Z", + }, nil + } + mam := &Mam{ + rep: repMock, + router: routerMock, + logger: kitlog.NewNopLogger(), + } + + iq, _ := stravaganza.NewIQBuilder(). + WithAttribute(stravaganza.ID, "form1"). + WithAttribute(stravaganza.Type, stravaganza.GetType). + WithAttribute(stravaganza.From, "ortuman@jackal.im/chamber"). + WithAttribute(stravaganza.To, "ortuman@jackal.im"). + WithChild( + stravaganza.NewBuilder("metadata"). + WithAttribute(stravaganza.Namespace, mamNamespace). + Build(), + ). + BuildIQ() + + // when + _ = mam.ProcessIQ(context.Background(), iq) + + // then + require.Len(t, respStanzas, 1) + require.Equal(t, "iq", respStanzas[0].Name()) + require.Equal(t, stravaganza.ResultType, respStanzas[0].Type()) + + metadata := respStanzas[0].ChildNamespace("metadata", mamNamespace) + require.NotNil(t, metadata) + + start := metadata.Child("start") + require.NotNil(t, start) + require.Equal(t, "s0", start.Attribute("id")) + require.Equal(t, "2008-08-22T21:09:04Z", start.Attribute("timestamp")) + + end := metadata.Child("end") + require.NotNil(t, start) + require.Equal(t, "e0", end.Attribute("id")) + require.Equal(t, "2020-04-20T14:34:21Z", end.Attribute("timestamp")) +} + +func TestMam_ArchiveMessage(t *testing.T) { + // given + var archivedMessages []*archivemodel.Message + + txMock := &txMock{} + txMock.DeleteArchiveOldestMessagesFunc = func(ctx context.Context, archiveID string, maxElements int) error { + return nil + } + txMock.InsertArchiveMessageFunc = func(ctx context.Context, message *archivemodel.Message) error { + archivedMessages = append(archivedMessages, message) + return nil + } + + repMock := &repositoryMock{} + repMock.InTransactionFunc = func(ctx context.Context, f func(ctx context.Context, tx repository.Transaction) error) error { + return f(ctx, txMock) + } + + hosts := &hostsMock{} + hosts.IsLocalHostFunc = func(h string) bool { return h == "jackal.im" } + + hk := hook.NewHooks() + mam := &Mam{ + hk: hk, + hosts: hosts, + rep: repMock, + logger: kitlog.NewNopLogger(), + } + _ = mam.Start(context.Background()) + t.Cleanup(func() { + _ = mam.Stop(context.Background()) + }) + + msg := testMessageStanzaWithParameters("b0", "ortuman@jackal.im/chamber", "noelia@jackal.im/yard") + + // when + execCtx := &hook.ExecutionContext{ + Info: &hook.C2SStreamInfo{ + Element: msg, + }, + Context: context.Background(), + } + _, err := hk.Run(hook.C2SStreamMessageReceived, execCtx) + require.NoError(t, err) + + _, err = hk.Run(hook.C2SStreamMessageRouted, execCtx) + require.NoError(t, err) + + // then + require.NoError(t, err) + require.Len(t, archivedMessages, 2) + + require.Equal(t, "ortuman", archivedMessages[0].ArchiveId) + require.Equal(t, "noelia", archivedMessages[1].ArchiveId) + + require.Len(t, txMock.DeleteArchiveOldestMessagesCalls(), 2) + require.Len(t, txMock.InsertArchiveMessageCalls(), 2) + + require.True(t, len(ExtractSentArchiveID(execCtx.Context)) > 0) + require.True(t, len(ExtractReceivedArchiveID(execCtx.Context)) > 0) +} + +func TestMam_SendArchiveMessages(t *testing.T) { + // given + archiveMessages := []*archivemodel.Message{ + { + ArchiveId: "ortuman", + Stamp: timestamppb.New(time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)), + FromJid: "ortuman@jackal.im/chamber", + ToJid: "noelia@jackal.im/yard", + Message: testMessageStanzaWithParameters( + "b0", + "ortuman@jackal.im/chamber", + "noelia@jackal.im/yard", + ).Proto(), + }, + { + ArchiveId: "ortuman", + Stamp: timestamppb.New(time.Date(2022, 01, 01, 01, 00, 00, 00, time.UTC)), + FromJid: "noelia@jackal.im/yard", + ToJid: "ortuman@jackal.im/chamber", + Message: testMessageStanzaWithParameters( + "b1", + "noelia@jackal.im/yard", + "ortuman@jackal.im/chamber", + ).Proto(), + }, + { + ArchiveId: "ortuman", + Stamp: timestamppb.New(time.Date(2022, 01, 01, 02, 00, 00, 00, time.UTC)), + FromJid: "ortuman@jackal.im/chamber", + ToJid: "noelia@jackal.im/yard", + Message: testMessageStanzaWithParameters( + "b2", + "ortuman@jackal.im/chamber", + "noelia@jackal.im/yard", + ).Proto(), + }, + } + + c2sInf := c2smodel.NewInfoMap() + + stmMock := &c2sStreamMock{} + stmMock.SetInfoValueFunc = func(ctx context.Context, k string, val interface{}) error { + bVal, ok := val.(bool) + if !ok { + return errors.New("unexpected value type") + } + c2sInf.SetBool(k, bVal) + return nil + } + + c2sRouterMock := &c2sRouterMock{} + c2sRouterMock.LocalStreamFunc = func(username string, resource string) (stream.C2S, error) { + return stmMock, nil + } + + routerMock := &routerMock{} + + var respStanzas []stravaganza.Stanza + routerMock.RouteFunc = func(ctx context.Context, stanza stravaganza.Stanza) ([]jid.JID, error) { + respStanzas = append(respStanzas, stanza) + return nil, nil + } + routerMock.C2SFunc = func() router.C2SRouter { + return c2sRouterMock + } + + repMock := &repositoryMock{} + repMock.FetchArchiveMessagesFunc = func(ctx context.Context, f *archivemodel.Filters, archiveID string) ([]*archivemodel.Message, error) { + return archiveMessages, nil + } + + mam := &Mam{ + rep: repMock, + router: routerMock, + logger: kitlog.NewNopLogger(), + } + + iq, _ := stravaganza.NewIQBuilder(). + WithAttribute(stravaganza.ID, "ortuman1"). + WithAttribute(stravaganza.Type, stravaganza.SetType). + WithAttribute(stravaganza.From, "ortuman@jackal.im/chamber"). + WithAttribute(stravaganza.To, "ortuman@jackal.im"). + WithChild( + stravaganza.NewBuilder("query"). + WithAttribute(stravaganza.Namespace, mamNamespace). + Build(), + ). + BuildIQ() + + // when + _ = mam.ProcessIQ(context.Background(), iq) + + // then + require.Len(t, respStanzas, 4) // 3 messages + result iq + + require.Equal(t, stravaganza.MessageName, respStanzas[0].Name()) + require.Equal(t, stravaganza.MessageName, respStanzas[1].Name()) + require.Equal(t, stravaganza.MessageName, respStanzas[2].Name()) + require.Equal(t, stravaganza.IQName, respStanzas[3].Name()) + + iqRes := respStanzas[3] + require.Equal(t, stravaganza.ResultType, iqRes.Type()) + + finElem := iqRes.ChildNamespace("fin", mamNamespace) + require.NotNil(t, finElem) + + rsmRes := finElem.ChildNamespace("set", xep0059.RSMNamespace) + require.NotNil(t, rsmRes) + + count := rsmRes.Child("count") + require.NotNil(t, count) + require.Equal(t, "3", count.Text()) + + require.Len(t, stmMock.SetInfoValueCalls(), 1) + require.True(t, IsArchiveRequested(c2sInf)) +} + +func TestMam_Forbidden(t *testing.T) { + routerMock := &routerMock{} + + var respStanzas []stravaganza.Stanza + routerMock.RouteFunc = func(ctx context.Context, stanza stravaganza.Stanza) ([]jid.JID, error) { + respStanzas = append(respStanzas, stanza) + return nil, nil + } + + repMock := &repositoryMock{} + + mam := &Mam{ + rep: repMock, + router: routerMock, + logger: kitlog.NewNopLogger(), + } + + iq, _ := stravaganza.NewIQBuilder(). + WithAttribute(stravaganza.ID, "ortuman1"). + WithAttribute(stravaganza.Type, stravaganza.SetType). + WithAttribute(stravaganza.From, "noelia@jackal.im/chamber"). + WithAttribute(stravaganza.To, "ortuman@jackal.im"). + WithChild( + stravaganza.NewBuilder("query"). + WithAttribute(stravaganza.Namespace, mamNamespace). + Build(), + ). + BuildIQ() + + // when + _ = mam.ProcessIQ(context.Background(), iq) + + require.Len(t, respStanzas, 1) + require.Equal(t, stravaganza.ErrorType, respStanzas[0].Attribute(stravaganza.Type)) +} + +func TestMam_DeleteArchive(t *testing.T) { + // given + var deletedArchiveID string + + repMock := &repositoryMock{} + repMock.DeleteArchiveFunc = func(ctx context.Context, archiveID string) error { + deletedArchiveID = archiveID + return nil + } + + hosts := &hostsMock{} + hosts.IsLocalHostFunc = func(h string) bool { return h == "jackal.im" } + + hk := hook.NewHooks() + mam := &Mam{ + hk: hk, + hosts: hosts, + rep: repMock, + logger: kitlog.NewNopLogger(), + } + _ = mam.Start(context.Background()) + t.Cleanup(func() { + _ = mam.Stop(context.Background()) + }) + + // when + _, err := hk.Run(hook.UserDeleted, &hook.ExecutionContext{ + Info: &hook.UserInfo{ + Username: "ortuman", + }, + Context: context.Background(), + }, + ) + + // then + require.NoError(t, err) + require.Len(t, repMock.DeleteArchiveCalls(), 1) + + require.Equal(t, "ortuman", deletedArchiveID) +} + +func TestMam_FormToFields(t *testing.T) { + tcs := map[string]struct { + form *xep0004.DataForm + filters *archivemodel.Filters + }{ + "by jid": { + form: &xep0004.DataForm{ + Type: xep0004.Submit, + Fields: []xep0004.Field{ + {Var: xep0004.FormType, Type: xep0004.Hidden, Values: []string{mamNamespace}}, + {Var: "with", Values: []string{"juliet@capulet.lit"}}, + }, + }, + filters: &archivemodel.Filters{ + With: "juliet@capulet.lit", + }, + }, + "time received": { + form: &xep0004.DataForm{ + Type: xep0004.Submit, + Fields: []xep0004.Field{ + {Var: xep0004.FormType, Type: xep0004.Hidden, Values: []string{mamNamespace}}, + {Var: "start", Values: []string{"2010-06-07T00:00:00Z"}}, + {Var: "end", Values: []string{"2010-07-07T13:23:54Z"}}, + }, + }, + filters: &archivemodel.Filters{ + Start: timestamppb.New(time.Date(2010, 06, 07, 00, 00, 00, 00, time.UTC)), + End: timestamppb.New(time.Date(2010, 07, 07, 13, 23, 54, 00, time.UTC)), + }, + }, + "after/before id": { + form: &xep0004.DataForm{ + Type: xep0004.Submit, + Fields: []xep0004.Field{ + {Var: xep0004.FormType, Type: xep0004.Hidden, Values: []string{mamNamespace}}, + {Var: "after-id", Values: []string{"28482-98726-73623"}}, + {Var: "before-id", Values: []string{"09af3-cc343-b409f"}}, + }, + }, + filters: &archivemodel.Filters{ + AfterId: "28482-98726-73623", + BeforeId: "09af3-cc343-b409f", + }, + }, + "ids": { + form: &xep0004.DataForm{ + Type: xep0004.Submit, + Fields: []xep0004.Field{ + {Var: xep0004.FormType, Type: xep0004.Hidden, Values: []string{mamNamespace}}, + {Var: "ids", Values: []string{"28482-98726-73623", "09af3-cc343-b409f"}}, + }, + }, + filters: &archivemodel.Filters{ + Ids: []string{"28482-98726-73623", "09af3-cc343-b409f"}, + }, + }, + } + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + filters, err := formToFilters(tc.form) + + require.NoError(t, err) + require.Equal(t, tc.filters.String(), filters.String()) + }) + } +} + +func testMessageStanzaWithParameters(body, from, to string) *stravaganza.Message { + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", from) + b.WithAttribute("to", to) + b.WithChild( + stravaganza.NewBuilder("body"). + WithText(body). + Build(), + ) + msg, _ := b.BuildMessage() + return msg +} diff --git a/pkg/router/router.go b/pkg/router/router.go index 1944ecf2a..beda06ab4 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -72,7 +72,7 @@ type C2SRouter interface { Unregister(stm stream.C2S) error // LocalStream returns local instance stream. - LocalStream(username, resource string) stream.C2S + LocalStream(username, resource string) (stream.C2S, error) // Start starts C2S router subsystem. Start(ctx context.Context) error diff --git a/pkg/s2s/in.go b/pkg/s2s/in.go index 695c8d202..389c121da 100644 --- a/pkg/s2s/in.go +++ b/pkg/s2s/in.go @@ -17,6 +17,7 @@ package s2s import ( "context" "crypto/tls" + "errors" "sync" "sync/atomic" "time" @@ -401,7 +402,7 @@ func (s *inS2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { if !ok { return nil } - _, err = s.router.Route(ctx, outIQ) + targets, err := s.router.Route(ctx, outIQ) switch err { case router.ErrResourceNotFound: return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, iq).Element()) @@ -412,11 +413,12 @@ func (s *inS2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error { case router.ErrRemoteServerTimeout: return s.sendElement(ctx, stanzaerror.E(stanzaerror.RemoteServerTimeout, iq).Element()) - case nil: + case nil, router.ErrUserNotAvailable: _, err = s.runHook(ctx, hook.S2SInStreamIQRouted, &hook.S2SStreamInfo{ ID: s.ID().String(), Sender: s.sender, Target: s.target, + Targets: targets, Element: iq, }) return err @@ -456,7 +458,7 @@ sendMsg: if !ok { return nil } - _, err = s.router.Route(ctx, outMsg) + targets, err := s.router.Route(ctx, outMsg) switch err { case router.ErrResourceNotFound: // treat the stanza as if it were addressed to @@ -475,19 +477,25 @@ sendMsg: case router.ErrRemoteServerTimeout: return s.sendElement(ctx, stanzaerror.E(stanzaerror.RemoteServerTimeout, message).Element()) - case router.ErrUserNotAvailable: - return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, message).Element()) - - case nil: - _, err = s.runHook(ctx, hook.S2SInStreamMessageRouted, &hook.S2SStreamInfo{ + case nil, router.ErrUserNotAvailable: + halted, hErr := s.runHook(ctx, hook.S2SInStreamMessageRouted, &hook.S2SStreamInfo{ ID: s.ID().String(), Sender: s.sender, Target: s.target, + Targets: targets, Element: msg, }) + if halted { + return nil + } + if errors.Is(err, router.ErrUserNotAvailable) { + return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, message).Element()) + } + return hErr + + default: return err } - return nil } func (s *inS2S) processPresence(ctx context.Context, presence *stravaganza.Presence) error { @@ -520,13 +528,14 @@ func (s *inS2S) processPresence(ctx context.Context, presence *stravaganza.Prese if !ok { return nil } - _, err = s.router.Route(ctx, outPr) + targets, err := s.router.Route(ctx, outPr) switch err { - case nil: + case nil, router.ErrUserNotAvailable: _, err := s.runHook(ctx, hook.S2SInStreamPresenceRouted, &hook.S2SStreamInfo{ ID: s.ID().String(), Sender: s.sender, Target: s.target, + Targets: targets, Element: presence, }) return err diff --git a/pkg/storage/boltdb/archive.go b/pkg/storage/boltdb/archive.go new file mode 100644 index 000000000..093235dc7 --- /dev/null +++ b/pkg/storage/boltdb/archive.go @@ -0,0 +1,274 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boltdb + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/jackal-xmpp/stravaganza/jid" + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + bolt "go.etcd.io/bbolt" +) + +const archiveStampFormat = "2006-01-02T15:04:05Z" + +type boltDBArchiveRep struct { + tx *bolt.Tx +} + +func newArchiveRep(tx *bolt.Tx) *boltDBArchiveRep { + return &boltDBArchiveRep{tx: tx} +} + +func (r *boltDBArchiveRep) InsertArchiveMessage(_ context.Context, message *archivemodel.Message) error { + op := insertSeqOp{ + tx: r.tx, + bucket: archiveBucket(message.ArchiveId), + obj: message, + } + return op.do() +} + +func (r *boltDBArchiveRep) FetchArchiveMetadata(_ context.Context, archiveID string) (metadata *archivemodel.Metadata, err error) { + bucketID := archiveBucket(archiveID) + + b := r.tx.Bucket([]byte(bucketID)) + if b == nil { + return nil, nil + } + var retVal archivemodel.Metadata + + c := b.Cursor() + _, val := c.First() + + var msg archivemodel.Message + if err := proto.Unmarshal(val, &msg); err != nil { + return nil, err + } + retVal.StartId = msg.Id + retVal.StartTimestamp = msg.Stamp.AsTime().UTC().Format(archiveStampFormat) + + _, val = c.Last() + if err := proto.Unmarshal(val, &msg); err != nil { + return nil, err + } + retVal.EndId = msg.Id + retVal.EndTimestamp = msg.Stamp.AsTime().UTC().Format(archiveStampFormat) + + return &retVal, nil +} + +func (r *boltDBArchiveRep) FetchArchiveMessages(_ context.Context, f *archivemodel.Filters, archiveID string) ([]*archivemodel.Message, error) { + var retVal []*archivemodel.Message + + op := iterKeysOp{ + tx: r.tx, + bucket: archiveBucket(archiveID), + iterFn: func(k, b []byte) error { + var msg archivemodel.Message + if err := proto.Unmarshal(b, &msg); err != nil { + return err + } + retVal = append(retVal, &msg) + return nil + }, + } + if err := op.do(); err != nil { + return nil, err + } + return applyFilters(retVal, f) +} + +func (r *boltDBArchiveRep) DeleteArchiveOldestMessages(_ context.Context, archiveID string, maxElements int) error { + bucketID := archiveBucket(archiveID) + + b := r.tx.Bucket([]byte(bucketID)) + if b == nil { + return nil + } + // count items + var count int + + c := b.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + count++ + } + if count < maxElements { + return nil + } + // store old value keys + var oldKeys [][]byte + + c = b.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + if count <= maxElements { + break + } + count-- + oldKeys = append(oldKeys, k) + } + // delete old values + for _, k := range oldKeys { + if err := b.Delete(k); err != nil { + return err + } + } + return nil +} + +func (r *boltDBArchiveRep) DeleteArchive(_ context.Context, archiveID string) error { + op := delBucketOp{ + tx: r.tx, + bucket: archiveBucket(archiveID), + } + return op.do() +} + +func archiveBucket(archiveID string) string { + return fmt.Sprintf("archive:%s", archiveID) +} + +// InsertArchiveMessage inserts a new message element into an archive queue. +func (r *Repository) InsertArchiveMessage(ctx context.Context, message *archivemodel.Message) error { + return r.db.Update(func(tx *bolt.Tx) error { + return newArchiveRep(tx).InsertArchiveMessage(ctx, message) + }) +} + +// FetchArchiveMetadata returns the metadata value associated to an archive. +func (r *Repository) FetchArchiveMetadata(ctx context.Context, archiveID string) (metadata *archivemodel.Metadata, err error) { + err = r.db.View(func(tx *bolt.Tx) error { + metadata, err = newArchiveRep(tx).FetchArchiveMetadata(ctx, archiveID) + return err + }) + return +} + +// FetchArchiveMessages fetches archive asscociated messages applying the passed f filters. +func (r *Repository) FetchArchiveMessages(ctx context.Context, f *archivemodel.Filters, archiveID string) (messages []*archivemodel.Message, err error) { + err = r.db.View(func(tx *bolt.Tx) error { + messages, err = newArchiveRep(tx).FetchArchiveMessages(ctx, f, archiveID) + return err + }) + return +} + +// DeleteArchiveOldestMessages trims archive oldest messages up to a maxElements total count. +func (r *Repository) DeleteArchiveOldestMessages(ctx context.Context, archiveID string, maxElements int) error { + return r.db.Update(func(tx *bolt.Tx) error { + return newArchiveRep(tx).DeleteArchiveOldestMessages(ctx, archiveID, maxElements) + }) +} + +// DeleteArchive clears an archive queue. +func (r *Repository) DeleteArchive(ctx context.Context, archiveID string) error { + return r.db.Update(func(tx *bolt.Tx) error { + return newArchiveRep(tx).DeleteArchive(ctx, archiveID) + }) +} + +func applyFilters(messages []*archivemodel.Message, f *archivemodel.Filters) ([]*archivemodel.Message, error) { + retVal := messages + + // filtering by JID + if len(f.With) > 0 { + jd, err := jid.NewWithString(f.With, false) + if err != nil { + return nil, err + } + var filtered []*archivemodel.Message + for _, msg := range retVal { + var matches bool + + switch { + case jd.IsFull(): + matches = msg.FromJid == jd.String() || msg.ToJid == jd.String() + + default: + fromJID, _ := jid.NewWithString(msg.FromJid, true) + toJID, _ := jid.NewWithString(msg.ToJid, true) + matches = fromJID.MatchesWithOptions(jd, jid.MatchesBare) || toJID.MatchesWithOptions(jd, jid.MatchesBare) + } + if matches { + filtered = append(filtered, msg) + } + } + retVal = filtered + } + + // filtering by id + if len(f.Ids) > 0 { + idsMap := map[string]struct{}{} + for _, id := range f.Ids { + idsMap[id] = struct{}{} + } + var filtered []*archivemodel.Message + for _, msg := range retVal { + _, ok := idsMap[msg.Id] + if !ok { + continue + } + filtered = append(filtered, msg) + } + retVal = filtered + + } else { + if len(f.BeforeId) > 0 { + for i, msg := range retVal { + if msg.Id != f.BeforeId { + continue + } + retVal = retVal[:i] + break + } + } + if len(f.AfterId) > 0 { + for i, msg := range retVal { + if msg.Id != f.AfterId { + continue + } + retVal = retVal[i+1:] + break + } + } + } + + // filtering by timestamp + if f.Start != nil { + startTm := f.Start.AsTime() + for i, msg := range retVal { + stampTm := msg.Stamp.AsTime() + if !stampTm.After(startTm) { + continue + } + retVal = retVal[i:] + break + } + } + if f.End != nil { + endTm := f.End.AsTime() + for i, msg := range retVal { + stampTm := msg.Stamp.AsTime() + if stampTm.Before(endTm) { + continue + } + retVal = retVal[:i] + break + } + } + return retVal, nil +} diff --git a/pkg/storage/boltdb/archive_test.go b/pkg/storage/boltdb/archive_test.go new file mode 100644 index 000000000..620826715 --- /dev/null +++ b/pkg/storage/boltdb/archive_test.go @@ -0,0 +1,292 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boltdb + +import ( + "context" + "testing" + "time" + + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + "github.com/stretchr/testify/require" + bolt "go.etcd.io/bbolt" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestBoltDB_InsertArchiveMessage(t *testing.T) { + t.Parallel() + + db := setupDB(t) + t.Cleanup(func() { cleanUp(db) }) + + err := db.Update(func(tx *bolt.Tx) error { + rep := boltDBArchiveRep{tx: tx} + + m0 := testMessageStanza() + + err := rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Message: m0.Proto(), + }) + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) +} + +func TestBoltDB_FetchArchiveMetadata(t *testing.T) { + t.Parallel() + + db := setupDB(t) + t.Cleanup(func() { cleanUp(db) }) + + err := db.Update(func(tx *bolt.Tx) error { + rep := boltDBArchiveRep{tx: tx} + + m0 := testMessageStanza() + m1 := testMessageStanza() + m2 := testMessageStanza() + + now0 := time.Now() + now1 := now0.Add(time.Hour) + now2 := now1.Add(time.Hour) + + err := rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "id0", + Message: m0.Proto(), + Stamp: timestamppb.New(now0), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "id1", + Message: m1.Proto(), + Stamp: timestamppb.New(now1), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "id2", + Message: m2.Proto(), + Stamp: timestamppb.New(now2), + }) + require.NoError(t, err) + + metadata, err := rep.FetchArchiveMetadata(context.Background(), "a1234") + require.NoError(t, err) + + require.Equal(t, "id0", metadata.StartId) + require.Equal(t, now0.UTC().Format(archiveStampFormat), metadata.StartTimestamp) + require.Equal(t, "id2", metadata.EndId) + require.Equal(t, now2.UTC().Format(archiveStampFormat), metadata.EndTimestamp) + + return nil + }) + require.NoError(t, err) +} + +func TestBoltDB_DeleteArchive(t *testing.T) { + t.Parallel() + + db := setupDB(t) + t.Cleanup(func() { cleanUp(db) }) + + err := db.Update(func(tx *bolt.Tx) error { + rep := boltDBArchiveRep{tx: tx} + + m0 := testMessageStanza() + m1 := testMessageStanza() + m2 := testMessageStanza() + + err := rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ArchiveId: "a1234", Message: m0.Proto()}) + require.NoError(t, err) + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ArchiveId: "a1234", Message: m1.Proto()}) + require.NoError(t, err) + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ArchiveId: "a1234", Message: m2.Proto()}) + require.NoError(t, err) + + require.Equal(t, 3, countBucketElements(t, tx, archiveBucket("a1234"))) + + require.NoError(t, rep.DeleteArchive(context.Background(), "a1234")) + + require.Equal(t, 0, countBucketElements(t, tx, archiveBucket("a1234"))) + + return nil + }) + require.NoError(t, err) +} + +func TestBoltDB_DeleteArchiveOldestMessages(t *testing.T) { + t.Parallel() + + db := setupDB(t) + t.Cleanup(func() { cleanUp(db) }) + + err := db.Update(func(tx *bolt.Tx) error { + rep := boltDBArchiveRep{tx: tx} + + m0 := testMessageStanza() + m1 := testMessageStanza() + m2 := testMessageStanza() + + err := rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Message: m0.Proto(), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Message: m1.Proto(), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Message: m2.Proto(), + }) + require.NoError(t, err) + + require.Equal(t, 3, countBucketElements(t, tx, archiveBucket("a1234"))) + + err = rep.DeleteArchiveOldestMessages(context.Background(), "a1234", 2) + require.NoError(t, err) + + require.Equal(t, 2, countBucketElements(t, tx, archiveBucket("a1234"))) + + return nil + }) + require.NoError(t, err) +} + +func TestBoltDB_FetchArchiveMessages(t *testing.T) { + tcs := map[string]struct { + filters *archivemodel.Filters + expectedResultIDs []string + }{ + "filtering by jid": { + filters: &archivemodel.Filters{ + With: "noelia@jackal.im", + }, + expectedResultIDs: []string{"m0", "m1", "m3"}, + }, + "filtering by full jid": { + filters: &archivemodel.Filters{ + With: "ortuman@jackal.im/firstwitch", + }, + expectedResultIDs: []string{"m2"}, + }, + "filtering by ids": { + filters: &archivemodel.Filters{ + Ids: []string{"m0", "m2"}, + }, + expectedResultIDs: []string{"m0", "m2"}, + }, + "filtering by after id": { + filters: &archivemodel.Filters{ + AfterId: "m1", + }, + expectedResultIDs: []string{"m2", "m3"}, + }, + "filtering by before id": { + filters: &archivemodel.Filters{ + BeforeId: "m2", + }, + expectedResultIDs: []string{"m0", "m1"}, + }, + "filtering by start": { + filters: &archivemodel.Filters{ + Start: timestamppb.New(time.Date(2022, 01, 02, 00, 00, 00, 00, time.UTC)), + }, + expectedResultIDs: []string{"m2", "m3"}, + }, + "filtering by end": { + filters: &archivemodel.Filters{ + End: timestamppb.New(time.Date(2022, 01, 02, 00, 00, 00, 00, time.UTC)), + }, + expectedResultIDs: []string{"m0"}, + }, + } + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + db := setupDB(t) + t.Cleanup(func() { cleanUp(db) }) + + err := db.Update(func(tx *bolt.Tx) error { + rep := boltDBArchiveRep{tx: tx} + + m0 := testMessageStanzaWithParameters("b0", "noelia@jackal.im/yard", "ortuman@jackal.im/chamber") + m1 := testMessageStanzaWithParameters("b1", "noelia@jackal.im/orchard", "ortuman@jackal.im/balcony") + m2 := testMessageStanzaWithParameters("b2", "witch1@jackal.im/yard", "ortuman@jackal.im/firstwitch") + m3 := testMessageStanzaWithParameters("b3", "witch2@jackal.im/yard", "noelia@jackal.im/garden") + + err := rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "m0", + FromJid: "noelia@jackal.im/yard", + ToJid: "ortuman@jackal.im/chamber", + Stamp: timestamppb.New(time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)), + Message: m0.Proto(), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "m1", + FromJid: "noelia@jackal.im/orchard", + ToJid: "ortuman@jackal.im/balcony", + Stamp: timestamppb.New(time.Date(2022, 01, 02, 00, 00, 00, 00, time.UTC)), + Message: m1.Proto(), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "m2", + FromJid: "witch1@jackal.im/yard", + ToJid: "ortuman@jackal.im/firstwitch", + Stamp: timestamppb.New(time.Date(2022, 01, 03, 00, 00, 00, 00, time.UTC)), + Message: m2.Proto(), + }) + require.NoError(t, err) + + err = rep.InsertArchiveMessage(context.Background(), &archivemodel.Message{ + ArchiveId: "a1234", + Id: "m3", + FromJid: "witch2@jackal.im/yard", + ToJid: "noelia@jackal.im/garden", + Stamp: timestamppb.New(time.Date(2022, 01, 04, 00, 00, 00, 00, time.UTC)), + Message: m3.Proto(), + }) + require.NoError(t, err) + + messages, err := rep.FetchArchiveMessages(context.Background(), tc.filters, "a1234") + require.NoError(t, err) + + var resultIDs []string + for _, msg := range messages { + resultIDs = append(resultIDs, msg.Id) + } + require.ElementsMatch(t, tc.expectedResultIDs, resultIDs) + return nil + }) + require.NoError(t, err) + }) + } +} diff --git a/pkg/storage/boltdb/offline_test.go b/pkg/storage/boltdb/offline_test.go index cc9955ad3..dff8cbbc0 100644 --- a/pkg/storage/boltdb/offline_test.go +++ b/pkg/storage/boltdb/offline_test.go @@ -31,8 +31,8 @@ func TestBoltDB_InsertAndFetchOfflineMessages(t *testing.T) { err := db.Update(func(tx *bolt.Tx) error { rep := boltDBOfflineRep{tx: tx} - m0 := testMessageStanza("message 0") - m1 := testMessageStanza("message 1") + m0 := testMessageStanza() + m1 := testMessageStanza() err := rep.InsertOfflineMessage(context.Background(), m0, "ortuman") require.NoError(t, err) @@ -45,8 +45,6 @@ func TestBoltDB_InsertAndFetchOfflineMessages(t *testing.T) { require.Len(t, messages, 2) - require.Equal(t, "message 0", messages[0].Child("body").Text()) - require.Equal(t, "message 1", messages[1].Child("body").Text()) return nil }) require.NoError(t, err) @@ -61,8 +59,8 @@ func TestBoltDB_CountOfflineMessages(t *testing.T) { err := db.Update(func(tx *bolt.Tx) error { rep := boltDBOfflineRep{tx: tx} - m0 := testMessageStanza("message 0") - m1 := testMessageStanza("message 1") + m0 := testMessageStanza() + m1 := testMessageStanza() err := rep.InsertOfflineMessage(context.Background(), m0, "ortuman") require.NoError(t, err) @@ -88,8 +86,8 @@ func TestBoltDB_DeleteOfflineMessages(t *testing.T) { err := db.Update(func(tx *bolt.Tx) error { rep := boltDBOfflineRep{tx: tx} - m0 := testMessageStanza("message 0") - m1 := testMessageStanza("message 1") + m0 := testMessageStanza() + m1 := testMessageStanza() err := rep.InsertOfflineMessage(context.Background(), m0, "ortuman") require.NoError(t, err) diff --git a/pkg/storage/boltdb/op.go b/pkg/storage/boltdb/op.go index f92bd3257..dac939cbf 100644 --- a/pkg/storage/boltdb/op.go +++ b/pkg/storage/boltdb/op.go @@ -60,8 +60,7 @@ func (op insertSeqOp) do() error { if err != nil { return err } - k := fmt.Sprintf("%d", seq) - return b.Put([]byte(k), p) + return b.Put([]byte(fmt.Sprintf("%d", seq)), p) } type delBucketOp struct { diff --git a/pkg/storage/boltdb/repository.go b/pkg/storage/boltdb/repository.go index aaa9faa58..dedf3dcaf 100644 --- a/pkg/storage/boltdb/repository.go +++ b/pkg/storage/boltdb/repository.go @@ -39,6 +39,7 @@ type Repository struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker cfg Config diff --git a/pkg/storage/boltdb/repository_test.go b/pkg/storage/boltdb/repository_test.go new file mode 100644 index 000000000..6b576966c --- /dev/null +++ b/pkg/storage/boltdb/repository_test.go @@ -0,0 +1,36 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boltdb + +import ( + "testing" + + bolt "go.etcd.io/bbolt" +) + +func countBucketElements(t *testing.T, tx *bolt.Tx, bucket string) int { + t.Helper() + + b := tx.Bucket([]byte(bucket)) + if b == nil { + return 0 + } + var count int + c := b.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + count++ + } + return count +} diff --git a/pkg/storage/boltdb/tx.go b/pkg/storage/boltdb/tx.go index 2d7cbc29e..8988ffeb1 100644 --- a/pkg/storage/boltdb/tx.go +++ b/pkg/storage/boltdb/tx.go @@ -28,6 +28,7 @@ type repTx struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker } @@ -41,6 +42,7 @@ func newRepTx(tx *bolt.Tx) *repTx { Private: newPrivateRep(tx), Roster: newRosterRep(tx), VCard: newVCardRep(tx), + Archive: newArchiveRep(tx), Locker: newLockerRep(), } } diff --git a/pkg/storage/boltdb/util_test.go b/pkg/storage/boltdb/util_test.go index 02a9ef18d..04339c9f8 100644 --- a/pkg/storage/boltdb/util_test.go +++ b/pkg/storage/boltdb/util_test.go @@ -20,7 +20,6 @@ import ( "testing" "github.com/jackal-xmpp/stravaganza" - bolt "go.etcd.io/bbolt" ) @@ -41,10 +40,23 @@ func cleanUp(db *bolt.DB) { _ = os.RemoveAll(dbPath) } -func testMessageStanza(body string) *stravaganza.Message { +func testMessageStanza() *stravaganza.Message { b := stravaganza.NewMessageBuilder() b.WithAttribute("from", "noelia@jackal.im/yard") b.WithAttribute("to", "ortuman@jackal.im/balcony") + b.WithChild( + stravaganza.NewBuilder("body"). + WithText("Call me but love, and I'll be new baptized; Henceforth I never will be Romeo."). + Build(), + ) + msg, _ := b.BuildMessage() + return msg +} + +func testMessageStanzaWithParameters(body, from, to string) *stravaganza.Message { + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", from) + b.WithAttribute("to", to) b.WithChild( stravaganza.NewBuilder("body"). WithText(body). diff --git a/pkg/storage/cached/cached.go b/pkg/storage/cached/cached.go index 5a150b395..798fe99a4 100644 --- a/pkg/storage/cached/cached.go +++ b/pkg/storage/cached/cached.go @@ -68,6 +68,7 @@ type CachedRepository struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker rep repository.Repository @@ -91,6 +92,7 @@ func New(cfg Config, rep repository.Repository, logger kitlog.Logger) (repositor BlockList: &cachedBlockListRep{c: c, rep: rep, logger: logger}, Roster: &cachedRosterRep{c: c, rep: rep, logger: logger}, VCard: &cachedVCardRep{c: c, rep: rep, logger: logger}, + Archive: rep, Offline: rep, Locker: rep, rep: rep, diff --git a/pkg/storage/cached/tx.go b/pkg/storage/cached/tx.go index 5c2084a01..230af7c7d 100644 --- a/pkg/storage/cached/tx.go +++ b/pkg/storage/cached/tx.go @@ -27,6 +27,7 @@ type cachedTx struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker } @@ -39,6 +40,7 @@ func newCacheTx(c Cache, tx repository.Transaction) *cachedTx { BlockList: &cachedBlockListRep{c: c, rep: tx}, Roster: &cachedRosterRep{c: c, rep: tx}, VCard: &cachedVCardRep{c: c, rep: tx}, + Archive: tx, Offline: tx, Locker: tx, } diff --git a/pkg/storage/measured/archive.go b/pkg/storage/measured/archive.go new file mode 100644 index 000000000..c44b4fd66 --- /dev/null +++ b/pkg/storage/measured/archive.go @@ -0,0 +1,63 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package measuredrepository + +import ( + "context" + "time" + + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + "github.com/ortuman/jackal/pkg/storage/repository" +) + +type measuredArchiveRep struct { + rep repository.Archive + inTx bool +} + +func (m *measuredArchiveRep) InsertArchiveMessage(ctx context.Context, message *archivemodel.Message) error { + t0 := time.Now() + err := m.rep.InsertArchiveMessage(ctx, message) + reportOpMetric(upsertOp, time.Since(t0).Seconds(), err == nil, m.inTx) + return err +} + +func (m *measuredArchiveRep) FetchArchiveMetadata(ctx context.Context, archiveID string) (metadata *archivemodel.Metadata, err error) { + t0 := time.Now() + metadata, err = m.rep.FetchArchiveMetadata(ctx, archiveID) + reportOpMetric(fetchOp, time.Since(t0).Seconds(), err == nil, m.inTx) + return +} + +func (m *measuredArchiveRep) FetchArchiveMessages(ctx context.Context, f *archivemodel.Filters, archiveID string) (messages []*archivemodel.Message, err error) { + t0 := time.Now() + messages, err = m.rep.FetchArchiveMessages(ctx, f, archiveID) + reportOpMetric(fetchOp, time.Since(t0).Seconds(), err == nil, m.inTx) + return +} + +func (m *measuredArchiveRep) DeleteArchiveOldestMessages(ctx context.Context, archiveID string, maxElements int) error { + t0 := time.Now() + err := m.rep.DeleteArchiveOldestMessages(ctx, archiveID, maxElements) + reportOpMetric(deleteOp, time.Since(t0).Seconds(), err == nil, m.inTx) + return err +} + +func (m *measuredArchiveRep) DeleteArchive(ctx context.Context, archiveID string) error { + t0 := time.Now() + err := m.rep.DeleteArchive(ctx, archiveID) + reportOpMetric(deleteOp, time.Since(t0).Seconds(), err == nil, m.inTx) + return err +} diff --git a/pkg/storage/measured/archive_test.go b/pkg/storage/measured/archive_test.go new file mode 100644 index 000000000..5cbb98912 --- /dev/null +++ b/pkg/storage/measured/archive_test.go @@ -0,0 +1,99 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package measuredrepository + +import ( + "context" + "testing" + + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + "github.com/stretchr/testify/require" +) + +func TestMeasuredArchiveRep_InsertArchiveMessage(t *testing.T) { + // given + repMock := &repositoryMock{} + repMock.InsertArchiveMessageFunc = func(ctx context.Context, message *archivemodel.Message) error { + return nil + } + m := &measuredArchiveRep{rep: repMock} + + // when + _ = m.InsertArchiveMessage(context.Background(), &archivemodel.Message{ArchiveId: "a1234"}) + + // then + require.Len(t, repMock.InsertArchiveMessageCalls(), 1) +} + +func TestMeasuredArchiveRep_FetchArchiveMetadata(t *testing.T) { + // given + repMock := &repositoryMock{} + repMock.FetchArchiveMetadataFunc = func(ctx context.Context, archiveID string) (*archivemodel.Metadata, error) { + return nil, nil + } + m := &measuredArchiveRep{rep: repMock} + + // when + _, _ = m.FetchArchiveMetadata(context.Background(), "a1234") + + // then + require.Len(t, repMock.FetchArchiveMetadataCalls(), 1) +} + +func TestMeasuredArchiveRep_FetchArchiveMessages(t *testing.T) { + // given + repMock := &repositoryMock{} + repMock.FetchArchiveMessagesFunc = func(ctx context.Context, f *archivemodel.Filters, archiveID string) ([]*archivemodel.Message, error) { + return nil, nil + } + m := &measuredArchiveRep{rep: repMock} + + // when + _, _ = m.FetchArchiveMessages(context.Background(), &archivemodel.Filters{}, "a1234") + + // then + require.Len(t, repMock.FetchArchiveMessagesCalls(), 1) +} + +func TestMeasuredArchiveRep_DeleteArchiveOldestMessages(t *testing.T) { + // given + repMock := &repositoryMock{} + repMock.DeleteArchiveOldestMessagesFunc = func(ctx context.Context, archiveID string, maxElements int) error { + return nil + } + m := &measuredArchiveRep{rep: repMock} + + // when + err := m.DeleteArchiveOldestMessages(context.Background(), "a1234", 10) + + // then + require.Len(t, repMock.DeleteArchiveOldestMessagesCalls(), 1) + require.NoError(t, err) +} + +func TestMeasuredArchiveRep_DeleteArchive(t *testing.T) { + // given + repMock := &repositoryMock{} + repMock.DeleteArchiveFunc = func(ctx context.Context, archiveId string) error { + return nil + } + m := &measuredArchiveRep{rep: repMock} + + // when + _ = m.DeleteArchive(context.Background(), "a1234") + + // then + require.Len(t, repMock.DeleteArchiveCalls(), 1) +} diff --git a/pkg/storage/measured/measured.go b/pkg/storage/measured/measured.go index 01a7b3d1b..914613f24 100644 --- a/pkg/storage/measured/measured.go +++ b/pkg/storage/measured/measured.go @@ -40,6 +40,7 @@ type Measured struct { measuredPrivateRep measuredRosterRep measuredVCardRep + measuredArchiveRep measuredLocker rep repository.Repository } @@ -55,6 +56,7 @@ func New(rep repository.Repository) repository.Repository { measuredPrivateRep: measuredPrivateRep{rep: rep}, measuredRosterRep: measuredRosterRep{rep: rep}, measuredVCardRep: measuredVCardRep{rep: rep}, + measuredArchiveRep: measuredArchiveRep{rep: rep}, measuredLocker: measuredLocker{rep: rep}, rep: rep, } diff --git a/pkg/storage/measured/tx.go b/pkg/storage/measured/tx.go index 992b0379d..316287f56 100644 --- a/pkg/storage/measured/tx.go +++ b/pkg/storage/measured/tx.go @@ -25,6 +25,7 @@ type measuredTx struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker } @@ -38,6 +39,7 @@ func newMeasuredTx(tx repository.Transaction) *measuredTx { Private: &measuredPrivateRep{rep: tx, inTx: true}, Roster: &measuredRosterRep{rep: tx, inTx: true}, VCard: &measuredVCardRep{rep: tx, inTx: true}, + Archive: &measuredArchiveRep{rep: tx, inTx: true}, Locker: &measuredLocker{rep: tx, inTx: true}, } } diff --git a/pkg/storage/pgsql/archive.go b/pkg/storage/pgsql/archive.go new file mode 100644 index 000000000..a27cd148d --- /dev/null +++ b/pkg/storage/pgsql/archive.go @@ -0,0 +1,214 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsqlrepository + +import ( + "context" + "database/sql" + "time" + + sq "github.com/Masterminds/squirrel" + kitlog "github.com/go-kit/log" + "github.com/golang/protobuf/proto" + "github.com/jackal-xmpp/stravaganza" + "github.com/jackal-xmpp/stravaganza/jid" + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + archiveTableName = "archives" + + archiveStampFormat = "2006-01-02T15:04:05Z" +) + +type pgSQLArchiveRep struct { + conn conn + logger kitlog.Logger +} + +func (r *pgSQLArchiveRep) InsertArchiveMessage(ctx context.Context, message *archivemodel.Message) error { + b, err := proto.Marshal(message.Message) + if err != nil { + return err + } + fromJID, _ := jid.NewWithString(message.FromJid, true) + toJID, _ := jid.NewWithString(message.ToJid, true) + + q := sq.Insert(archiveTableName). + Prefix(noLoadBalancePrefix). + Columns("archive_id", "id", `"from"`, "from_bare", `"to"`, "to_bare", "message"). + Values( + message.ArchiveId, + message.Id, + fromJID.String(), + fromJID.ToBareJID().String(), + toJID.String(), + toJID.ToBareJID().String(), + b, + ) + + _, err = q.RunWith(r.conn).ExecContext(ctx) + return err +} + +func (r *pgSQLArchiveRep) FetchArchiveMetadata(ctx context.Context, archiveID string) (*archivemodel.Metadata, error) { + fromExpr := `FROM ` + fromExpr += `(SELECT "id", created_at FROM archives WHERE serial = (SELECT MIN(serial) FROM archives WHERE archive_id = $1)) AS min,` + fromExpr += `(SELECT "id", created_at FROM archives WHERE serial = (SELECT MAX(serial) FROM archives WHERE archive_id = $1)) AS max` + + q := sq.Select("min.id, min.created_at, max.id, max.created_at").Suffix(fromExpr, archiveID) + + var start, end time.Time + var metadata archivemodel.Metadata + + err := q.RunWith(r.conn). + QueryRowContext(ctx). + Scan( + &metadata.StartId, + &start, + &metadata.EndId, + &end, + ) + + switch err { + case nil: + metadata.StartTimestamp = start.UTC().Format(archiveStampFormat) + metadata.EndTimestamp = end.UTC().Format(archiveStampFormat) + return &metadata, nil + + case sql.ErrNoRows: + return nil, nil + + default: + return nil, err + } +} + +func (r *pgSQLArchiveRep) FetchArchiveMessages(ctx context.Context, f *archivemodel.Filters, archiveID string) ([]*archivemodel.Message, error) { + q := sq.Select("id", `"from"`, `"to"`, "message", "created_at"). + From(archiveTableName). + Where(filtersToPred(f, archiveID)). + OrderBy("created_at"). + PlaceholderFormat(sq.Dollar) + + rows, err := q.RunWith(r.conn).QueryContext(ctx) + if err != nil { + return nil, err + } + defer closeRows(rows, r.logger) + + retVal, err := scanArchiveMessages(rows, archiveID) + if err != nil { + return nil, err + } + return retVal, err +} + +func (r *pgSQLArchiveRep) DeleteArchiveOldestMessages(ctx context.Context, archiveID string, maxElements int) error { + q := sq.Delete(archiveTableName). + Prefix(noLoadBalancePrefix). + Where(sq.And{ + sq.Eq{"archive_id": archiveID}, + sq.Expr(`"id" NOT IN (SELECT "id" FROM archives WHERE archive_id = $2 ORDER BY created_at DESC LIMIT $3 OFFSET 0)`, archiveID, maxElements), + }) + _, err := q.RunWith(r.conn).ExecContext(ctx) + return err +} + +func (r *pgSQLArchiveRep) DeleteArchive(ctx context.Context, archiveID string) error { + q := sq.Delete(archiveTableName). + Prefix(noLoadBalancePrefix). + Where(sq.Eq{"archive_id": archiveID}) + _, err := q.RunWith(r.conn).ExecContext(ctx) + return err +} + +func filtersToPred(f *archivemodel.Filters, archiveID string) (interface{}, error) { + pred := sq.And{ + sq.Eq{"archive_id": archiveID}, + } + // filtering by JID + if len(f.With) > 0 { + jd, err := jid.NewWithString(f.With, false) + if err != nil { + return nil, err + } + switch { + case jd.IsFull(): + pred = append(pred, sq.Expr(`("to" = ? OR "from" = ?)`, jd.String(), jd.String())) + + default: + pred = append(pred, sq.Expr(`(to_bare = ? OR from_bare = ?)`, jd.String(), jd.String())) + } + } + + // filtering by id + if len(f.Ids) > 0 { + pred = append(pred, sq.Eq{"id": f.Ids}) + } else { + if len(f.BeforeId) > 0 { + pred = append(pred, sq.Expr(`(serial < (SELECT serial FROM archives WHERE "id" = ? AND archive_id = ?))`, f.BeforeId, archiveID)) + } + if len(f.AfterId) > 0 { + pred = append(pred, sq.Expr(`(serial > (SELECT serial FROM archives WHERE "id" = ? AND archive_id = ?))`, f.AfterId, archiveID)) + } + } + + // filtering by timestamp + if f.Start != nil { + pred = append(pred, sq.Expr("EXTRACT(epoch FROM created_at) > ?", f.Start.GetSeconds())) + } + if f.End != nil { + pred = append(pred, sq.Expr("EXTRACT(epoch FROM created_at) < ?", f.End.GetSeconds())) + } + return pred, nil +} + +func scanArchiveMessages(scanner rowsScanner, archiveID string) ([]*archivemodel.Message, error) { + var ret []*archivemodel.Message + for scanner.Next() { + msg, err := scanArchiveMessage(scanner, archiveID) + if err != nil { + return nil, err + } + ret = append(ret, msg) + } + return ret, nil +} + +func scanArchiveMessage(scanner rowsScanner, archiveID string) (*archivemodel.Message, error) { + var ret archivemodel.Message + + var b []byte + var tm time.Time + + if err := scanner.Scan(&ret.Id, &ret.FromJid, &ret.ToJid, &b, &tm); err != nil { + return nil, err + } + sb, err := stravaganza.NewBuilderFromBinary(b) + if err != nil { + return nil, err + } + msg, err := sb.BuildMessage() + if err != nil { + return nil, err + } + ret.ArchiveId = archiveID + ret.Message = msg.Proto() + ret.Stamp = timestamppb.New(tm) + + return &ret, nil +} diff --git a/pkg/storage/pgsql/archive_test.go b/pkg/storage/pgsql/archive_test.go new file mode 100644 index 000000000..b053941c2 --- /dev/null +++ b/pkg/storage/pgsql/archive_test.go @@ -0,0 +1,217 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsqlrepository + +import ( + "context" + "database/sql/driver" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/golang/protobuf/proto" + "github.com/jackal-xmpp/stravaganza" + archivemodel "github.com/ortuman/jackal/pkg/model/archive" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestPgSQLArchive_InsertArchiveMessage(t *testing.T) { + // given + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", "noelia@jackal.im/yard") + b.WithAttribute("to", "ortuman@jackal.im/balcony") + b.WithChild( + stravaganza.NewBuilder("body"). + WithText("I'll give thee a wind."). + Build(), + ) + msg, _ := b.BuildMessage() + + aMsg := &archivemodel.Message{ + ArchiveId: "ortuman", + Id: "id1234", + FromJid: "ortuman@jackal.im/local", + ToJid: "ortuman@jabber.org/remote", + Message: msg.Proto(), + } + msgBytes, _ := proto.Marshal(aMsg.Message) + + s, mock := newArchiveMock() + mock.ExpectExec(`INSERT INTO archives \(archive_id,id,"from",from_bare,"to",to_bare,message\) VALUES \(\$1,\$2,\$3,\$4,\$5,\$6,\$7\)`). + WithArgs("ortuman", "id1234", "ortuman@jackal.im/local", "ortuman@jackal.im", "ortuman@jabber.org/remote", "ortuman@jabber.org", msgBytes). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // when + err := s.InsertArchiveMessage(context.Background(), aMsg) + + // then + require.Nil(t, err) + require.Nil(t, mock.ExpectationsWereMet()) +} + +func TestPgSQLArchive_FetchArchiveMetadata(t *testing.T) { + minT := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC) + maxT := time.Date(2022, 12, 12, 00, 00, 00, 00, time.UTC) + + // given + s, mock := newArchiveMock() + mock.ExpectQuery(`SELECT min.id, min.created_at, max.id, max.created_at FROM \(SELECT "id", created_at FROM archives WHERE serial = \(SELECT MIN\(serial\) FROM archives WHERE archive_id = \$1\)\) AS min,\(SELECT "id", created_at FROM archives WHERE serial = \(SELECT MAX\(serial\) FROM archives WHERE archive_id = \$1\)\) AS max`). + WithArgs("ortuman"). + WillReturnRows( + sqlmock.NewRows([]string{"min.id", "min.created_at", "max.id", "max.created_at"}).AddRow("YWxwaGEg", minT, "b21lZ2Eg", maxT), + ) + + // when + metadata, err := s.FetchArchiveMetadata(context.Background(), "ortuman") + + // then + require.Nil(t, err) + require.NotNil(t, metadata) + + require.Equal(t, "YWxwaGEg", metadata.StartId) + require.Equal(t, "2022-01-01T00:00:00Z", metadata.StartTimestamp) + require.Equal(t, "b21lZ2Eg", metadata.EndId) + require.Equal(t, "2022-12-12T00:00:00Z", metadata.EndTimestamp) + + require.Nil(t, mock.ExpectationsWereMet()) +} + +func TestPgSQLArchive_FetchArchiveMessages(t *testing.T) { + starTm := time.Date(2022, time.July, 6, 14, 7, 43, 167051000, time.UTC) + endTm := time.Date(2023, time.July, 7, 15, 7, 43, 167051000, time.UTC) + + tcs := map[string]struct { + filters *archivemodel.Filters + withArgs []driver.Value + expectQuery string + }{ + "by bare jid": { + filters: &archivemodel.Filters{With: "noelia@jackal.im"}, + withArgs: []driver.Value{"ortuman", "noelia@jackal.im", "noelia@jackal.im"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND \(to_bare = \$2 OR from_bare = \$3\)\) ORDER BY created_at`, + }, + "by full jid": { + filters: &archivemodel.Filters{With: "noelia@jackal.im/yard"}, + withArgs: []driver.Value{"ortuman", "noelia@jackal.im/yard", "noelia@jackal.im/yard"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND \("to" = \$2 OR "from" = \$3\)\) ORDER BY created_at`, + }, + "by ids": { + filters: &archivemodel.Filters{Ids: []string{"id1234", "id5678"}}, + withArgs: []driver.Value{"ortuman", "id1234", "id5678"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND id IN \(\$2,\$3\)\) ORDER BY created_at`, + }, + "by before id": { + filters: &archivemodel.Filters{BeforeId: "id1234"}, + withArgs: []driver.Value{"ortuman", "id1234", "ortuman"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND \(serial < \(SELECT serial FROM archives WHERE "id" = \$2 AND archive_id = \$3\)\)\) ORDER BY created_at`, + }, + "by after id": { + filters: &archivemodel.Filters{AfterId: "id1234"}, + withArgs: []driver.Value{"ortuman", "id1234", "ortuman"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND \(serial > \(SELECT serial FROM archives WHERE "id" = \$2 AND archive_id = \$3\)\)\) ORDER BY created_at`, + }, + "by before and after id": { + filters: &archivemodel.Filters{BeforeId: "id1234", AfterId: "id5678"}, + withArgs: []driver.Value{"ortuman", "id1234", "ortuman", "id5678", "ortuman"}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND \(serial < \(SELECT serial FROM archives WHERE "id" = \$2 AND archive_id = \$3\)\) AND \(serial > \(SELECT serial FROM archives WHERE "id" = \$4 AND archive_id = \$5\)\)\) ORDER BY created_at`, + }, + "by start timestamp": { + filters: &archivemodel.Filters{Start: timestamppb.New(starTm)}, + withArgs: []driver.Value{"ortuman", starTm.Unix()}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND EXTRACT\(epoch FROM created_at\) > \$2\) ORDER BY created_at`, + }, + "by end timestamp": { + filters: &archivemodel.Filters{End: timestamppb.New(endTm)}, + withArgs: []driver.Value{"ortuman", endTm.Unix()}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND EXTRACT\(epoch FROM created_at\) < \$2\) ORDER BY created_at`, + }, + "by start and end timestamp": { + filters: &archivemodel.Filters{Start: timestamppb.New(starTm), End: timestamppb.New(endTm)}, + withArgs: []driver.Value{"ortuman", starTm.Unix(), endTm.Unix()}, + expectQuery: `SELECT id, "from", "to", message, created_at FROM archives WHERE \(archive_id = \$1 AND EXTRACT\(epoch FROM created_at\) > \$2 AND EXTRACT\(epoch FROM created_at\) < \$3\) ORDER BY created_at`, + }, + } + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", "noelia@jackal.im/yard") + b.WithAttribute("to", "ortuman@jackal.im/balcony") + b.WithChild( + stravaganza.NewBuilder("body"). + WithText("I'll give thee a wind."). + Build(), + ) + msg, _ := b.BuildMessage() + + msgBytes, _ := msg.MarshalBinary() + tmNow := time.Date(2022, time.July, 6, 14, 7, 43, 167051000, time.UTC) + + rows := sqlmock.NewRows([]string{"id", "from", "to", "message", "created_at"}). + AddRow("id1234", "ortuman@jackal.im", "noelia@jackal.im", msgBytes, tmNow) + + s, mock := newArchiveMock() + mock.ExpectQuery(tc.expectQuery). + WithArgs(tc.withArgs...). + WillReturnRows(rows) + + // when + messages, err := s.FetchArchiveMessages(context.Background(), tc.filters, "ortuman") + + require.NoError(t, err) + require.Nil(t, mock.ExpectationsWereMet()) + + // then + require.Len(t, messages, 1) + require.Equal(t, "id1234", messages[0].Id) + require.Equal(t, tmNow, messages[0].Stamp.AsTime()) + }) + } +} + +func TestPgSQLArchive_DeleteArchiveOldestMessages(t *testing.T) { + // given + s, mock := newArchiveMock() + mock.ExpectExec(`DELETE FROM archives WHERE \(archive_id = \$1 AND "id" NOT IN \(SELECT "id" FROM archives WHERE archive_id = \$2 ORDER BY created_at DESC LIMIT \$3 OFFSET 0\)\)`). + WithArgs("ortuman", "ortuman", 1234). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // when + err := s.DeleteArchiveOldestMessages(context.Background(), "ortuman", 1234) + + // then + require.Nil(t, err) + require.Nil(t, mock.ExpectationsWereMet()) +} + +func TestPgSQLArchive_DeleteArchive(t *testing.T) { + // given + s, mock := newArchiveMock() + mock.ExpectExec(`DELETE FROM archives WHERE archive_id = \$1`). + WithArgs("ortuman"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // when + err := s.DeleteArchive(context.Background(), "ortuman") + + // then + require.Nil(t, err) + require.Nil(t, mock.ExpectationsWereMet()) +} + +func newArchiveMock() (*pgSQLArchiveRep, sqlmock.Sqlmock) { + s, sqlMock := newPgSQLMock() + return &pgSQLArchiveRep{conn: s}, sqlMock +} diff --git a/pkg/storage/pgsql/repository.go b/pkg/storage/pgsql/repository.go index 3a0acb9f9..32a665e83 100644 --- a/pkg/storage/pgsql/repository.go +++ b/pkg/storage/pgsql/repository.go @@ -57,6 +57,7 @@ type Repository struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker host string @@ -120,6 +121,7 @@ func (r *Repository) Start(ctx context.Context) error { r.Private = &pgSQLPrivateRep{conn: db, logger: r.logger} r.Roster = &pgSQLRosterRep{conn: db, logger: r.logger} r.VCard = &pgSQLVCardRep{conn: db, logger: r.logger} + r.Archive = &pgSQLArchiveRep{conn: db, logger: r.logger} r.Locker = &pgSQLLocker{conn: db} return nil } diff --git a/pkg/storage/pgsql/tx.go b/pkg/storage/pgsql/tx.go index 99d6be102..18ccde9a5 100644 --- a/pkg/storage/pgsql/tx.go +++ b/pkg/storage/pgsql/tx.go @@ -29,6 +29,7 @@ type repTx struct { repository.Private repository.Roster repository.VCard + repository.Archive repository.Locker } @@ -42,6 +43,7 @@ func newRepTx(tx *sql.Tx) *repTx { Private: &pgSQLPrivateRep{conn: tx}, Roster: &pgSQLRosterRep{conn: tx}, VCard: &pgSQLVCardRep{conn: tx}, + Archive: &pgSQLArchiveRep{conn: tx}, Locker: &pgSQLLocker{conn: tx}, } } diff --git a/pkg/storage/repository/archive.go b/pkg/storage/repository/archive.go new file mode 100644 index 000000000..9aec3f73b --- /dev/null +++ b/pkg/storage/repository/archive.go @@ -0,0 +1,39 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repository + +import ( + "context" + + archivemodel "github.com/ortuman/jackal/pkg/model/archive" +) + +// Archive defines storage operations for message archive +type Archive interface { + // InsertArchiveMessage inserts a new message element into an archive queue. + InsertArchiveMessage(ctx context.Context, message *archivemodel.Message) error + + // FetchArchiveMetadata returns the metadata value associated to an archive. + FetchArchiveMetadata(ctx context.Context, archiveID string) (*archivemodel.Metadata, error) + + // FetchArchiveMessages fetches archive asscociated messages applying the passed f filters. + FetchArchiveMessages(ctx context.Context, f *archivemodel.Filters, archiveID string) ([]*archivemodel.Message, error) + + // DeleteArchiveOldestMessages trims archive oldest messages up to a maxElements total count. + DeleteArchiveOldestMessages(ctx context.Context, archiveID string, maxElements int) error + + // DeleteArchive clears an archive queue. + DeleteArchive(ctx context.Context, archiveID string) error +} diff --git a/pkg/storage/repository/repository.go b/pkg/storage/repository/repository.go index 51b9c913b..e5a6fe100 100644 --- a/pkg/storage/repository/repository.go +++ b/pkg/storage/repository/repository.go @@ -38,6 +38,7 @@ type Transaction interface { } type baseRepository interface { + Archive User Last Capabilities diff --git a/pkg/util/xmpp/xmpp.go b/pkg/util/xmpp/xmpp.go index 8aca645d3..bdaac11df 100644 --- a/pkg/util/xmpp/xmpp.go +++ b/pkg/util/xmpp/xmpp.go @@ -22,6 +22,8 @@ import ( "github.com/jackal-xmpp/stravaganza/jid" ) +const delayTimeFormat = "2006-01-02T15:04:05Z" + // MakeResultIQ creates a new result stanza derived from iq. func MakeResultIQ(iq *stravaganza.IQ, queryChild stravaganza.Element) *stravaganza.IQ { b := iq.ResultBuilder() @@ -57,10 +59,53 @@ func MakeDelayMessage(stanza stravaganza.Stanza, stamp time.Time, from, text str stravaganza.NewBuilder("delay"). WithAttribute(stravaganza.Namespace, "urn:xmpp:delay"). WithAttribute(stravaganza.From, from). - WithAttribute("stamp", stamp.UTC().Format("2006-01-02T15:04:05Z")). + WithAttribute("stamp", stamp.UTC().Format(delayTimeFormat)). WithText(text). Build(), ) dMsg, _ := sb.BuildMessage() return dMsg } + +// MakeStanzaIDMessage creates and returns a new message containing a stanza-id element. +func MakeStanzaIDMessage(originalMsg *stravaganza.Message, stanzaID, by string) *stravaganza.Message { + msg, _ := stravaganza.NewBuilderFromElement(originalMsg). + WithChild( + stravaganza.NewBuilder("stanza-id"). + WithAttribute(stravaganza.Namespace, "urn:xmpp:sid:0"). + WithAttribute("by", by). + WithAttribute("id", stanzaID). + Build(), + ). + BuildMessage() + return msg +} + +// MessageStanzaID returns the stanza-id value contained in msg parameter. +func MessageStanzaID(msg *stravaganza.Message) string { + sidElem := msg.ChildNamespace("stanza-id", "urn:xmpp:sid:0") + if sidElem == nil { + return "" + } + return sidElem.Attribute("id") +} + +// MakeForwardedStanza creates a new forwarded element derived from the passed stanza. +func MakeForwardedStanza(stanza stravaganza.Stanza, stamp *time.Time) stravaganza.Element { + b := stravaganza.NewBuilder("forwarded"). + WithAttribute(stravaganza.Namespace, "urn:xmpp:forward:0"). + WithChild( + stravaganza.NewBuilderFromElement(stanza). + WithAttribute(stravaganza.Namespace, "jabber:client"). + Build(), + ) + if stamp != nil { + b.WithChild( + stravaganza.NewBuilder("delay"). + WithAttribute(stravaganza.Namespace, "urn:xmpp:delay"). + WithAttribute("stamp", stamp.UTC().Format(delayTimeFormat)). + Build(), + ) + } + return b.Build() +} diff --git a/pkg/util/xmpp/xmpp_test.go b/pkg/util/xmpp/xmpp_test.go index c11fd36b1..9d7aef9cc 100644 --- a/pkg/util/xmpp/xmpp_test.go +++ b/pkg/util/xmpp/xmpp_test.go @@ -126,3 +126,55 @@ func TestMakeDelayStanza(t *testing.T) { require.Equal(t, "2021-02-15T15:00:00Z", dChild.Attribute("stamp")) require.Equal(t, "Delayed IQ", dChild.Text()) } + +func TestMakeStanzaIDElement(t *testing.T) { + // given + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", "noelia@jackal.im/yard") + b.WithAttribute("to", "ortuman@jackal.im/balcony") + b.WithChild( + stravaganza.NewBuilder("body"). + WithText("I'll give thee a wind."). + Build(), + ) + msg, _ := b.BuildMessage() + + // when + msg = MakeStanzaIDMessage(msg, "1234", "ortuman@jackal.im") + + // then + elem := msg.ChildNamespace("stanza-id", "urn:xmpp:sid:0") + require.NotNil(t, elem) + + require.Equal(t, "1234", MessageStanzaID(msg)) + require.Equal(t, "ortuman@jackal.im", elem.Attribute("by")) +} + +func TestMakeForwardedElement(t *testing.T) { + // given + b := stravaganza.NewMessageBuilder() + b.WithAttribute("from", "noelia@jackal.im/yard") + b.WithAttribute("to", "ortuman@jackal.im/balcony") + b.WithChild( + stravaganza.NewBuilder("body"). + WithText("I'll give thee a wind."). + Build(), + ) + msg, _ := b.BuildMessage() + + stamp, _ := time.Parse(time.RFC3339, "2021-02-15T15:00:00Z") + forwarded := MakeForwardedStanza(msg, &stamp) + + // when + require.Equal(t, "urn:xmpp:forward:0", forwarded.Attribute(stravaganza.Namespace)) + + dChild := forwarded.Child("delay") + require.NotNil(t, dChild) + require.Equal(t, "2021-02-15T15:00:00Z", dChild.Attribute("stamp")) + + msgEl := forwarded.Child("message") + require.NotNil(t, msgEl) + bodyEl := msgEl.Child("body") + require.NotNil(t, bodyEl) + require.Equal(t, "I'll give thee a wind.", bodyEl.Text()) +} diff --git a/proto/model/v1/archive.proto b/proto/model/v1/archive.proto new file mode 100644 index 000000000..5e3a1bde8 --- /dev/null +++ b/proto/model/v1/archive.proto @@ -0,0 +1,85 @@ +// Copyright 2022 The jackal Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax="proto3"; + +import "google/protobuf/timestamp.proto"; + +import "github.com/jackal-xmpp/stravaganza/stravaganza.proto"; + +package model.archive.v1; + +option go_package = "pkg/model/archive/;archivemodel"; + +// Message represents an archive message entity. +message Message { + // archived_id is the message archive identifier. + string archive_id = 1; + + // id is the message archive unique identifier. + string id = 2; + + // from_jid is the message from jid value. + string from_jid = 3; + + // to_jid is the message from jid value. + string to_jid = 4; + + // message is the archived message. + stravaganza.PBElement message = 5; + + // stamp is the timestamp in which the message was archived. + google.protobuf.Timestamp stamp = 9; +} + +// Messages represents a set of archive messages. +message Messages { + repeated Message archive_messages = 1; +} + +// Metadata represents an archive metadata information. +message Metadata { + // start_timestamp is the identifier of the first archive message. + string start_id = 1; + + // start_timestamp is the timestamp value of the first archive message. + string start_timestamp = 2; + + // end_id is the identifier of the last archive message. + string end_id = 3; + + // end_timestamp is the timestamp value of the last archive message. + string end_timestamp = 4; +} + +// Filters define a set of filters to be applied when fetching archive messages. +message Filters { + // start is used to filter out messages before a certain date/time. + google.protobuf.Timestamp start = 1; + + // end is used to filter out messages after a certain date/time. + google.protobuf.Timestamp end = 2; + + // with contains a JID against which to match messages. + string with = 3; + + // before_id is the id of the newest message user wants to fetch. + string before_id = 4; + + // after_id is the id of the oldest message user wants to fetch. + string after_id = 5; + + // ids contains one or more ids the user wants to fetch. + repeated string ids = 6; +} diff --git a/scripts/genproto.sh b/scripts/genproto.sh index bf7be5bfd..dca80d4d0 100755 --- a/scripts/genproto.sh +++ b/scripts/genproto.sh @@ -15,6 +15,7 @@ FILES=( "admin/v1/users.proto" "c2s/v1/resourceinfo.proto" "cluster/v1/cluster.proto" + "model/v1/archive.proto" "model/v1/user.proto" "model/v1/last.proto" "model/v1/blocklist.proto" diff --git a/sql/postgres.up.psql b/sql/postgres.up.psql index 3be5cf8da..39d0ad1c1 100644 --- a/sql/postgres.up.psql +++ b/sql/postgres.up.psql @@ -170,3 +170,25 @@ CREATE TABLE IF NOT EXISTS vcards ( ); SELECT enable_updated_at('vcards'); + +-- archives + +CREATE TABLE IF NOT EXISTS archives ( + serial SERIAL PRIMARY KEY, + archive_id VARCHAR(1023), + id VARCHAR(255) NOT NULL, + "from" TEXT NOT NULL, + from_bare TEXT NOT NULL, + "to" TEXT NOT NULL, + to_bare TEXT NOT NULL, + message BYTEA NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_archives_archive_id ON archives(archive_id); +CREATE INDEX IF NOT EXISTS i_archives_id ON archives(id); +CREATE INDEX IF NOT EXISTS i_archives_to ON archives("to"); +CREATE INDEX IF NOT EXISTS i_archives_to_bare ON archives(to_bare); +CREATE INDEX IF NOT EXISTS i_archives_from ON archives("from"); +CREATE INDEX IF NOT EXISTS i_archives_from_bare ON archives(from_bare); +CREATE INDEX IF NOT EXISTS i_archives_created_at ON archives(created_at);