From 25bc36bc7ae79afe8b5e5f053fbd8bc8fa68acbc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Jan 2024 22:44:06 +0200 Subject: [PATCH] Add more contexts everywhere --- CHANGELOG.md | 4 +- appservice/appservice.go | 10 +- appservice/http.go | 2 +- appservice/intent.go | 53 +- appservice/registration.go | 3 +- bridge/bridge.go | 18 +- bridge/commands/admin.go | 2 +- bridge/crypto.go | 91 ++-- bridge/cryptostore.go | 6 +- bridge/matrix.go | 10 +- client.go | 113 +++-- crypto/account.go | 2 +- crypto/cross_sign_pubkey.go | 4 +- crypto/cross_sign_signing.go | 12 +- crypto/cross_sign_store.go | 10 +- crypto/cross_sign_test.go | 32 +- crypto/cross_sign_validation.go | 14 +- crypto/cryptohelper/cryptohelper.go | 44 +- crypto/decryptmegolm.go | 12 +- crypto/decryptolm.go | 10 +- crypto/devicelist.go | 33 +- crypto/encryptmegolm.go | 22 +- crypto/encryptolm.go | 12 +- crypto/keyimport.go | 16 +- crypto/keysharing.go | 15 +- crypto/machine.go | 54 +- crypto/machine_test.go | 22 +- crypto/sql_store.go | 465 ++++++++---------- crypto/sql_store_upgrade/upgrade.go | 3 +- crypto/store.go | 138 +++--- crypto/store_test.go | 34 +- crypto/verification.go | 4 +- go.mod | 2 +- go.sum | 4 +- sqlstatestore/statestore.go | 309 ++++++------ .../v05-mark-encryption-state-resync.go | 7 +- statestore.go | 134 ++--- 37 files changed, 886 insertions(+), 840 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6b61ed..7abbe58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,8 @@ * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. -* **Breaking change *(client)*** Added context parameters to all functions - (thanks to [@recht] in [#144]). +* **Breaking change *(everything)*** Added context parameters to all functions + (started by [@recht] in [#144]). * *(crypto)* Added experimental pure Go Olm implementation to replace libolm (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. diff --git a/appservice/appservice.go b/appservice/appservice.go index 98d1463..dc5e82b 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -93,12 +93,12 @@ type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) type StateStore interface { mautrix.StateStore - IsRegistered(userID id.UserID) bool - MarkRegistered(userID id.UserID) + IsRegistered(ctx context.Context, userID id.UserID) (bool, error) + MarkRegistered(ctx context.Context, userID id.UserID) error - GetPowerLevel(roomID id.RoomID, userID id.UserID) int - GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int - HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool + GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) + GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) + HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) } // AppService is the main config for all appservices. diff --git a/appservice/http.go b/appservice/http.go index 2219687..1d4c7f2 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -236,7 +236,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } if evt.Type.IsState() { - mautrix.UpdateStateStore(as.StateStore, evt) + mautrix.UpdateStateStore(ctx, as.StateStore, evt) } var ch chan *event.Event if evt.Type.Class == event.ToDeviceEventType { diff --git a/appservice/intent.go b/appservice/intent.go index f5f066d..bdf0f06 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -13,6 +13,8 @@ import ( "strings" "sync" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -57,17 +59,26 @@ func (intent *IntentAPI) Register(ctx context.Context) error { } func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { + if intent.IsCustomPuppet { + return nil + } intent.registerLock.Lock() defer intent.registerLock.Unlock() - if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { + isRegistered, err := intent.as.StateStore.IsRegistered(ctx, intent.UserID) + if err != nil { + return fmt.Errorf("failed to check if user is registered: %w", err) + } else if isRegistered { return nil } - err := intent.Register(ctx) + err = intent.Register(ctx) if err != nil && !errors.Is(err, mautrix.MUserInUse) { return fmt.Errorf("failed to ensure registered: %w", err) } - intent.as.StateStore.MarkRegistered(intent.UserID) + err = intent.as.StateStore.MarkRegistered(ctx, intent.UserID) + if err != nil { + return fmt.Errorf("failed to mark user as registered in state store: %w", err) + } return nil } @@ -83,7 +94,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext } else if len(extra) == 1 { params = extra[0] } - if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache { + if intent.as.StateStore.IsInRoom(ctx, roomID, intent.UserID) && !params.IgnoreCache { return nil } @@ -111,7 +122,10 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return fmt.Errorf("failed to ensure joined after invite: %w", err) } } - intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin) + err = intent.as.StateStore.SetMembership(ctx, resp.RoomID, intent.UserID, event.MembershipJoin) + if err != nil { + return fmt.Errorf("failed to set membership in state store: %w", err) + } return nil } @@ -205,13 +219,14 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i Membership: membership, Reason: reason, } - memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target) - if !ok { + memberContent, err := intent.as.StateStore.TryGetMember(ctx, roomID, target) + if err != nil { + return nil, fmt.Errorf("failed to get old member content from state store: %w", err) + } else if memberContent == nil { if intent.as.GetProfile != nil { memberContent = intent.as.GetProfile(target, roomID) - ok = memberContent != nil } - if !ok { + if memberContent == nil { profile, err := intent.GetProfile(ctx, target) if err != nil { intent.Log.Debug().Err(err). @@ -224,7 +239,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i } } } - if ok && memberContent != nil { + if memberContent != nil { content.Displayname = memberContent.Displayname content.AvatarURL = memberContent.AvatarURL } @@ -297,15 +312,25 @@ func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *m } func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := intent.as.StateStore.TryGetMember(roomID, userID) - if !ok { + member, err := intent.as.StateStore.TryGetMember(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Str("room_id", roomID.String()). + Str("user_id", userID.String()). + Msg("Failed to get member from state store") + } + if member == nil { _ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member) } return member } func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { - pl = intent.as.StateStore.GetPowerLevels(roomID) + pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get cached power levels: %w", err) + return + } if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) @@ -417,7 +442,7 @@ func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error } func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { - if !intent.as.StateStore.IsInvited(roomID, userID) { + if !intent.as.StateStore.IsInvited(ctx, roomID, userID) { _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: userID, }) diff --git a/appservice/registration.go b/appservice/registration.go index 464ea1d..b11bd84 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -10,9 +10,8 @@ import ( "os" "regexp" - "gopkg.in/yaml.v3" - "go.mau.fi/util/random" + "gopkg.in/yaml.v3" ) // Registration contains the data in a Matrix appservice registration. diff --git a/bridge/bridge.go b/bridge/bridge.go index 960dce9..6ad1972 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -215,15 +215,15 @@ type Bridge struct { type Crypto interface { HandleMemberEvent(*event.Event) - Decrypt(*event.Event) (*event.Event, error) - Encrypt(id.RoomID, event.Type, *event.Content) error - WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + Decrypt(context.Context, *event.Event) (*event.Event, error) + Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(id.RoomID) - Init() error + ResetSession(context.Context, id.RoomID) + Init(ctx context.Context) error Start() Stop() - Reset(startAfterReset bool) + Reset(ctx context.Context, startAfterReset bool) Client() *mautrix.Client ShareKeys(context.Context) error } @@ -650,10 +650,10 @@ func (br *Bridge) WaitWebsocketConnected() { func (br *Bridge) start() { br.ZLog.Debug().Msg("Running database upgrades") - err := br.DB.Upgrade() + err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) if err != nil { br.LogDBUpgradeErrorAndExit("main", err) - } else if err = br.StateStore.Upgrade(); err != nil { + } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { br.LogDBUpgradeErrorAndExit("matrix_state", err) } @@ -679,7 +679,7 @@ func (br *Bridge) start() { go br.fetchMediaConfig(ctx) if br.Crypto != nil { - err = br.Crypto.Init() + err = br.Crypto.Init(ctx) if err != nil { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") os.Exit(19) diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go index cf38d6c..ff3340e 100644 --- a/bridge/commands/admin.go +++ b/bridge/commands/admin.go @@ -17,7 +17,7 @@ var CommandDiscardMegolmSession = &FullHandler{ if ce.Bridge.Crypto == nil { ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") } else { - ce.Bridge.Crypto.ResetSession(ce.RoomID) + ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") } }, diff --git a/bridge/crypto.go b/bridge/crypto.go index a1a76eb..872bf8a 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -61,7 +61,7 @@ func NewCryptoHelper(bridge *Bridge) Crypto { } } -func (helper *CryptoHelper) Init() error { +func (helper *CryptoHelper) Init(ctx context.Context) error { if len(helper.bridge.CryptoPickleKey) == 0 { panic("CryptoPickleKey not set") } @@ -75,13 +75,13 @@ func (helper *CryptoHelper) Init() error { helper.bridge.CryptoPickleKey, ) - err := helper.store.DB.Upgrade() + err := helper.store.DB.Upgrade(ctx) if err != nil { helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) } var isExistingDevice bool - helper.client, isExistingDevice, err = helper.loginBot(context.TODO()) + helper.client, isExistingDevice, err = helper.loginBot(ctx) if err != nil { return err } @@ -111,7 +111,7 @@ func (helper *CryptoHelper) Init() error { } if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { - deleted, err := helper.store.RedactOutdatedGroupSessions() + deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) if err != nil { return err } @@ -123,12 +123,12 @@ func (helper *CryptoHelper) Init() error { helper.client.Syncer = &cryptoSyncer{helper.mach} helper.client.Store = helper.store - err = helper.mach.Load() + err = helper.mach.Load(ctx) if err != nil { return err } if isExistingDevice { - helper.verifyKeysAreOnServer(context.TODO()) + helper.verifyKeysAreOnServer(ctx) } go helper.resyncEncryptionInfo(context.TODO()) @@ -138,22 +138,16 @@ func (helper *CryptoHelper) Init() error { func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - var roomIDs []id.RoomID - for rows.Next() { - var roomID id.RoomID - err = rows.Scan(&roomID) - if err != nil { - log.Err(err).Msg("Failed to scan room ID") - continue - } - roomIDs = append(roomIDs, roomID) + roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + if err != nil { + log.Err(err).Msg("Failed to scan rooms for resync") + return } - _ = rows.Close() if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { @@ -161,7 +155,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.ExecContext(ctx, ` + _, err = helper.bridge.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { @@ -182,7 +176,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { Int("max_messages", maxMessages). Interface("content", &evt). Msg("Resynced encryption event") - _, err = helper.bridge.DB.ExecContext(ctx, ` + _, err = helper.bridge.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL @@ -223,8 +217,10 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device } func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { - deviceID := helper.store.FindDeviceID() - if len(deviceID) > 0 { + deviceID, err := helper.store.FindDeviceID(ctx) + if err != nil { + return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) + } else if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), @@ -270,7 +266,7 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { return } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") - helper.Reset(false) + helper.Reset(ctx, false) } func (helper *CryptoHelper) Start() { @@ -306,16 +302,16 @@ func (helper *CryptoHelper) Stop() { helper.syncDone.Wait() } -func (helper *CryptoHelper) clearDatabase() { - _, err := helper.store.DB.Exec("DELETE FROM crypto_account") +func (helper *CryptoHelper) clearDatabase(ctx context.Context) { + _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") } - _, err = helper.store.DB.Exec("DELETE FROM crypto_olm_session") + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") } - _, err = helper.store.DB.Exec("DELETE FROM crypto_megolm_outbound_session") + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") } @@ -325,22 +321,22 @@ func (helper *CryptoHelper) clearDatabase() { //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") } -func (helper *CryptoHelper) Reset(startAfterReset bool) { +func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { helper.lock.Lock() defer helper.lock.Unlock() helper.log.Info().Msg("Resetting end-to-bridge encryption device") helper.Stop() helper.log.Debug().Msg("Crypto syncer stopped, clearing database") - helper.clearDatabase() + helper.clearDatabase(ctx) helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll(context.TODO()) + _, err := helper.client.LogoutAll(ctx) if err != nil { helper.log.Warn().Err(err).Msg("Failed to log out all devices") } helper.client = nil helper.store = nil helper.mach = nil - err = helper.Init() + err = helper.Init(ctx) if err != nil { helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") os.Exit(50) @@ -355,25 +351,24 @@ func (helper *CryptoHelper) Client() *mautrix.Client { return helper.client } -func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { - return helper.mach.DecryptMegolmEvent(context.TODO(), evt) +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return helper.mach.DecryptMegolmEvent(ctx, evt) } -func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { helper.lock.RLock() defer helper.lock.RUnlock() var encrypted *event.EncryptedEventContent - ctx := context.TODO() encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { return } helper.log.Debug().Err(err). Str("room_id", roomID.String()). Msg("Got error while encrypting event for room, sharing group session and trying again...") var users []id.UserID - users, err = helper.store.GetRoomJoinedOrInvitedMembers(roomID) + users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) if err != nil { err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { @@ -389,10 +384,10 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten return } -func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { helper.lock.RLock() defer helper.lock.RUnlock() - return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -419,10 +414,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) ResetSession(roomID id.RoomID) { +func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { helper.lock.RLock() defer helper.lock.RUnlock() - err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID) + err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) if err != nil { helper.log.Debug().Err(err). Str("room_id", roomID.String()). @@ -499,18 +494,18 @@ type cryptoStateStore struct { var _ crypto.StateStore = (*cryptoStateStore)(nil) -func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { +func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { portal := c.bridge.Child.GetIPortal(id) if portal != nil { - return portal.IsEncrypted() + return portal.IsEncrypted(), nil } - return c.bridge.StateStore.IsEncrypted(id) + return c.bridge.StateStore.IsEncrypted(ctx, id) } -func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID { - return c.bridge.StateStore.FindSharedRooms(id) +func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { + return c.bridge.StateStore.FindSharedRooms(ctx, id) } -func (c *cryptoStateStore) GetEncryptionEvent(id id.RoomID) *event.EncryptionEventContent { - return c.bridge.StateStore.GetEncryptionEvent(id) +func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { + return c.bridge.StateStore.GetEncryptionEvent(ctx, id) } diff --git a/bridge/cryptostore.go b/bridge/cryptostore.go index e199f5a..dde48a2 100644 --- a/bridge/cryptostore.go +++ b/bridge/cryptostore.go @@ -9,6 +9,8 @@ package bridge import ( + "context" + "github.com/lib/pq" "go.mau.fi/util/dbutil" @@ -36,9 +38,9 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id } } -func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) { +func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { var rows dbutil.Rows - rows, err = store.DB.Query(` + rows, err = store.DB.Query(ctx, ` SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND (membership='join' OR membership='invite') diff --git a/bridge/matrix.go b/bridge/matrix.go index 90453c1..00994dd 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -494,7 +494,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { log.Debug().Msg("Decrypting received event") decryptionStart := time.Now() - decrypted, err := mx.bridge.Crypto.Decrypt(evt) + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 @@ -502,9 +502,9 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) - if mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = mx.bridge.Crypto.Decrypt(evt) + decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) } else { go mx.waitLongerForSession(ctx, evt, decryptionStart) return @@ -529,14 +529,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) - if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up trying to decrypt event") mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) return } log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := mx.bridge.Crypto.Decrypt(evt) + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt event") mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) diff --git a/client.go b/client.go index 1c2b968..d1a6d8f 100644 --- a/client.go +++ b/client.go @@ -27,11 +27,11 @@ import ( ) type CryptoHelper interface { - Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) - Decrypt(*event.Event) (*event.Event, error) - WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + Encrypt(context.Context, id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) + Decrypt(context.Context, *event.Event) (*event.Event, error) + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - Init() error + Init(context.Context) error } // Deprecated: switch to zerolog @@ -846,7 +846,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin } _, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } } return } @@ -858,7 +861,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { _, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } } return } @@ -1000,13 +1006,20 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event queryParams["fi.mau.event_id"] = req.MeowEventID.String() } - if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) { - contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON) + if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { - err = fmt.Errorf("failed to encrypt event: %w", err) + err = fmt.Errorf("failed to check if room is encrypted: %w", err) return } - eventType = event.EventEncrypted + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } } urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} @@ -1021,7 +1034,7 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return } @@ -1034,7 +1047,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, }) _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return } @@ -1100,19 +1113,29 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re urlPath := cli.BuildClientURL("v3", "createRoom") _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if storeErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(storeErr). + Stringer("creator_user_id", cli.UserID). + Msg("Failed to update creator membership in state store after creating room") + } for _, evt := range req.InitialState { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite if req.BeeperAutoJoinInvites { inviteMembership = event.MembershipJoin } for _, invitee := range req.Invite { - cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership) + storeErr = cli.StateStore.SetMembership(ctx, resp.RoomID, invitee, inviteMembership) + if storeErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(storeErr). + Stringer("invitee_user_id", invitee). + Msg("Failed to update membership in state store after creating room") + } } for _, evt := range req.InitialState { - cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) + cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) } } return @@ -1129,7 +1152,10 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq u := cli.BuildClientURL("v3", "rooms", roomID, "leave") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1146,7 +1172,10 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv u := cli.BuildClientURL("v3", "rooms", roomID, "invite") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1163,7 +1192,10 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU u := cli.BuildClientURL("v3", "rooms", roomID, "kick") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1173,7 +1205,10 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse u := cli.BuildClientURL("v3", "rooms", roomID, "ban") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1183,7 +1218,10 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba u := cli.BuildClientURL("v3", "rooms", roomID, "unban") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1216,7 +1254,7 @@ func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err return } -func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { +func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { if cli.StateStore == nil { return } @@ -1246,7 +1284,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even } return } - UpdateStateStore(cli.StateStore, fakeEvt) + UpdateStateStore(ctx, cli.StateStore, fakeEvt) } // StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with @@ -1256,7 +1294,7 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent) } return } @@ -1310,10 +1348,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt Handler: parseRoomStateArray, }) if err == nil && cli.StateStore != nil { - cli.StateStore.ClearCachedMembers(roomID) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching state") for _, evts := range stateMap { for _, evt := range evts { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } } } @@ -1630,13 +1671,22 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") for userID, member := range resp.Joined { - cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{ + updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ Membership: event.MembershipJoin, AvatarURL: id.ContentURIString(member.AvatarURL), Displayname: member.DisplayName, }) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Stringer("user_id", userID). + Msg("Failed to update membership in state store after fetching joined members") + } } } return @@ -1665,10 +1715,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb clearMemberships = append(clearMemberships, extra.Membership) } if extra.NotMembership == "" { - cli.StateStore.ClearCachedMembers(roomID, clearMemberships...) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") } for _, evt := range resp.Chunk { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } } return diff --git a/crypto/account.go b/crypto/account.go index a667825..0eb18a2 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 9f4f358..77efab5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -42,7 +42,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *Cross } func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) { - dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to get keys from database: %w", err) } diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 1a5a023..f6c37a9 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -34,8 +34,8 @@ var ( ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC") ) -func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { - crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID) +func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { + crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID) if err != nil { return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err) } @@ -85,7 +85,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe Str("signature", signature). Msg("Signed master key of user with our user-signing key") - if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -137,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures) } - if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -178,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er Str("signature", signature). Msg("Signed own device key with self-signing key") - if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index f1008eb..88fcd0e 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { log := log.With().Str("user_id", userID.String()).Logger() - currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") @@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK if newKeyUsage == curKeyUsage { if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil { + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { log.Error().Err(err).Msg("Error deleting old signatures made by user") } else { log.Debug(). @@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") - if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil { + if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { log.Error().Err(err).Msg("Error storing cross-signing key") } } @@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } else { if verified { log.Debug().Err(err).Msg("Cross-signing key signature verified") - err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature) + err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature) if err != nil { log.Error().Err(err).Msg("Error storing cross-signing key signature") } diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index 847c87f..b53da10 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -32,7 +32,7 @@ func getOlmMachine(t *testing.T) *OlmMachine { t.Fatalf("Error opening db: %v", err) } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(); err != nil { + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { t.Fatalf("Error creating tables: %v", err) } @@ -41,9 +41,9 @@ func getOlmMachine(t *testing.T) *OlmMachine { ssk, _ := olm.NewPkSigning() usk, _ := olm.NewPkSigning() - sqlStore.PutCrossSigningKey(userID, id.XSUsageMaster, mk.PublicKey) - sqlStore.PutCrossSigningKey(userID, id.XSUsageSelfSigning, ssk.PublicKey) - sqlStore.PutCrossSigningKey(userID, id.XSUsageUserSigning, usk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey) return &OlmMachine{ CryptoStore: sqlStore, @@ -70,9 +70,9 @@ func TestTrustOwnDevice(t *testing.T) { t.Error("Own device trusted while it shouldn't be") } - m.CryptoStore.PutSignature(ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") - m.CryptoStore.PutSignature(ownDevice.UserID, ownDevice.SigningKey, + m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { @@ -91,20 +91,20 @@ func TestTrustOtherUser(t *testing.T) { } theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) - m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") // sign them with self-signing instead of user-signing key - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { t.Error("Other user trusted before their master key has been signed with our user-signing key") } - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { @@ -128,27 +128,27 @@ func TestTrustOtherDevice(t *testing.T) { } theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) theirSSK, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) - m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { t.Error("Other user not trusted while they should be") } - m.CryptoStore.PutSignature(otherUser, theirSSK.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey, otherUser, theirMasterKey.PublicKey, "sig3") if m.IsDeviceTrusted(theirDevice) { t.Error("Other device trusted before it has been signed with user's SSK") } - m.CryptoStore.PutSignature(otherUser, theirDevice.SigningKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey, "sig4") if !m.IsDeviceTrusted(theirDevice) { diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index 27afeb7..ff2452e 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted { return device.Trust, nil } - theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID) + theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi Msg("Self-signing key of user not found") return id.TrustStateUnset, nil } - sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key) + sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi Msg("Self-signing key of user is not signed by their master key") return id.TrustStateUnset, nil } - deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key) + deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo return true, nil } // first we verify our user-signing key - ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey) + ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database") return false, err } else if !ourUserSigningKeyTrusted { return false, nil } - theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", userID.String()). @@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo Msg("Master key of user not found") return false, nil } - sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey) + sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", userID.String()). diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 293166f..eb7d7a7 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH }, nil } -func (helper *CryptoHelper) Init() error { +func (helper *CryptoHelper) Init(ctx context.Context) error { if helper == nil { return fmt.Errorf("crypto helper is nil") } @@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error { var stateStore crypto.StateStore if helper.managedStateStore != nil { - err := helper.managedStateStore.Upgrade() + err := helper.managedStateStore.Upgrade(ctx) if err != nil { return fmt.Errorf("failed to upgrade client state store: %w", err) } @@ -124,7 +124,6 @@ func (helper *CryptoHelper) Init() error { } else { stateStore = helper.client.StateStore.(crypto.StateStore) } - ctx := context.TODO() var cryptoStore crypto.Store if helper.unmanagedCryptoStore == nil { managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey) @@ -133,11 +132,14 @@ func (helper *CryptoHelper) Init() error { } else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory { helper.client.Store = managedCryptoStore } - err := managedCryptoStore.DB.Upgrade() + err := managedCryptoStore.DB.Upgrade(ctx) if err != nil { return fmt.Errorf("failed to upgrade crypto state store: %w", err) } - storedDeviceID := managedCryptoStore.FindDeviceID() + storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx) + if err != nil { + return fmt.Errorf("failed to find existing device ID: %w", err) + } if helper.LoginAs != nil { if storedDeviceID != "" { helper.LoginAs.DeviceID = storedDeviceID @@ -168,7 +170,7 @@ func (helper *CryptoHelper) Init() error { return fmt.Errorf("the client must be logged in") } helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore) - err := helper.mach.Load() + err := helper.mach.Load(ctx) if err != nil { return fmt.Errorf("failed to load olm account: %w", err) } else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil { @@ -253,17 +255,18 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. Str("session_id", content.SessionID.String()). Logger() log.Debug().Msg("Decrypting received event") + ctx := log.WithContext(context.TODO()) - decrypted, err := helper.Decrypt(evt) + decrypted, err := helper.Decrypt(ctx, evt) if errors.Is(err, NoSessionFound) { log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = helper.Decrypt(evt) + decrypted, err = helper.Decrypt(ctx, evt) } else { - go helper.waitLongerForSession(log, src, evt) + go helper.waitLongerForSession(ctx, log, src, evt) return } } @@ -306,20 +309,20 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up") helper.DecryptErrorCallback(evt, NoSessionFound) return } log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := helper.Decrypt(evt) + decrypted, err := helper.Decrypt(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) @@ -329,32 +332,31 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix helper.postDecrypt(src, decrypted) } -func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { if helper == nil { return false } helper.lock.RLock() defer helper.lock.RUnlock() - return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) } -func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } - return helper.mach.DecryptMegolmEvent(context.TODO(), evt) + return helper.mach.DecryptMegolmEvent(ctx, evt) } -func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } helper.lock.RLock() defer helper.lock.RUnlock() - ctx := context.TODO() encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { + if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { return } helper.log.Debug(). @@ -362,7 +364,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten Str("room_id", roomID.String()). Msg("Got session error while encrypting event, sharing group session and trying again") var users []id.UserID - users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID) + users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID) if err != nil { err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index eaff136..540f99c 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else { forwardedKeys = true lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1] - device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem)) + device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem)) if device != nil { trustLevel = mach.ResolveTrust(device) } else { @@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve mach.megolmDecryptLock.Lock() defer mach.megolmDecryptLock.Unlock() - sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { @@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve Int("max_messages", sess.MaxMessages). Logger() if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt { - err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") + err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") return sess, plaintext, messageIndex, RatchetError @@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") return sess, plaintext, messageIndex, RatchetError - } else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + } else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { - if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") return sess, plaintext, messageIndex, RatchetError } else { diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 57b39f0..f99c7db 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) - err = mach.CryptoStore.UpdateSession(senderKey, session) + err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting") @@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) - sessions, err := mach.CryptoStore.GetSessions(senderKey) + sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) endTimeTrace() if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) @@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) - err = mach.CryptoStore.UpdateSession(senderKey, session) + err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") @@ -217,7 +217,7 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S return nil, err } mach.saveAccount() - err = mach.CryptoStore.AddSession(senderKey, session) + err = mach.CryptoStore.AddSession(ctx, senderKey, session) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session") } diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 8514275..e554480 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -53,7 +53,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id Str("signed_device_id", deviceID.String()). Str("signature", signature). Msg("Verified self-signing signature") - err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature) + err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature) if err != nil { log.Warn().Err(err). Str("signer_user_id", signerUserID.String()). @@ -74,7 +74,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } // save signature of device made by its own device signing key if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { - err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature) + err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature) if err != nil { log.Warn().Err(err). Str("signer_user_id", signerUserID.String()). @@ -96,7 +96,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT log := mach.machOrContextLog(ctx) if !includeUntracked { var err error - users, err = mach.CryptoStore.FilterTrackedUsers(users) + users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users) if err != nil { log.Warn().Err(err).Msg("Failed to filter tracked user list") } @@ -123,7 +123,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) - existingDevices, err := mach.CryptoStore.GetDevices(userID) + existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Warn().Err(err).Msg("Failed to get existing devices for user") existingDevices = make(map[id.DeviceID]*id.Device) @@ -151,7 +151,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT } } log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list") - err = mach.CryptoStore.PutDevices(userID, newDevices) + err = mach.CryptoStore.PutDevices(ctx, userID, newDevices) if err != nil { log.Warn().Err(err).Msg("Failed to update device list") } @@ -169,7 +169,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT Str("identity_key", device.IdentityKey.String()). Str("signing_key", device.SigningKey.String()). Logger() - sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed") + sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed") if err != nil { log.Err(err).Msg("Failed to redact megolm sessions from deleted device") } else { @@ -179,7 +179,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT } } } - mach.OnDevicesChanged(userID) + mach.OnDevicesChanged(ctx, userID) } } for userID := range req.DeviceKeys { @@ -197,18 +197,25 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT // // This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does // not need to be called manually. -func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) { +func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) { if mach.DisableDeviceChangeKeyRotation { return } - for _, roomID := range mach.StateStore.FindSharedRooms(userID) { - mach.Log.Debug(). + rooms, err := mach.StateStore.FindSharedRooms(ctx, userID) + if err != nil { + mach.machOrContextLog(ctx).Err(err). + Stringer("with_user_id", userID). + Msg("Failed to find shared rooms to invalidate group sessions") + return + } + for _, roomID := range rooms { + mach.machOrContextLog(ctx).Debug(). Str("user_id", userID.String()). Str("room_id", roomID.String()). Msg("Invalidating group session in room due to device change notification") - err := mach.CryptoStore.RemoveOutboundGroupSession(roomID) + err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) if err != nil { - mach.Log.Warn().Err(err). + mach.machOrContextLog(ctx).Err(err). Str("user_id", userID.String()). Str("room_id", roomID.String()). Msg("Failed to invalidate outbound group session") diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 078ef51..1eee2fe 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) { func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() - session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) + session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { @@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID log = log.With().Uint("message_index", idx).Logger() } log.Debug().Msg("Encrypted event successfully") - err = mach.CryptoStore.UpdateOutboundGroupSession(session) + err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session) if err != nil { log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting") } @@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID } func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession { - session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID)) + encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) + if err != nil { + mach.machOrContextLog(ctx).Err(err). + Stringer("room_id", roomID). + Msg("Failed to get encryption event in room") + } + session := NewOutboundGroupSession(roomID, encryptionEvent) if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) @@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string { func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() - session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) + session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { @@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, for _, userID := range users { log := log.With().Str("target_user_id", userID.String()).Logger() - devices, err := mach.CryptoStore.GetDevices(userID) + devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Error().Err(err).Msg("Failed to get devices of user") } else if devices == nil { @@ -292,7 +298,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, log.Debug().Msg("Group session successfully shared") session.Shared = true - return mach.CryptoStore.AddOutboundGroupSession(session) + return mach.CryptoStore.AddOutboundGroupSession(ctx, session) } func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error { @@ -367,7 +373,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "This device does not encrypt messages for unverified devices", }} session.Users[userKey] = OGSIgnored - } else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil { + } else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil { log.Error().Err(err).Msg("Failed to get olm session to encrypt group session") } else if deviceSession == nil { log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session") diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index f21ecd0..3b1d40d 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession Str("olm_session_description", session.Describe()). Msg("Encrypting olm message") msgType, ciphertext := session.Encrypt(plaintext) - err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session) + err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") } @@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession } } -func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool { - if !mach.CryptoStore.HasSession(identityKey) { +func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool { + if !mach.CryptoStore.HasSession(ctx, identityKey) { return true } mach.devicesToUnwedgeLock.Lock() @@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id for userID, devices := range input { request[userID] = make(map[id.DeviceID]id.KeyAlgorithm) for deviceID, identity := range devices { - if mach.shouldCreateNewSession(identity.IdentityKey) { + if mach.shouldCreateNewSession(ctx, identity.IdentityKey) { request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519 } } @@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key") } else { wrapped := wrapSession(sess) - err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped) + err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped) if err != nil { log.Error().Err(err).Msg("Failed to store created outbound session") } else { diff --git a/crypto/keyimport.go b/crypto/keyimport.go index ed66f23..2d9f348 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package crypto import ( "bytes" + "context" "crypto/aes" "crypto/cipher" "crypto/hmac" @@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession, return sessionsJSON, nil } -func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) { +func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) { if session.Algorithm != id.AlgorithmMegolmV1 { return false, ErrInvalidExportedAlgorithm } @@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er ReceivedAt: time.Now().UTC(), } - existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID()) + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false, nil } - err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } @@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er // ImportKeys imports data that was exported with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports -func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) { +func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) { exportData, err := decodeKeyExport(data) if err != nil { return 0, 0, err @@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er Str("room_id", session.RoomID.String()). Str("session_id", session.SessionID.String()). Logger() - imported, err := mach.importExportedRoomKey(session) + imported, err := mach.importExportedRoomKey(ctx, session) if err != nil { + if ctx.Err() != nil { + return count, len(sessions), ctx.Err() + } log.Error().Err(err).Msg("Failed to import Megolm session from file") } else if imported { log.Debug().Msg("Imported Megolm session from file") diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 9b8eef7..8cf15d3 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt Msg("Mismatched session ID while creating inbound group session from forward") return false } - config := mach.StateStore.GetEncryptionEvent(content.RoomID) + config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID) + if err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } var maxAge time.Duration var maxMessages int if config != nil { @@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxMessages: maxMessages, IsScheduled: content.IsScheduled, } - err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs) if err != nil { log.Error().Err(err).Msg("Failed to store new inbound group session") return false @@ -274,7 +277,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User return } - igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) + igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") @@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us Int("first_message_index", content.FirstMessageIndex). Logger() - sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Acked group session was already redacted") @@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 { log.Debug().Msg("Redacting inbound copy of outbound group session after ack") - err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") + err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") if err != nil { log.Err(err).Msg("Failed to redact group session") } diff --git a/crypto/machine.go b/crypto/machine.go index 35f8c12..da78eaf 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -80,11 +80,11 @@ type OlmMachine struct { // StateStore is used by OlmMachine to get room state information that's needed for encryption. type StateStore interface { // IsEncrypted returns whether a room is encrypted. - IsEncrypted(id.RoomID) bool + IsEncrypted(context.Context, id.RoomID) (bool, error) // GetEncryptionEvent returns the encryption event's content for an encrypted room. - GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent + GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) // FindSharedRooms returns the encrypted rooms that another user is also in for a user ID. - FindSharedRooms(id.UserID) []id.RoomID + FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) } // NewOlmMachine creates an OlmMachine with the given client, logger and stores. @@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. -func (mach *OlmMachine) Load() (err error) { - mach.account, err = mach.CryptoStore.GetAccount() +func (mach *OlmMachine) Load(ctx context.Context) (err error) { + mach.account, err = mach.CryptoStore.GetAccount(ctx) if err != nil { return } @@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) { } func (mach *OlmMachine) saveAccount() { - err := mach.CryptoStore.PutAccount(mach.account) + err := mach.CryptoStore.PutAccount(context.TODO(), mach.account) if err != nil { mach.Log.Error().Err(err).Msg("Failed to save account") } } // FlushStore calls the Flush method of the CryptoStore. -func (mach *OlmMachine) FlushStore() error { - return mach.CryptoStore.Flush() +func (mach *OlmMachine) FlushStore(ctx context.Context) error { + return mach.CryptoStore.Flush(ctx) } func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { @@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string // // client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent) func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) { - if !mach.StateStore.IsEncrypted(evt.RoomID) { + ctx := context.TODO() + if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil { + mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID). + Msg("Failed to check if room is encrypted to handle member event") + return + } else if !isEncrypted { return } content := evt.Content.AsMember() @@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even Str("prev_membership", string(prevContent.Membership)). Str("new_membership", string(content.Membership)). Msg("Got membership state change, invalidating group session in room") - err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID) + err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") } @@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { // GetOrFetchDevice attempts to retrieve the device identity for the given device from the store // and if it's not found it asks the server for it. func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { - device, err := mach.CryptoStore.GetDevice(userID, deviceID) + device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) } else if device != nil { @@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, // store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { - deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey) + deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) if err != nil || deviceIdentity != nil { return deviceIdentity, err } @@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De mach.olmLock.Lock() defer mach.olmLock.Unlock() - olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey) + olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) if err != nil { return err } @@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Msg("Mismatched session ID while creating inbound group session") return } - err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs) if err != nil { log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return @@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) { } // WaitForSession waits for the given Megolm session to arrive. -func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[sessionID] if !ok { @@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, } mach.keyWaitersLock.Unlock() // Handle race conditions where a session appears between the failed decryption and WaitForSession call. - sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) if sess != nil || errors.Is(err, ErrGroupSessionWithheld) { return true } @@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, case <-ch: return true case <-time.After(timeout): - sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) + sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) // Check if the session somehow appeared in the store without telling us // We accept withheld sessions as received, as then the decryption attempt will show the error. return sess != nil || errors.Is(err, ErrGroupSessionWithheld) + case <-ctx.Done(): + return false } } @@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve return } - config := mach.StateStore.GetEncryptionEvent(content.RoomID) + config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID) + if err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } var maxAge time.Duration var maxMessages int if config != nil { @@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve } if mach.DeletePreviousKeysOnReceive && !content.IsScheduled { log.Debug().Msg("Redacting previous megolm sessions from sender in room") - sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device") + sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device") if err != nil { log.Err(err).Msg("Failed to redact previous megolm sessions") } else { @@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") return } - err := mach.CryptoStore.PutWithheldGroupSession(*content) + err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event") } @@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { log := mach.Log.With().Str("action", "redact expired sessions").Logger() for { - sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions() + sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx) if err != nil { log.Err(err).Msg("Failed to redact expired megolm sessions") } else if len(sessionIDs) > 0 { diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 0271104..f1d00eb 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -20,18 +20,18 @@ import ( type mockStateStore struct{} -func (mockStateStore) IsEncrypted(id.RoomID) bool { - return true +func (mockStateStore) IsEncrypted(context.Context, id.RoomID) (bool, error) { + return true, nil } -func (mockStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { +func (mockStateStore) GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) { return &event.EncryptionEventContent{ RotationPeriodMessages: 3, - } + }, nil } -func (mockStateStore) FindSharedRooms(id.UserID) []id.RoomID { - return []id.RoomID{"room1"} +func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) { + return []id.RoomID{"room1"}, nil } func newMachine(t *testing.T, userID id.UserID) *OlmMachine { @@ -47,7 +47,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine { } machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(); err != nil { + if err := machine.Load(context.TODO()); err != nil { t.Fatalf("Error creating account: %v", err) } @@ -57,7 +57,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine { func TestRatchetMegolmSession(t *testing.T) { mach := newMachine(t, "user1") outSess := mach.newOutboundGroupSession(context.TODO(), "meow") - inSess, err := mach.CryptoStore.GetGroupSession("meow", mach.OwnIdentity().IdentityKey, outSess.ID()) + inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID()) require.NoError(t, err) assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) err = inSess.RatchetTo(10) @@ -85,7 +85,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { } // store sender device identity in receiving machine store - machineIn.CryptoStore.PutDevices("user1", map[id.DeviceID]*id.Device{ + machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ "device1": { UserID: "user1", DeviceID: "device1", @@ -97,7 +97,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create & store outbound megolm session for sending the event later megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1") megolmOutSession.Shared = true - machineOut.CryptoStore.AddOutboundGroupSession(megolmOutSession) + machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession) // encrypt m.room_key event with olm session deviceIdentity := &id.Device{ @@ -125,7 +125,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { if err != nil { t.Errorf("Error creating inbound megolm session: %v", err) } - if err = machineIn.CryptoStore.PutGroupSession("room1", senderKey, igs.ID(), igs); err != nil { + if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil { t.Errorf("Error storing inbound megolm session: %v", err) } } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 64d62bc..8c85f6d 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -27,7 +27,7 @@ import ( "maunium.net/go/mautrix/id" ) -var PostgresArrayWrapper func(interface{}) interface { +var PostgresArrayWrapper func(any) interface { driver.Valuer sql.Scanner } @@ -62,21 +62,21 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID } // Flush does nothing for this implementation as data is already persisted in the database. -func (store *SQLCryptoStore) Flush() error { +func (store *SQLCryptoStore) Flush(_ context.Context) error { return nil } // PutNextBatch stores the next sync batch token for the current account. func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error { store.SyncToken = nextBatch - _, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) + _, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) return err } // GetNextBatch retrieves the next sync batch token for the current account. func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { - err := store.DB. + err := store.DB.Conn(ctx). QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { @@ -111,20 +111,19 @@ func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (st return nb, nil } -func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) { - err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - // TODO return error - store.DB.Log.Warn("Failed to scan device ID: %v", err) +func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) { + err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID) + if errors.Is(err, sql.ErrNoRows) { + err = nil } return } // PutAccount stores an OlmAccount in the database. -func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error { +func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account bytes := account.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(` + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id @@ -133,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error { } // GetAccount retrieves an OlmAccount from the database. -func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { +func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { - row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) + row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) @@ -154,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { } // HasSession returns whether there is an Olm session for the given sender key. -func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { +func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool { store.olmSessionCacheLock.Lock() cache, ok := store.olmSessionCache[key] store.olmSessionCacheLock.Unlock() @@ -162,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { return true } var sessionID id.SessionID - err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", + err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", key, store.AccountID).Scan(&sessionID) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return false } return len(sessionID) > 0 } // GetSessions returns all the known Olm sessions for a sender key. -func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) { - rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC", +func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) { + rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC", key, store.AccountID) if err != nil { return nil, err @@ -212,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session } // GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. -func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) { +func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() - row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", + row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) sess := OlmSession{Internal: *olm.NewBlankSession()} @@ -224,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er var sessionID id.SessionID err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -242,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er } // AddSession persists an Olm session for a sender in the database. -func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error { +func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err } // UpdateSession replaces the Olm session for a sender in the database. -func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error { +func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", + _, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } @@ -275,14 +274,14 @@ func datePtr(t time.Time) *time.Time { } // PutGroupSession stores an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { +func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { return fmt.Errorf("failed to marshal ratchet safety info: %w", err) } - _, err = store.DB.Exec(` + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id @@ -301,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString var sessionBytes, ratchetSafetyBytes []byte var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -327,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send Reason: withheldReason.String, } } - igs := olm.NewBlankInboundGroupSession() - err = igs.Unpickle(sessionBytes, store.PickleKey) - if err != nil { - return nil, err - } - var chains []string - if forwardingChains.String != "" { - chains = strings.Split(forwardingChains.String, ",") - } - var rs RatchetSafety - if len(ratchetSafetyBytes) > 0 { - err = json.Unmarshal(ratchetSafetyBytes, &rs) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) - } - } + igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) if senderKey == "" { senderKey = id.Curve25519(senderKeyDB.String) } @@ -360,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send }, nil } -func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { + _, err := store.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL @@ -369,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses return err } -func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { +func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { if roomID == "" && senderKey == "" { return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions") } - res, err := store.DB.Query(` + res, err := store.DB.Query(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5 AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) { +func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { var query string switch store.DB.Dialect { case dbutil.Postgres: @@ -413,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error default: return nil, fmt.Errorf("unsupported dialect") } - res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) { - res, err := store.DB.Query(` +func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { + res, err := store.DB.Query(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { - _, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", +func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) return err } -func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { var code, reason sql.NullString - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&code, &reason) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil || !code.Valid { return nil, err @@ -467,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey }, nil } -func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) { - for rows.Next() { - var roomID id.RoomID - var signingKey, senderKey, forwardingChains sql.NullString - var sessionBytes, ratchetSafetyBytes []byte - var receivedAt sql.NullTime - var maxAge, maxMessages sql.NullInt64 - var isScheduled bool - err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { + igs = olm.NewBlankInboundGroupSession() + err = igs.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return + } + if forwardingChains != "" { + chains = strings.Split(forwardingChains, ",") + } + var rs RatchetSafety + if len(ratchetSafetyBytes) > 0 { + err = json.Unmarshal(ratchetSafetyBytes, &rs) if err != nil { - return + err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) } - igs := olm.NewBlankInboundGroupSession() - err = igs.Unpickle(sessionBytes, store.PickleKey) - if err != nil { - return - } - var chains []string - if forwardingChains.String != "" { - chains = strings.Split(forwardingChains.String, ",") - } - var rs RatchetSafety - if len(ratchetSafetyBytes) > 0 { - err = json.Unmarshal(ratchetSafetyBytes, &rs) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) - } - } - result = append(result, &InboundGroupSession{ - Internal: *igs, - SigningKey: id.Ed25519(signingKey.String), - SenderKey: id.Curve25519(senderKey.String), - RoomID: roomID, - ForwardingChains: chains, - RatchetSafety: rs, - ReceivedAt: receivedAt.Time, - MaxAge: maxAge.Int64, - MaxMessages: int(maxMessages.Int64), - IsScheduled: isScheduled, - }) } return } -func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { - rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled +func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) { + var roomID id.RoomID + var signingKey, senderKey, forwardingChains sql.NullString + var sessionBytes, ratchetSafetyBytes []byte + var receivedAt sql.NullTime + var maxAge, maxMessages sql.NullInt64 + var isScheduled bool + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + if err != nil { + return nil, err + } + igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + return &InboundGroupSession{ + Internal: *igs, + SigningKey: id.Ed25519(signingKey.String), + SenderKey: id.Curve25519(senderKey.String), + RoomID: roomID, + ForwardingChains: chains, + RatchetSafety: rs, + ReceivedAt: receivedAt.Time, + MaxAge: maxAge.Int64, + MaxMessages: int(maxMessages.Int64), + IsScheduled: isScheduled, + }, nil +} + +func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) - if err == sql.ErrNoRows { - return []*InboundGroupSession{}, nil - } else if err != nil { + if err != nil { return nil, err } - return store.scanGroupSessionList(rows) + return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() } -func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { - rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled +func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, store.AccountID, ) - if err == sql.ErrNoRows { - return []*InboundGroupSession{}, nil - } else if err != nil { + if err != nil { return nil, err } - return store.scanGroupSessionList(rows) + return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() } // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. -func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error { +func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(` + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -556,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi } // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. -func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error { +func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + _, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } // GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID. -func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { +func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { var ogs OutboundGroupSession var sessionBytes []byte var maxAgeMS int64 - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT session, shared, max_messages, message_count, max_age, created_at, last_used FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`, roomID, store.AccountID, ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -590,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun } // RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID. -func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error { - _, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", +func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", roomID, store.AccountID) return err } @@ -608,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey ` var expectedEventID id.EventID var expectedTimestamp int64 - err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp) + err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp) if err != nil { return false, err } else if expectedEventID != eventID || expectedTimestamp != timestamp { @@ -623,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey return true, nil } +func scanDevice(rows dbutil.Scannable) (*id.Device, error) { + var device id.Device + err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err + } + return &device, nil +} + // GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID. -func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) { +func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { var ignore id.UserID - err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) - if err == sql.ErrNoRows { + err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err } - rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) + rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) if err != nil { return nil, err } data := make(map[id.DeviceID]*id.Device) - for rows.Next() { - var identity id.Device - err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - return nil, err - } - identity.UserID = userID - data[identity.DeviceID] = &identity + err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { + data[device.DeviceID] = device + return true, nil + }) + if err != nil { + return nil, err } return data, nil } // GetDevice returns the device dentity for a given user and device ID. -func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { - var identity id.Device - err := store.DB.QueryRow(` - SELECT identity_key, signing_key, trust, deleted, name +func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { + return scanDevice(store.DB.QueryRow(ctx, ` + SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND device_id=$2`, userID, deviceID, - ).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - identity.UserID = userID - identity.DeviceID = deviceID - return &identity, nil + )) } // FindDeviceByKey finds a specific device by its sender key. -func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { - var identity id.Device - err := store.DB.QueryRow(` - SELECT device_id, signing_key, trust, deleted, name +func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { + return scanDevice(store.DB.QueryRow(ctx, ` + SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND identity_key=$2`, userID, identityKey, - ).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - identity.UserID = userID - identity.IdentityKey = identityKey - return &identity, nil + )) } const deviceInsertQuery = ` @@ -698,106 +659,84 @@ ON CONFLICT (user_id, device_id) DO UPDATE var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s") // PutDevice stores a single device for a user, replacing it if it exists already. -func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error { - _, err := store.DB.Exec(deviceInsertQuery, +func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error { + _, err := store.DB.Exec(ctx, deviceInsertQuery, userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name) return err } // PutDevices stores the device identity information for the given user ID. -func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error { - tx, err := store.DB.Begin() - if err != nil { - return err - } - - _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) - if err != nil { - return fmt.Errorf("failed to add user to tracked users list: %w", err) - } - - _, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) - if err != nil { - _ = tx.Rollback() - return fmt.Errorf("failed to delete old devices: %w", err) - } - if len(devices) == 0 { - err = tx.Commit() +func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { + return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) if err != nil { - return fmt.Errorf("failed to commit changes (no devices added): %w", err) + return fmt.Errorf("failed to add user to tracked users list: %w", err) + } + + _, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) + if err != nil { + return fmt.Errorf("failed to delete old devices: %w", err) + } + if len(devices) == 0 { + return nil + } + deviceBatchLen := 5 // how many devices will be inserted per query + deviceIDs := make([]id.DeviceID, 0, len(devices)) + for deviceID := range devices { + deviceIDs = append(deviceIDs, deviceID) + } + const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)" + for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen { + var batchDevices []id.DeviceID + if batchDeviceIdx+deviceBatchLen < len(deviceIDs) { + batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen] + } else { + batchDevices = deviceIDs[batchDeviceIdx:] + } + values := make([]interface{}, 1, len(devices)*6+1) + values[0] = userID + valueStrings := make([]string, 0, len(devices)) + i := 2 + for _, deviceID := range batchDevices { + identity := devices[deviceID] + values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) + valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5)) + i += 6 + } + valueString := strings.Join(valueStrings, ",") + _, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...) + if err != nil { + return fmt.Errorf("failed to insert new devices: %w", err) + } } return nil - } - deviceBatchLen := 5 // how many devices will be inserted per query - deviceIDs := make([]id.DeviceID, 0, len(devices)) - for deviceID := range devices { - deviceIDs = append(deviceIDs, deviceID) - } - const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)" - for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen { - var batchDevices []id.DeviceID - if batchDeviceIdx+deviceBatchLen < len(deviceIDs) { - batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen] - } else { - batchDevices = deviceIDs[batchDeviceIdx:] - } - values := make([]interface{}, 1, len(devices)*6+1) - values[0] = userID - valueStrings := make([]string, 0, len(devices)) - i := 2 - for _, deviceID := range batchDevices { - identity := devices[deviceID] - values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) - valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5)) - i += 6 - } - valueString := strings.Join(valueStrings, ",") - _, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...) - if err != nil { - _ = tx.Rollback() - return fmt.Errorf("failed to insert new devices: %w", err) - } - } - err = tx.Commit() - if err != nil { - return fmt.Errorf("failed to commit changes: %w", err) - } - return nil + }) } // FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information. -func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) { +func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) { var rows dbutil.Rows var err error if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { - rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { queryString := make([]string, len(users)) params := make([]interface{}, len(users)) for i, user := range users { - queryString[i] = fmt.Sprintf("$%d", i+1) + queryString[i] = fmt.Sprintf("?%d", i+1) params[i] = user } - rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) } if err != nil { return users, err } - var ptr int - for rows.Next() { - err = rows.Scan(&users[ptr]) - if err != nil { - return users, err - } else { - ptr++ - } - } - return users[:ptr], nil + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. -func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key `, userID, usage, key, key) @@ -805,8 +744,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross } // GetCrossSigningKeys retrieves a user's stored cross-signing keys. -func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { - rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID) +func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { + rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID) if err != nil { return nil, err } @@ -825,8 +764,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross } // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. -func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature `, signedUserID, signedKey, signerUserID, signerKey, signature) @@ -834,8 +773,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E } // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. -func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { - rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID) +func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { + rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID) if err != nil { return nil, err } @@ -854,18 +793,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25 } // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. -func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) { +func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) { q := `SELECT EXISTS( SELECT 1 FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4 )` - err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned) + err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned) return } // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. -func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { - res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) +func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) { + res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) if err != nil { return 0, err } diff --git a/crypto/sql_store_upgrade/upgrade.go b/crypto/sql_store_upgrade/upgrade.go index c9541b9..08c995d 100644 --- a/crypto/sql_store_upgrade/upgrade.go +++ b/crypto/sql_store_upgrade/upgrade.go @@ -7,6 +7,7 @@ package sql_store_upgrade import ( + "context" "embed" "fmt" @@ -21,7 +22,7 @@ const VersionTableName = "crypto_version" var fs embed.FS func init() { - Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error { + Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+") }) Table.RegisterFS(fs) diff --git a/crypto/store.go b/crypto/store.go index 99e464d..09393a5 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{} type Store interface { // Flush ensures that everything in the store is persisted to disk. // This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately. - Flush() error + Flush(context.Context) error // PutAccount updates the OlmAccount in the store. - PutAccount(*OlmAccount) error + PutAccount(context.Context, *OlmAccount) error // GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount. - GetAccount() (*OlmAccount, error) + GetAccount(ctx context.Context) (*OlmAccount, error) // AddSession inserts an Olm session into the store. - AddSession(id.SenderKey, *OlmSession) error + AddSession(context.Context, id.SenderKey, *OlmSession) error // HasSession returns whether or not the store has an Olm session with the given sender key. - HasSession(id.SenderKey) bool + HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. - GetSessions(id.SenderKey) (OlmSessionList, error) + GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) // GetLatestSession returns the session with the highest session ID (lexiographically sorting). // It's usually safe to return the most recently added session if sorting by session ID is too difficult. - GetLatestSession(id.SenderKey) (*OlmSession, error) + GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) // UpdateSession updates a session that has previously been inserted with AddSession. - UpdateSession(id.SenderKey, *OlmSession) error + UpdateSession(context.Context, id.SenderKey, *OlmSession) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. - PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error + PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error // GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. - GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) + GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) // RedactGroupSession removes the session data for the given inbound Megolm session from the store. - RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error + RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error // RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room. - RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error) + RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error) // RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired. - RedactExpiredGroupSessions() ([]id.SessionID, error) + RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error) // RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata. - RedactOutdatedGroupSessions() ([]id.SessionID, error) + RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error) // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. - PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error + PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. - GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) + GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error) + GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error) // GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export // files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetAllGroupSessions() ([]*InboundGroupSession, error) + GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error) // AddOutboundGroupSession inserts the given outbound Megolm session into the store. // // The store should index inserted sessions by the RoomID field to support getting and removing sessions. // There will only be one outbound session per room ID at a time. - AddOutboundGroupSession(*OutboundGroupSession) error + AddOutboundGroupSession(context.Context, *OutboundGroupSession) error // UpdateOutboundGroupSession updates the given outbound Megolm session in the store. - UpdateOutboundGroupSession(*OutboundGroupSession) error + UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error // GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store. - GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error) + GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error) // RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID. - RemoveOutboundGroupSession(id.RoomID) error + RemoveOutboundGroupSession(context.Context, id.RoomID) error // ValidateMessageIndex validates that the given message details aren't from a replay attack. // @@ -96,29 +96,29 @@ type Store interface { ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) // GetDevices returns a map from device ID to id.Device struct containing all devices of a given user. - GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error) + GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error) // GetDevice returns a specific device of a given user. - GetDevice(id.UserID, id.DeviceID) (*id.Device, error) + GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error) // PutDevice stores a single device for a user, replacing it if it exists already. - PutDevice(id.UserID, *id.Device) error + PutDevice(context.Context, id.UserID, *id.Device) error // PutDevices overrides the stored device list for the given user with the given list. - PutDevices(id.UserID, map[id.DeviceID]*id.Device) error + PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error // FindDeviceByKey finds a specific device by its identity key. - FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error) + FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error) // FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists // have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty. - FilterTrackedUsers([]id.UserID) ([]id.UserID, error) + FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error) // PutCrossSigningKey stores a cross-signing key of some user along with its usage. - PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error + PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error // GetCrossSigningKeys retrieves a user's stored cross-signing keys. - GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) + GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. - PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error + PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. - IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) + IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. - DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error) + DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) } type messageIndexKey struct { @@ -170,18 +170,18 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { } } -func (gs *MemoryStore) Flush() error { +func (gs *MemoryStore) Flush(_ context.Context) error { gs.lock.Lock() err := gs.save() gs.lock.Unlock() return err } -func (gs *MemoryStore) GetAccount() (*OlmAccount, error) { +func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { return gs.Account, nil } -func (gs *MemoryStore) PutAccount(account *OlmAccount) error { +func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error { gs.lock.Lock() gs.Account = account err := gs.save() @@ -189,7 +189,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error { return err } -func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) { +func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) { gs.lock.Lock() sessions, ok := gs.Sessions[senderKey] if !ok { @@ -200,7 +200,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro return sessions, nil } -func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error { +func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error { gs.lock.Lock() sessions, _ := gs.Sessions[senderKey] gs.Sessions[senderKey] = append(sessions, session) @@ -210,19 +210,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e return err } -func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error { +func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } -func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool { +func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() return ok && len(sessions) > 0 && !sessions[0].Expired() } -func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) { +func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() @@ -246,7 +246,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey return sender } -func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { +func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { gs.lock.Lock() gs.getGroupSessions(roomID, senderKey)[sessionID] = igs err := gs.save() @@ -254,7 +254,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, return err } -func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { gs.lock.Lock() session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] if !ok { @@ -269,7 +269,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, return session, nil } -func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { +func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { gs.lock.Lock() delete(gs.getGroupSessions(roomID, senderKey), sessionID) err := gs.save() @@ -277,7 +277,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK return err } -func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { +func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { gs.lock.Lock() var sessionIDs []id.SessionID if roomID != "" && senderKey != "" { @@ -315,11 +315,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender return sessionIDs, err } -func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) { +func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) { return nil, fmt.Errorf("not implemented") } -func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) { +func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) { return nil, fmt.Errorf("not implemented") } @@ -337,7 +337,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S return sender } -func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { +func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error { gs.lock.Lock() gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content err := gs.save() @@ -345,7 +345,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven return err } -func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { gs.lock.Lock() session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] gs.lock.Unlock() @@ -355,7 +355,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se return session, nil } -func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { gs.lock.Lock() defer gs.lock.Unlock() room, ok := gs.GroupSessions[roomID] @@ -371,7 +371,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou return result, nil } -func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) { gs.lock.Lock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { @@ -385,7 +385,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { return result, nil } -func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error { +func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { gs.lock.Lock() gs.OutGroupSessions[session.RoomID] = session err := gs.save() @@ -393,12 +393,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er return err } -func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error { +func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } -func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { +func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { gs.lock.RLock() session, ok := gs.OutGroupSessions[roomID] gs.lock.RUnlock() @@ -408,7 +408,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup return session, nil } -func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error { +func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error { gs.lock.Lock() session, ok := gs.OutGroupSessions[roomID] if !ok || session == nil { @@ -443,7 +443,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send return true, nil } -func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) { +func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { gs.lock.RLock() devices, ok := gs.Devices[userID] if !ok { @@ -453,7 +453,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, return devices, nil } -func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { +func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] @@ -467,7 +467,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De return device, nil } -func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { +func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] @@ -482,7 +482,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity return nil, nil } -func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error { +func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error { gs.lock.Lock() devices, ok := gs.Devices[userID] if !ok { @@ -495,7 +495,7 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error { return err } -func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error { +func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { gs.lock.Lock() gs.Devices[userID] = devices err := gs.save() @@ -503,7 +503,7 @@ func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id. return err } -func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) { +func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) { gs.lock.RLock() var ptr int for _, userID := range users { @@ -517,7 +517,7 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error return users[:ptr], nil } -func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { +func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() userKeys, ok := gs.CrossSigningKeys[userID] if !ok { @@ -539,7 +539,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin return err } -func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { +func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { gs.lock.RLock() defer gs.lock.RUnlock() keys, ok := gs.CrossSigningKeys[userID] @@ -549,7 +549,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin return keys, nil } -func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { +func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { gs.lock.RLock() signedUserSigs, ok := gs.KeySignatures[signedUserID] if !ok { @@ -572,7 +572,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519 return err } -func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { +func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { gs.lock.RLock() defer gs.lock.RUnlock() userKeys, ok := gs.KeySignatures[userID] @@ -590,8 +590,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s return sigsBySigner, nil } -func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { - sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID) +func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { + sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID) if err != nil { return false, err } @@ -599,7 +599,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID return ok, nil } -func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { +func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) { var count int64 gs.lock.RLock() for _, userSigs := range gs.KeySignatures { diff --git a/crypto/store_test.go b/crypto/store_test.go index ebeef39..665e3ef 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -36,7 +36,7 @@ func getCryptoStores(t *testing.T) map[string]Store { t.Fatalf("Error opening db: %v", err) } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(); err != nil { + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { t.Fatalf("Error creating tables: %v", err) } @@ -65,8 +65,8 @@ func TestPutAccount(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - store.PutAccount(acc) - retrieved, err := store.GetAccount() + store.PutAccount(context.TODO(), acc) + retrieved, err := store.GetAccount(context.TODO()) if err != nil { t.Fatalf("Error retrieving account: %v", err) } @@ -105,7 +105,7 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(olmSessID) { + if store.HasSession(context.TODO(), olmSessID) { t.Error("Found Olm session before inserting it") } olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) @@ -117,15 +117,15 @@ func TestStoreOlmSession(t *testing.T) { id: olmSessID, Internal: *olmInternal, } - err = store.AddSession(olmSessID, &olmSess) + err = store.AddSession(context.TODO(), olmSessID, &olmSess) if err != nil { t.Errorf("Error storing Olm session: %v", err) } - if !store.HasSession(olmSessID) { + if !store.HasSession(context.TODO(), olmSessID) { t.Error("Not found Olm session after inserting it") } - retrieved, err := store.GetLatestSession(olmSessID) + retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) if err != nil { t.Errorf("Failed retrieving Olm session: %v", err) } @@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) { RoomID: "room1", } - err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs) + err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs) if err != nil { t.Errorf("Error storing inbound group session: %v", err) } - retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID()) + retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID()) if err != nil { t.Errorf("Error retrieving inbound group session: %v", err) } @@ -179,7 +179,7 @@ func TestStoreOutboundMegolmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - sess, err := store.GetOutboundGroupSession("room1") + sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") if sess != nil { t.Error("Got outbound session before inserting") } @@ -188,12 +188,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) { } outbound := NewOutboundGroupSession("room1", nil) - err = store.AddOutboundGroupSession(outbound) + err = store.AddOutboundGroupSession(context.TODO(), outbound) if err != nil { t.Errorf("Error inserting outbound session: %v", err) } - sess, err = store.GetOutboundGroupSession("room1") + sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") if sess == nil { t.Error("Did not get outbound session after inserting") } @@ -201,12 +201,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) { t.Errorf("Error retrieving outbound session: %v", err) } - err = store.RemoveOutboundGroupSession("room1") + err = store.RemoveOutboundGroupSession(context.TODO(), "room1") if err != nil { t.Errorf("Error deleting outbound session: %v", err) } - sess, err = store.GetOutboundGroupSession("room1") + sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") if sess != nil { t.Error("Got outbound session after deleting") } @@ -232,11 +232,11 @@ func TestStoreDevices(t *testing.T) { SigningKey: acc.SigningKey(), } } - err := store.PutDevices("user1", deviceMap) + err := store.PutDevices(context.TODO(), "user1", deviceMap) if err != nil { t.Errorf("Error string devices: %v", err) } - devs, err := store.GetDevices("user1") + devs, err := store.GetDevices(context.TODO(), "user1") if err != nil { t.Errorf("Error getting devices: %v", err) } @@ -250,7 +250,7 @@ func TestStoreDevices(t *testing.T) { t.Errorf("Last device identity key does not match") } - filtered, err := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"}) + filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) if err != nil { t.Errorf("Error filtering tracked users: %v", err) } else if len(filtered) != 1 || filtered[0] != "user1" { diff --git a/crypto/verification.go b/crypto/verification.go index be24687..31608bf 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -507,7 +507,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use // we can finally trust this device device.Trust = id.TrustStateVerified - err = mach.CryptoStore.PutDevice(device.UserID, device) + err = mach.CryptoStore.PutDevice(ctx, device.UserID, device) if err != nil { mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err) } @@ -521,7 +521,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID) } } else { - masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID) + masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID) if err != nil { mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err) } else { diff --git a/go.mod b/go.mod index 07e3efd..8484acc 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 + go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.17.0 golang.org/x/exp v0.0.0-20231226003508-02704c960a9b diff --git a/go.sum b/go.sum index f52dc4c..d923c7b 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 h1:zcfIxHgzZpgGSJv/FUVbOjO4ZWa12En4TGhxgUI/QH0= -go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 531b71e..cd94215 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ package sqlstatestore import ( + "context" "database/sql" "embed" "encoding/json" @@ -15,6 +16,7 @@ import ( "strconv" "strings" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/event" @@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b } } -func (store *SQLStateStore) IsRegistered(userID id.UserID) bool { +func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) { var isRegistered bool err := store. - QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID). + QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID). Scan(&isRegistered) - if err != nil { - store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err) + if errors.Is(err, sql.ErrNoRows) { + err = nil } - return isRegistered + return isRegistered, err } -func (store *SQLStateStore) MarkRegistered(userID id.UserID) { - _, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) - if err != nil { - store.Log.Warn("Failed to mark %s as registered: %v", userID, err) - } +func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { + _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) + return err } -func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent { - members := make(map[id.UserID]*event.MemberEventContent) +type Member struct { + id.UserID + event.MemberEventContent +} + +func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) { args := make([]any, len(memberships)+1) args[0] = roomID query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1" @@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even } query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } - rows, err := store.Query(query, args...) + rows, err := store.Query(ctx, query, args...) if err != nil { - return members + return nil, err } - var userID id.UserID - var member event.MemberEventContent - for rows.Next() { - err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil { - store.Log.Warn("Failed to scan member in %s: %v", roomID, err) - } else { - members[userID] = &member - } - } - return members + members := make(map[id.UserID]*event.MemberEventContent) + return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { + err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) + return + }).Iter(func(m Member) (bool, error) { + members[m.UserID] = &m.MemberEventContent + return true, nil + }) } -func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) { - memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite) +func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { + var memberMap map[id.UserID]*event.MemberEventContent + memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite) + if err != nil { + return + } members = make([]id.UserID, len(memberMap)) i := 0 for userID := range memberMap { @@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem return } -func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { - membership := event.MembershipLeave - err := store. - QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). +func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) { + err = store. + QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). Scan(&membership) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err) + if errors.Is(err, sql.ErrNoRows) { + membership = event.MembershipLeave + err = nil } - return membership + return } -func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := store.TryGetMember(roomID, userID) - if !ok { - member.Membership = event.MembershipLeave +func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + member, err := store.TryGetMember(ctx, roomID, userID) + if member == nil && err == nil { + member = &event.MemberEventContent{Membership: event.MembershipLeave} } - return member + return member, err } -func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) { +func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { var member event.MemberEventContent err := store. - QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). + QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). Scan(&member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err } - return &member, err == nil + return &member, nil } -func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) { +func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { query := ` SELECT room_id FROM mx_user_profile LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id @@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID ` if !store.IsBridge { query = ` - SELECT mx_user_profile.room_id FROM mx_user_profile - LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id - WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL - ` + SELECT mx_user_profile.room_id FROM mx_user_profile + LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id + WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL + ` } - rows, err := store.Query(query, userID) + rows, err := store.Query(ctx, query, userID) if err != nil { - store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err) - return + return nil, err } - for rows.Next() { - var roomID id.RoomID - err = rows.Scan(&roomID) - if err != nil { - store.Log.Warn("Failed to scan room ID: %v", err) - } else { - rooms = append(rooms, roomID) - } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() +} + +func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join") +} + +func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join", "invite") +} + +func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + membership, err := store.GetMembership(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership") + return false } - return -} - -func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join") -} - -func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join", "invite") -} - -func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - membership := store.GetMembership(roomID, userID) for _, allowedMembership := range allowedMemberships { if allowedMembership == membership { return true @@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all return false } -func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { - _, err := store.Exec(` +func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership `, roomID, userID, membership) - if err != nil { - store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err) - } + return err } -func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { - _, err := store.Exec(` +func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) - if err != nil { - store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err) - } + return err } -func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) { +func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) params[0] = roomID @@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ... } query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } - _, err := store.Exec(query, params...) - if err != nil { - store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err) - } + _, err := store.Exec(ctx, query, params...) + return err } -func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) { +func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { contentBytes, err := json.Marshal(content) if err != nil { - store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err) - return + return fmt.Errorf("failed to marshal content JSON: %w", err) } - _, err = store.Exec(` + _, err = store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption `, roomID, contentBytes) - if err != nil { - store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err) - } + return err } -func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { +func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). Scan(&data) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err) - } - return nil + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err } else if data == nil { - return nil + return nil, nil } - content := &event.EncryptionEventContent{} - err = json.Unmarshal(data, content) + var content event.EncryptionEventContent + err = json.Unmarshal(data, &content) if err != nil { - store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err) - return nil + return nil, fmt.Errorf("failed to parse content JSON: %w", err) } - return content + return &content, nil } -func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool { - cfg := store.GetEncryptionEvent(roomID) - return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1 +func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { + cfg, err := store.GetEncryptionEvent(ctx, roomID) + return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err } -func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { - levelsBytes, err := json.Marshal(levels) - if err != nil { - store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err) - return - } - _, err = store.Exec(` +func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels - `, roomID, levelsBytes) - if err != nil { - store.Log.Warn("Failed to store power levels of %s: %v", roomID, err) - } + `, roomID, dbutil.JSON{Data: levels}) + return err } -func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { - var data []byte - err := store. - QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&data) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err) - } - return - } else if data == nil { - return - } - levels = &event.PowerLevelsEventContent{} - err = json.Unmarshal(data, levels) - if err != nil { - store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err) - return nil +func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { + err = store. + QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + err = nil } return } -func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { +func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { if store.Dialect == dbutil.Postgres { var powerLevel int err := store. - QueryRow(` + QueryRow(ctx, ` SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) FROM mx_room_state WHERE room_id=$1 `, roomID, userID). Scan(&powerLevel) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err) + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } - return powerLevel + return levels.GetUserLevel(userID), nil } - return store.GetPowerLevels(roomID).GetUserLevel(userID) } -func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { +func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { if store.Dialect == dbutil.Postgres { defaultType := "events_default" defaultValue := 0 @@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType } var powerLevel int err := store. - QueryRow(` + QueryRow(ctx, ` SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) FROM mx_room_state WHERE room_id=$1 `, roomID, eventType.Type, defaultType, defaultValue). Scan(&powerLevel) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue + if errors.Is(err, sql.ErrNoRows) { + err = nil + powerLevel = defaultValue } - return powerLevel + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err + } + return levels.GetEventLevel(eventType), nil } - return store.GetPowerLevels(roomID).GetEventLevel(eventType) } -func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { +func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { if store.Dialect == dbutil.Postgres { defaultType := "events_default" defaultValue := 0 @@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev } var hasPower bool err := store. - QueryRow(`SELECT + QueryRow(ctx, `SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) >= COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). Scan(&hasPower) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue == 0 + if errors.Is(err, sql.ErrNoRows) { + err = nil + hasPower = defaultValue == 0 } - return hasPower + return hasPower, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err + } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil } - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) } diff --git a/sqlstatestore/v05-mark-encryption-state-resync.go b/sqlstatestore/v05-mark-encryption-state-resync.go index d66a9e9..bf44d30 100644 --- a/sqlstatestore/v05-mark-encryption-state-resync.go +++ b/sqlstatestore/v05-mark-encryption-state-resync.go @@ -1,19 +1,20 @@ package sqlstatestore import ( + "context" "fmt" "go.mau.fi/util/dbutil" ) func init() { - UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error { - portalExists, err := db.TableExists(tx, "portal") + UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error { + portalExists, err := db.TableExists(ctx, "portal") if err != nil { return fmt.Errorf("failed to check if portal table exists") } if portalExists { - _, err = tx.Exec(` + _, err = db.Exec(ctx, ` INSERT INTO mx_room_state (room_id, encryption) SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL ON CONFLICT (room_id) DO UPDATE diff --git a/statestore.go b/statestore.go index 2c0a8fd..63a5bfb 100644 --- a/statestore.go +++ b/statestore.go @@ -7,33 +7,37 @@ package mautrix import ( + "context" "sync" + "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // StateStore is an interface for storing basic room state information. type StateStore interface { - IsInRoom(roomID id.RoomID, userID id.UserID) bool - IsInvited(roomID id.RoomID, userID id.UserID) bool - IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool - GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent - TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) - SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) - SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) - ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) + IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool + IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool + IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool + GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error + SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error + ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error - SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) - GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent + SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error + GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) - SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) - IsEncrypted(roomID id.RoomID) bool + SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error + IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) - GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) + GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) } -func UpdateStateStore(store StateStore, evt *event.Event) { +func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { if store == nil || evt == nil || evt.StateKey == nil { return } @@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) { if evt.Type != event.StateMember && evt.GetStateKey() != "" { return } + var err error switch content := evt.Content.Parsed.(type) { case *event.MemberEventContent: - store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content) + err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content) case *event.PowerLevelsEventContent: - store.SetPowerLevels(evt.RoomID, content) + err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: - store.SetEncryptionEvent(evt.RoomID, content) + err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + } + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_id", evt.ID). + Str("event_type", evt.Type.Type). + Msg("Failed to update state store") } } @@ -57,7 +68,7 @@ func UpdateStateStore(store StateStore, evt *event.Event) { // // DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default). func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt) } type MemoryStateStore struct { @@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore { } } -func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool { +func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) { store.registrationsLock.RLock() defer store.registrationsLock.RUnlock() registered, ok := store.Registrations[userID] - return ok && registered + return ok && registered, nil } -func (store *MemoryStateStore) MarkRegistered(userID id.UserID) { +func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error { store.registrationsLock.Lock() defer store.registrationsLock.Unlock() store.Registrations[userID] = true + return nil } -func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent { +func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { store.membersLock.RLock() members, ok := store.Members[roomID] store.membersLock.RUnlock() @@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e store.Members[roomID] = members store.membersLock.Unlock() } - return members + return members, nil } -func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) { - members := store.GetRoomMembers(roomID) +func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + members, err := store.GetRoomMembers(ctx, roomID) + if err != nil { + return nil, err + } ids := make([]id.UserID, 0, len(members)) for id := range members { ids = append(ids, id) @@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ( return ids, nil } -func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { - return store.GetMember(roomID, userID).Membership +func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) { + return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil } -func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := store.TryGetMember(roomID, userID) - if !ok { +func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + member, err := store.TryGetMember(ctx, roomID, userID) + if member == nil && err == nil { member = &event.MemberEventContent{Membership: event.MembershipLeave} } - return member + return member, err } -func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) { +func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) { store.membersLock.RLock() defer store.membersLock.RUnlock() members, membersOk := store.Members[roomID] if !membersOk { return } - member, ok = members[userID] + member = members[userID] return } -func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join") +func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join") } -func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join", "invite") +func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join", "invite") } -func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - membership := store.GetMembership(roomID, userID) +func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + membership := exerrors.Must(store.GetMembership(ctx, roomID, userID)) for _, allowedMembership := range allowedMemberships { if allowedMembership == membership { return true @@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, return false } -func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { +func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { @@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, } store.Members[roomID] = members store.membersLock.Unlock() + return nil } -func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { +func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { @@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem } store.Members[roomID] = members store.membersLock.Unlock() + return nil } -func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) { +func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error { store.membersLock.Lock() defer store.membersLock.Unlock() members, ok := store.Members[roomID] if !ok { - return + return nil } for userID, member := range members { for _, membership := range memberships { @@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships } } } + return nil } -func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { +func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels store.powerLevelsLock.Unlock() + return nil } -func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { +func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] store.powerLevelsLock.RUnlock() return } -func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { - return store.GetPowerLevels(roomID).GetUserLevel(userID) +func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { + return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil } -func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { - return store.GetPowerLevels(roomID).GetEventLevel(eventType) +func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { + return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil } -func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) +func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { + return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } -func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) { +func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content store.encryptionLock.Unlock() + return nil } -func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { +func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { store.encryptionLock.RLock() defer store.encryptionLock.RUnlock() - return store.Encryption[roomID] + return store.Encryption[roomID], nil } -func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool { - cfg := store.GetEncryptionEvent(roomID) - return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1 +func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { + cfg, err := store.GetEncryptionEvent(ctx, roomID) + return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err }