Skip to content

Commit

Permalink
Add more contexts everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 7, 2024
1 parent 0a302c7 commit 25bc36b
Show file tree
Hide file tree
Showing 37 changed files with 879 additions and 833 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

* **Breaking change *(bridge)*** Added raw event to portal membership handling
functions.
* **Breaking change *(client)*** Added context parameters to all functions
(thanks to [@recht] in [#144]).
* **Breaking change *(everything)*** Added context parameters to all functions
(started by [@recht] in [#144]).
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
(thanks to [@DerLukas15] in [#106]).
* You can use the `goolm` build tag to the new implementation.
Expand Down
10 changes: 5 additions & 5 deletions appservice/appservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
type StateStore interface {
mautrix.StateStore

IsRegistered(userID id.UserID) bool
MarkRegistered(userID id.UserID)
IsRegistered(ctx context.Context, userID id.UserID) (bool, error)
MarkRegistered(ctx context.Context, userID id.UserID) error

GetPowerLevel(roomID id.RoomID, userID id.UserID) int
GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error)
GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error)
HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error)
}

// AppService is the main config for all appservices.
Expand Down
2 changes: 1 addition & 1 deletion appservice/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
}

if evt.Type.IsState() {
mautrix.UpdateStateStore(as.StateStore, evt)
mautrix.UpdateStateStore(ctx, as.StateStore, evt)
}
var ch chan *event.Event
if evt.Type.Class == event.ToDeviceEventType {
Expand Down
53 changes: 39 additions & 14 deletions appservice/intent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"strings"
"sync"

"github.com/rs/zerolog"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
Expand Down Expand Up @@ -57,17 +59,26 @@ func (intent *IntentAPI) Register(ctx context.Context) error {
}

func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error {
if intent.IsCustomPuppet {
return nil
}
intent.registerLock.Lock()
defer intent.registerLock.Unlock()
if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
isRegistered, err := intent.as.StateStore.IsRegistered(ctx, intent.UserID)
if err != nil {
return fmt.Errorf("failed to check if user is registered: %w", err)
} else if isRegistered {
return nil
}

err := intent.Register(ctx)
err = intent.Register(ctx)
if err != nil && !errors.Is(err, mautrix.MUserInUse) {
return fmt.Errorf("failed to ensure registered: %w", err)
}
intent.as.StateStore.MarkRegistered(intent.UserID)
err = intent.as.StateStore.MarkRegistered(ctx, intent.UserID)
if err != nil {
return fmt.Errorf("failed to mark user as registered in state store: %w", err)
}
return nil
}

Expand All @@ -83,7 +94,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
} else if len(extra) == 1 {
params = extra[0]
}
if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
if intent.as.StateStore.IsInRoom(ctx, roomID, intent.UserID) && !params.IgnoreCache {
return nil
}

Expand Down Expand Up @@ -111,7 +122,10 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
return fmt.Errorf("failed to ensure joined after invite: %w", err)
}
}
intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
err = intent.as.StateStore.SetMembership(ctx, resp.RoomID, intent.UserID, event.MembershipJoin)
if err != nil {
return fmt.Errorf("failed to set membership in state store: %w", err)
}
return nil
}

Expand Down Expand Up @@ -205,13 +219,14 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
Membership: membership,
Reason: reason,
}
memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
if !ok {
memberContent, err := intent.as.StateStore.TryGetMember(ctx, roomID, target)
if err != nil {
return nil, fmt.Errorf("failed to get old member content from state store: %w", err)
} else if memberContent == nil {
if intent.as.GetProfile != nil {
memberContent = intent.as.GetProfile(target, roomID)
ok = memberContent != nil
}
if !ok {
if memberContent == nil {
profile, err := intent.GetProfile(ctx, target)
if err != nil {
intent.Log.Debug().Err(err).
Expand All @@ -224,7 +239,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
}
}
}
if ok && memberContent != nil {
if memberContent != nil {
content.Displayname = memberContent.Displayname
content.AvatarURL = memberContent.AvatarURL
}
Expand Down Expand Up @@ -297,15 +312,25 @@ func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *m
}

func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
if !ok {
member, err := intent.as.StateStore.TryGetMember(ctx, roomID, userID)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Str("room_id", roomID.String()).
Str("user_id", userID.String()).
Msg("Failed to get member from state store")
}
if member == nil {
_ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member)
}
return member
}

func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
pl = intent.as.StateStore.GetPowerLevels(roomID)
pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get cached power levels: %w", err)
return
}
if pl == nil {
pl = &event.PowerLevelsEventContent{}
err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl)
Expand Down Expand Up @@ -417,7 +442,7 @@ func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error
}

func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error {
if !intent.as.StateStore.IsInvited(roomID, userID) {
if !intent.as.StateStore.IsInvited(ctx, roomID, userID) {
_, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{
UserID: userID,
})
Expand Down
3 changes: 1 addition & 2 deletions appservice/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ import (
"os"
"regexp"

"gopkg.in/yaml.v3"

"go.mau.fi/util/random"
"gopkg.in/yaml.v3"
)

// Registration contains the data in a Matrix appservice registration.
Expand Down
18 changes: 9 additions & 9 deletions bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ type Bridge struct {

type Crypto interface {
HandleMemberEvent(*event.Event)
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, *event.Content) error
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
Decrypt(context.Context, *event.Event) (*event.Event, error)
Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
ResetSession(context.Context, id.RoomID)
Init(ctx context.Context) error
Start()
Stop()
Reset(startAfterReset bool)
Reset(ctx context.Context, startAfterReset bool)
Client() *mautrix.Client
ShareKeys(context.Context) error
}
Expand Down Expand Up @@ -650,10 +650,10 @@ func (br *Bridge) WaitWebsocketConnected() {

func (br *Bridge) start() {
br.ZLog.Debug().Msg("Running database upgrades")
err := br.DB.Upgrade()
err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO()))
if err != nil {
br.LogDBUpgradeErrorAndExit("main", err)
} else if err = br.StateStore.Upgrade(); err != nil {
} else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil {
br.LogDBUpgradeErrorAndExit("matrix_state", err)
}

Expand All @@ -679,7 +679,7 @@ func (br *Bridge) start() {
go br.fetchMediaConfig(ctx)

if br.Crypto != nil {
err = br.Crypto.Init()
err = br.Crypto.Init(ctx)
if err != nil {
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption")
os.Exit(19)
Expand Down
2 changes: 1 addition & 1 deletion bridge/commands/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var CommandDiscardMegolmSession = &FullHandler{
if ce.Bridge.Crypto == nil {
ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled")
} else {
ce.Bridge.Crypto.ResetSession(ce.RoomID)
ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID)
ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.")
}
},
Expand Down
Loading

0 comments on commit 25bc36b

Please sign in to comment.